Shaoan commited on
Commit
490f253
Β·
verified Β·
1 Parent(s): 7ce8df1

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +28 -23
app.py CHANGED
@@ -82,7 +82,7 @@ class ConceptAlignerModel:
82
  self.text_encoder.load_state_dict(adapter_state, strict=True)
83
  print(" βœ“ T5 encoder loaded")
84
 
85
- # Only download VAE (small ~330MB) - not fine-tuned
86
  print(" Loading VAE from FLUX.1-dev...")
87
  vae = AutoencoderKL.from_pretrained(
88
  'black-forest-labs/FLUX.1-dev',
@@ -90,25 +90,23 @@ class ConceptAlignerModel:
90
  torch_dtype=self.dtype,
91
  token=HF_TOKEN
92
  ).to(self.device)
93
- print(" βœ“ VAE loaded (~330MB download)")
94
 
95
- # Create transformer architecture WITHOUT downloading base weights
96
- print(" Initializing transformer architecture...")
97
-
98
- # Get config only (no weights download)
99
- from diffusers.models.transformers.transformer_flux import FluxTransformerConfig
100
  config = FluxTransformer2DModel.load_config(
101
  'black-forest-labs/FLUX.1-dev',
102
  subfolder="transformer",
103
  token=HF_TOKEN
104
  )
105
 
106
- # Initialize empty transformer from config
107
- transformer = FluxTransformer2DModel.from_config(config)
108
- print(" βœ“ Transformer architecture initialized")
 
109
 
110
- # Add LoRA config (needed for architecture)
111
- print(" Adding LoRA adapter config...")
112
  transformer_lora_config = LoraConfig(
113
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
114
  target_modules=[
@@ -120,15 +118,25 @@ class ConceptAlignerModel:
120
  )
121
  transformer.add_adapter(transformer_lora_config)
122
  transformer.context_embedder.requires_grad_(True)
 
123
 
124
- # Load YOUR FULL fine-tuned transformer weights (no base model needed!)
125
- print(" Loading YOUR fine-tuned transformer weights...")
126
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
127
- transformer.load_state_dict(transformer_state, strict=True)
 
 
 
 
 
 
 
 
128
  transformer = transformer.to(self.device).to(self.dtype)
129
- print(" βœ“ Fine-tuned transformer loaded (~26GB from your checkpoint)")
130
 
131
  # Load empty pooled clip
 
132
  self.empty_pooled_clip = torch.load(
133
  os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
134
  map_location=self.device,
@@ -136,8 +144,8 @@ class ConceptAlignerModel:
136
  ).to(self.dtype)
137
  print(" βœ“ Empty pooled clip loaded")
138
 
139
- # Create scheduler (just config, no weights)
140
- print(" Loading scheduler config...")
141
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
142
  'black-forest-labs/FLUX.1-dev',
143
  subfolder="scheduler",
@@ -146,7 +154,7 @@ class ConceptAlignerModel:
146
  print(" βœ“ Scheduler loaded")
147
 
148
  # Create pipeline
149
- print(" Creating pipeline...")
150
  self.pipe = CustomFluxKontextPipeline(
151
  scheduler=noise_scheduler,
152
  aligner=self.model,
@@ -156,10 +164,8 @@ class ConceptAlignerModel:
156
  ).to(self.device)
157
 
158
  print("="*60)
159
- print("βœ“ ALL MODELS LOADED SUCCESSFULLY!")
160
  print("="*60)
161
- print(f"Total downloads: ~330MB VAE + ~26GB your checkpoints")
162
- print(f"Saved: ~24GB by not downloading base FLUX transformer!")
163
 
164
  # Print memory usage
165
  if torch.cuda.is_available():
@@ -211,7 +217,6 @@ with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
211
  # 🎨 ConceptAligner Demo
212
 
213
  Generate images with fine-tuned concept alignment using FLUX!
214
- This demo uses fully fine-tuned weights - no base model downloads needed.
215
  """)
216
 
217
  with gr.Row():
 
82
  self.text_encoder.load_state_dict(adapter_state, strict=True)
83
  print(" βœ“ T5 encoder loaded")
84
 
85
+ # Only download VAE (small ~168MB)
86
  print(" Loading VAE from FLUX.1-dev...")
87
  vae = AutoencoderKL.from_pretrained(
88
  'black-forest-labs/FLUX.1-dev',
 
90
  torch_dtype=self.dtype,
91
  token=HF_TOKEN
92
  ).to(self.device)
93
+ print(" βœ“ VAE loaded")
94
 
95
+ # Create transformer from config only (download config.json but not weights)
96
+ print(" Downloading transformer config only...")
 
 
 
97
  config = FluxTransformer2DModel.load_config(
98
  'black-forest-labs/FLUX.1-dev',
99
  subfolder="transformer",
100
  token=HF_TOKEN
101
  )
102
 
103
+ # Initialize transformer from config (no weights)
104
+ print(" Initializing transformer architecture from config...")
105
+ transformer = FluxTransformer2DModel.from_config(config, torch_dtype=self.dtype)
106
+ print(" βœ“ Empty transformer initialized")
107
 
108
+ # Add LoRA adapter config
109
+ print(" Adding LoRA adapter layers...")
110
  transformer_lora_config = LoraConfig(
111
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
112
  target_modules=[
 
118
  )
119
  transformer.add_adapter(transformer_lora_config)
120
  transformer.context_embedder.requires_grad_(True)
121
+ print(" βœ“ LoRA adapters added")
122
 
123
+ # Load YOUR FULL fine-tuned transformer weights
124
+ print(" Loading your fine-tuned transformer weights...")
125
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
126
+
127
+ # Load with strict=False in case of minor key mismatches
128
+ missing_keys, unexpected_keys = transformer.load_state_dict(transformer_state, strict=False)
129
+
130
+ if missing_keys:
131
+ print(f" ⚠️ Missing keys: {len(missing_keys)}")
132
+ if unexpected_keys:
133
+ print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}")
134
+
135
  transformer = transformer.to(self.device).to(self.dtype)
136
+ print(" βœ“ Fine-tuned transformer loaded")
137
 
138
  # Load empty pooled clip
139
+ print(" Loading empty pooled clip...")
140
  self.empty_pooled_clip = torch.load(
141
  os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
142
  map_location=self.device,
 
144
  ).to(self.dtype)
145
  print(" βœ“ Empty pooled clip loaded")
146
 
147
+ # Create scheduler (just config)
148
+ print(" Loading scheduler...")
149
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
150
  'black-forest-labs/FLUX.1-dev',
151
  subfolder="scheduler",
 
154
  print(" βœ“ Scheduler loaded")
155
 
156
  # Create pipeline
157
+ print(" Assembling pipeline...")
158
  self.pipe = CustomFluxKontextPipeline(
159
  scheduler=noise_scheduler,
160
  aligner=self.model,
 
164
  ).to(self.device)
165
 
166
  print("="*60)
167
+ print("βœ… ALL MODELS LOADED SUCCESSFULLY!")
168
  print("="*60)
 
 
169
 
170
  # Print memory usage
171
  if torch.cuda.is_available():
 
217
  # 🎨 ConceptAligner Demo
218
 
219
  Generate images with fine-tuned concept alignment using FLUX!
 
220
  """)
221
 
222
  with gr.Row():