Shaoan commited on
Commit
6af382a
Β·
verified Β·
1 Parent(s): 58dc7cb

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +33 -53
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,5 @@
1
  """
2
- ConceptAligner Hugging Face Demo - Minimal downloads
3
- Only downloads VAE, uses your fine-tuned weights for everything else
4
  """
5
 
6
  import torch
@@ -14,6 +13,14 @@ from pipeline import CustomFluxKontextPipeline
14
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
15
  from peft import LoraConfig
16
 
 
 
 
 
 
 
 
 
17
  # Login with token from environment
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  if HF_TOKEN:
@@ -73,7 +80,7 @@ class ConceptAlignerModel:
73
  self.model.load_state_dict(adapter_state, strict=True)
74
  print(" βœ“ ConceptAligner loaded")
75
 
76
- # Load T5 encoder (your fine-tuned version with full weights)
77
  print(" Loading fine-tuned T5 encoder...")
78
  self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
79
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
@@ -82,7 +89,7 @@ class ConceptAlignerModel:
82
  self.text_encoder.load_state_dict(adapter_state, strict=True)
83
  print(" βœ“ T5 encoder loaded")
84
 
85
- # Only download VAE (small ~168MB)
86
  print(" Loading VAE from FLUX.1-dev...")
87
  vae = AutoencoderKL.from_pretrained(
88
  'black-forest-labs/FLUX.1-dev',
@@ -92,21 +99,18 @@ class ConceptAlignerModel:
92
  ).to(self.device)
93
  print(" βœ“ VAE loaded")
94
 
95
- # Create transformer from config only (download config.json but not weights)
96
- print(" Downloading transformer config only...")
97
  config = FluxTransformer2DModel.load_config(
98
  'black-forest-labs/FLUX.1-dev',
99
  subfolder="transformer",
100
  token=HF_TOKEN
101
  )
102
 
103
- # Initialize transformer from config (no weights)
104
- print(" Initializing transformer architecture from config...")
105
  transformer = FluxTransformer2DModel.from_config(config, torch_dtype=self.dtype)
106
- print(" βœ“ Empty transformer initialized")
107
 
108
- # Add LoRA adapter config
109
- print(" Adding LoRA adapter layers...")
110
  transformer_lora_config = LoraConfig(
111
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
112
  target_modules=[
@@ -118,43 +122,28 @@ class ConceptAlignerModel:
118
  )
119
  transformer.add_adapter(transformer_lora_config)
120
  transformer.context_embedder.requires_grad_(True)
121
- print(" βœ“ LoRA adapters added")
122
 
123
- # Load YOUR FULL fine-tuned transformer weights
124
- print(" Loading your fine-tuned transformer weights...")
125
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
126
-
127
- # Load with strict=False in case of minor key mismatches
128
- missing_keys, unexpected_keys = transformer.load_state_dict(transformer_state, strict=False)
129
-
130
- if missing_keys:
131
- print(f" ⚠️ Missing keys: {len(missing_keys)}")
132
- if unexpected_keys:
133
- print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}")
134
-
135
  transformer = transformer.to(self.device).to(self.dtype)
136
- print(" βœ“ Fine-tuned transformer loaded")
137
 
138
  # Load empty pooled clip
139
- print(" Loading empty pooled clip...")
140
  self.empty_pooled_clip = torch.load(
141
  os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
142
  map_location=self.device,
143
  weights_only=True
144
  ).to(self.dtype)
145
- print(" βœ“ Empty pooled clip loaded")
146
 
147
- # Create scheduler (just config)
148
- print(" Loading scheduler...")
149
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
150
  'black-forest-labs/FLUX.1-dev',
151
  subfolder="scheduler",
152
  token=HF_TOKEN
153
  )
154
- print(" βœ“ Scheduler loaded")
155
 
156
  # Create pipeline
157
- print(" Assembling pipeline...")
158
  self.pipe = CustomFluxKontextPipeline(
159
  scheduler=noise_scheduler,
160
  aligner=self.model,
@@ -163,15 +152,11 @@ class ConceptAlignerModel:
163
  text_embedder=self.text_encoder,
164
  ).to(self.device)
165
 
166
- print("="*60)
167
- print("βœ… ALL MODELS LOADED SUCCESSFULLY!")
168
- print("="*60)
169
 
170
- # Print memory usage
171
  if torch.cuda.is_available():
172
  allocated = torch.cuda.memory_allocated(0) / 1024**3
173
- reserved = torch.cuda.memory_reserved(0) / 1024**3
174
- print(f"πŸ“Š GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
175
 
176
  @torch.no_grad()
177
  def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
@@ -196,7 +181,7 @@ class ConceptAlignerModel:
196
  return prev_image, current_image, prev_prompt
197
  except Exception as e:
198
  import traceback
199
- print(f"❌ Generation error: {e}")
200
  print(traceback.format_exc())
201
  return self.previous_image, None, self.previous_prompt or ""
202
 
@@ -206,18 +191,18 @@ class ConceptAlignerModel:
206
  return None, None, "No previous generation"
207
 
208
  # Initialize model
209
- print("="*60)
210
- print("πŸš€ Initializing ConceptAligner Demo")
211
- print("="*60)
212
  model = ConceptAlignerModel()
213
 
214
- # Create Gradio interface (without theme in constructor for Gradio 6.0)
 
 
 
 
 
 
215
  with gr.Blocks(title="ConceptAligner") as demo:
216
- gr.Markdown("""
217
- # 🎨 ConceptAligner Demo
218
-
219
- Generate images with fine-tuned concept alignment using FLUX!
220
- """)
221
 
222
  with gr.Row():
223
  with gr.Column(scale=1):
@@ -255,7 +240,7 @@ with gr.Blocks(title="ConceptAligner") as demo:
255
  gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
256
 
257
  generate_btn.click(
258
- fn=model.generate_image,
259
  inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
260
  outputs=[prev_image, current_image, prev_prompt_display]
261
  )
@@ -263,9 +248,4 @@ with gr.Blocks(title="ConceptAligner") as demo:
263
  reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
264
 
265
  if __name__ == "__main__":
266
- # Launch with proper configuration for Hugging Face Spaces
267
- demo.launch(
268
- server_name="0.0.0.0",
269
- server_port=7860,
270
- show_error=True
271
- )
 
1
  """
2
+ ConceptAligner Hugging Face Demo
 
3
  """
4
 
5
  import torch
 
13
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
14
  from peft import LoraConfig
15
 
16
+ # For HF Spaces GPU support
17
+ try:
18
+ import spaces
19
+ GPU_AVAILABLE = True
20
+ except ImportError:
21
+ GPU_AVAILABLE = False
22
+ print("⚠️ spaces package not available, running without @spaces.GPU decorator")
23
+
24
  # Login with token from environment
25
  HF_TOKEN = os.environ.get("HF_TOKEN")
26
  if HF_TOKEN:
 
80
  self.model.load_state_dict(adapter_state, strict=True)
81
  print(" βœ“ ConceptAligner loaded")
82
 
83
+ # Load T5 encoder
84
  print(" Loading fine-tuned T5 encoder...")
85
  self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
86
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
 
89
  self.text_encoder.load_state_dict(adapter_state, strict=True)
90
  print(" βœ“ T5 encoder loaded")
91
 
92
+ # Download VAE
93
  print(" Loading VAE from FLUX.1-dev...")
94
  vae = AutoencoderKL.from_pretrained(
95
  'black-forest-labs/FLUX.1-dev',
 
99
  ).to(self.device)
100
  print(" βœ“ VAE loaded")
101
 
102
+ # Create transformer from config
103
+ print(" Downloading transformer config...")
104
  config = FluxTransformer2DModel.load_config(
105
  'black-forest-labs/FLUX.1-dev',
106
  subfolder="transformer",
107
  token=HF_TOKEN
108
  )
109
 
110
+ print(" Initializing transformer...")
 
111
  transformer = FluxTransformer2DModel.from_config(config, torch_dtype=self.dtype)
 
112
 
113
+ print(" Adding LoRA adapters...")
 
114
  transformer_lora_config = LoraConfig(
115
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
116
  target_modules=[
 
122
  )
123
  transformer.add_adapter(transformer_lora_config)
124
  transformer.context_embedder.requires_grad_(True)
 
125
 
126
+ print(" Loading fine-tuned transformer weights...")
 
127
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
128
+ transformer.load_state_dict(transformer_state, strict=False)
 
 
 
 
 
 
 
 
129
  transformer = transformer.to(self.device).to(self.dtype)
130
+ print(" βœ“ Transformer loaded")
131
 
132
  # Load empty pooled clip
 
133
  self.empty_pooled_clip = torch.load(
134
  os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
135
  map_location=self.device,
136
  weights_only=True
137
  ).to(self.dtype)
 
138
 
139
+ # Create scheduler
 
140
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
141
  'black-forest-labs/FLUX.1-dev',
142
  subfolder="scheduler",
143
  token=HF_TOKEN
144
  )
 
145
 
146
  # Create pipeline
 
147
  self.pipe = CustomFluxKontextPipeline(
148
  scheduler=noise_scheduler,
149
  aligner=self.model,
 
152
  text_embedder=self.text_encoder,
153
  ).to(self.device)
154
 
155
+ print("βœ… ALL MODELS LOADED!")
 
 
156
 
 
157
  if torch.cuda.is_available():
158
  allocated = torch.cuda.memory_allocated(0) / 1024**3
159
+ print(f"πŸ“Š GPU Memory: {allocated:.2f}GB allocated")
 
160
 
161
  @torch.no_grad()
162
  def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
 
181
  return prev_image, current_image, prev_prompt
182
  except Exception as e:
183
  import traceback
184
+ print(f"❌ Error: {e}")
185
  print(traceback.format_exc())
186
  return self.previous_image, None, self.previous_prompt or ""
187
 
 
191
  return None, None, "No previous generation"
192
 
193
  # Initialize model
194
+ print("πŸš€ Initializing ConceptAligner...")
 
 
195
  model = ConceptAlignerModel()
196
 
197
+ # Wrap generation function with @spaces.GPU if available
198
+ if GPU_AVAILABLE:
199
+ generate_fn = spaces.GPU(model.generate_image)
200
+ else:
201
+ generate_fn = model.generate_image
202
+
203
+ # Create Gradio interface
204
  with gr.Blocks(title="ConceptAligner") as demo:
205
+ gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
 
 
 
 
206
 
207
  with gr.Row():
208
  with gr.Column(scale=1):
 
240
  gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
241
 
242
  generate_btn.click(
243
+ fn=generate_fn,
244
  inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
245
  outputs=[prev_image, current_image, prev_prompt_display]
246
  )
 
248
  reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
249
 
250
  if __name__ == "__main__":
251
+ demo.launch()
 
 
 
 
 
requirements.txt CHANGED
@@ -18,4 +18,5 @@ httpx==0.28.1
18
  requests==2.32.5
19
  numpy==1.26.4
20
  pydantic==2.11.9
21
- python-multipart==0.0.20
 
 
18
  requests==2.32.5
19
  numpy==1.26.4
20
  pydantic==2.11.9
21
+ python-multipart==0.0.20
22
+ spaces