Shaoan commited on
Commit
3b895b3
Β·
verified Β·
1 Parent(s): 9128255

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +222 -157
app.py CHANGED
@@ -1,11 +1,11 @@
1
  """
2
- ConceptAligner Hugging Face Demo - ZeroGPU Compatible
 
3
  """
4
 
5
  # CRITICAL: Import spaces FIRST
6
  import spaces
7
 
8
- # Now import everything else
9
  import torch
10
  import gradio as gr
11
  import os
@@ -17,7 +17,7 @@ from pipeline import CustomFluxKontextPipeline
17
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
18
  from peft import LoraConfig
19
 
20
- # Login with token from environment
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
@@ -26,15 +26,16 @@ if HF_TOKEN:
26
  # Configuration
27
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
28
  CHECKPOINT_DIR = "./checkpoint"
 
 
29
 
30
  EXAMPLE_PROMPTS = [
31
  ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
32
  ]
33
 
34
  def download_checkpoint():
35
- """Download checkpoint files from HF model repo"""
36
  print("Downloading checkpoint files...")
37
-
38
  files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
39
  os.makedirs(CHECKPOINT_DIR, exist_ok=True)
40
 
@@ -48,118 +49,101 @@ def download_checkpoint():
48
  local_dir=CHECKPOINT_DIR,
49
  token=HF_TOKEN
50
  )
51
- print(f" βœ“ {filename} downloaded")
52
-
53
- print("βœ“ All checkpoint files ready!")
54
-
55
- # Global model variable
56
- model_pipeline = None
57
-
58
- def load_models():
59
- """Load all models - called once at startup"""
60
- global model_pipeline
61
-
62
- if model_pipeline is not None:
63
- return model_pipeline
64
-
65
- print("πŸš€ Loading models...")
66
-
67
- checkpoint_path = CHECKPOINT_DIR
68
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
70
-
71
- # Load ConceptAligner
72
- print(" Loading ConceptAligner...")
73
- aligner_model = ConceptAligner().to(device).to(dtype)
74
- adapter_state = load_file(os.path.join(checkpoint_path, "model_1.safetensors"))
75
- aligner_model.load_state_dict(adapter_state, strict=True)
76
-
77
- # Load T5 encoder
78
- print(" Loading T5 encoder...")
79
- text_encoder = LoraT5Embedder(device=device).to(dtype)
80
- adapter_state = load_file(os.path.join(checkpoint_path, "model_2.safetensors"))
81
- if "t5_encoder.shared.weight" in adapter_state:
82
- adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
83
- text_encoder.load_state_dict(adapter_state, strict=True)
84
-
85
- # Load VAE
86
- print(" Loading VAE...")
87
- vae = AutoencoderKL.from_pretrained(
88
- 'black-forest-labs/FLUX.1-dev',
89
- subfolder="vae",
90
- torch_dtype=dtype,
91
- token=HF_TOKEN
92
- ).to(device)
93
-
94
- # Load transformer
95
- print(" Loading transformer...")
96
- config = FluxTransformer2DModel.load_config(
97
- 'black-forest-labs/FLUX.1-dev',
98
- subfolder="transformer",
99
- token=HF_TOKEN
100
- )
101
-
102
- transformer = FluxTransformer2DModel.from_config(config, torch_dtype=dtype)
103
-
104
- transformer_lora_config = LoraConfig(
105
- r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
106
- target_modules=[
107
- "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
108
- "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
109
- "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
110
- "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
111
- ],
112
- )
113
- transformer.add_adapter(transformer_lora_config)
114
- transformer.context_embedder.requires_grad_(True)
115
-
116
- transformer_state = load_file(os.path.join(checkpoint_path, "model.safetensors"))
117
- transformer.load_state_dict(transformer_state, strict=False)
118
- transformer = transformer.to(device).to(dtype)
119
-
120
- # Load scheduler
121
- noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
122
- 'black-forest-labs/FLUX.1-dev',
123
- subfolder="scheduler",
124
- token=HF_TOKEN
125
- )
126
 
127
- # Create pipeline
128
- pipeline = CustomFluxKontextPipeline(
129
- scheduler=noise_scheduler,
130
- aligner=aligner_model,
131
- transformer=transformer,
132
- vae=vae,
133
- text_embedder=text_encoder,
134
- ).to(device)
135
-
136
- model_pipeline = pipeline
137
- print("βœ… Models loaded!")
138
- return pipeline
139
-
140
- # Download checkpoint at startup
141
  download_checkpoint()
142
 
143
- # ZeroGPU decorator - this moves computation to GPU when called
144
- @spaces.GPU(duration=60) # 60 seconds of GPU time per generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @torch.no_grad()
146
- def generate_image(prompt, threshold=0.0, topk=0, height=512, width=512,
147
- guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
148
- """Generate image using the model"""
 
 
149
 
150
  if not prompt.strip():
151
- return None, None, "Please enter a prompt"
152
 
153
  try:
154
- # Load models (will use cached version after first call)
155
- pipe = load_models()
156
-
157
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
  generator = torch.Generator(device=device).manual_seed(int(seed))
159
 
160
- print(f"Generating image: {prompt[:50]}...")
161
-
162
- image = pipe(
163
  prompt=prompt,
164
  guidance_scale=guidance_scale,
165
  true_cfg_scale=true_cf_scale,
@@ -170,66 +154,147 @@ def generate_image(prompt, threshold=0.0, topk=0, height=512, width=512,
170
  generator=generator,
171
  ).images[0]
172
 
173
- return None, image, prompt
 
 
 
 
 
 
 
174
 
175
  except Exception as e:
176
  import traceback
177
- error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
178
- print(error_msg)
179
- return None, None, f"Error: {str(e)}"
180
 
181
- # Create Gradio interface
182
- with gr.Blocks(title="ConceptAligner") as demo:
183
- gr.Markdown("""
184
- # 🎨 ConceptAligner Demo
185
-
186
- Generate images with fine-tuned concept alignment using FLUX!
187
-
188
- ⚑ Running on ZeroGPU - GPU allocated on-demand for each generation
189
- """)
190
-
191
- with gr.Row():
192
- with gr.Column(scale=1):
193
- prompt_input = gr.Textbox(
194
- label="Prompt",
195
- lines=6,
196
- placeholder="Describe your image..."
197
- )
198
 
199
- generate_btn = gr.Button("✨ Generate", variant="primary", size="lg")
 
 
 
 
 
 
200
 
201
- with gr.Accordion("βš™οΈ Settings", open=True):
202
- guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
203
- num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps")
204
- seed = gr.Number(value=0, label="Seed", precision=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- with gr.Accordion("πŸ”¬ Advanced", open=False):
207
- true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG")
208
- threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold")
209
- topk = gr.Slider(0, 300, value=0, step=1, label="Top-K")
210
  with gr.Row():
211
- height = gr.Slider(256, 1024, value=512, step=64, label="Height")
212
- width = gr.Slider(256, 1024, value=512, step=64, label="Width")
213
-
214
- with gr.Column(scale=2):
215
- gr.Markdown("### πŸ–ΌοΈ Generated Image")
216
- output_image = gr.Image(label="Output", type="pil", height=512)
217
- status_text = gr.Textbox(label="Status", interactive=False, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- gr.Markdown("### πŸ“ Example Prompt")
220
- gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
221
-
222
- # Hidden components for compatibility
223
- prev_image_hidden = gr.Image(visible=False)
224
- prev_prompt_hidden = gr.Textbox(visible=False)
225
-
226
- generate_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  fn=generate_image,
228
- inputs=[
229
- prompt_input, threshold, topk,
230
- height, width, guidance_scale, true_cfg_scale, num_steps, seed
231
- ],
232
- outputs=[prev_image_hidden, output_image, prev_prompt_hidden]
 
 
233
  )
234
 
235
  if __name__ == "__main__":
 
1
  """
2
+ ConceptAligner - Same GPU behavior as FLUX demo
3
+ Models loaded at startup, GPU allocated only for inference
4
  """
5
 
6
  # CRITICAL: Import spaces FIRST
7
  import spaces
8
 
 
9
  import torch
10
  import gradio as gr
11
  import os
 
17
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
18
  from peft import LoraConfig
19
 
20
+ # Login
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
 
26
  # Configuration
27
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
28
  CHECKPOINT_DIR = "./checkpoint"
29
+ dtype = torch.bfloat16
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  EXAMPLE_PROMPTS = [
33
  ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
34
  ]
35
 
36
  def download_checkpoint():
37
+ """Download checkpoint files"""
38
  print("Downloading checkpoint files...")
 
39
  files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
40
  os.makedirs(CHECKPOINT_DIR, exist_ok=True)
41
 
 
49
  local_dir=CHECKPOINT_DIR,
50
  token=HF_TOKEN
51
  )
52
+ print("βœ“ Checkpoint files ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # Download at startup
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  download_checkpoint()
56
 
57
+ # Load models at startup (like FLUX does)
58
+ print("Loading models...")
59
+
60
+ # Load ConceptAligner
61
+ aligner_model = ConceptAligner().to(device).to(dtype)
62
+ adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_1.safetensors"))
63
+ aligner_model.load_state_dict(adapter_state, strict=True)
64
+ print(" βœ“ ConceptAligner")
65
+
66
+ # Load T5 encoder
67
+ text_encoder = LoraT5Embedder(device=device).to(dtype)
68
+ adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_2.safetensors"))
69
+ if "t5_encoder.shared.weight" in adapter_state:
70
+ adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
71
+ text_encoder.load_state_dict(adapter_state, strict=True)
72
+ print(" βœ“ T5 Encoder")
73
+
74
+ # Load VAE
75
+ vae = AutoencoderKL.from_pretrained(
76
+ 'black-forest-labs/FLUX.1-dev',
77
+ subfolder="vae",
78
+ torch_dtype=dtype,
79
+ token=HF_TOKEN
80
+ ).to(device)
81
+ print(" βœ“ VAE")
82
+
83
+ # Load transformer
84
+ config = FluxTransformer2DModel.load_config(
85
+ 'black-forest-labs/FLUX.1-dev',
86
+ subfolder="transformer",
87
+ token=HF_TOKEN
88
+ )
89
+
90
+ transformer = FluxTransformer2DModel.from_config(config, torch_dtype=dtype)
91
+
92
+ transformer_lora_config = LoraConfig(
93
+ r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
94
+ target_modules=[
95
+ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
96
+ "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
97
+ "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
98
+ "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
99
+ ],
100
+ )
101
+ transformer.add_adapter(transformer_lora_config)
102
+ transformer.context_embedder.requires_grad_(True)
103
+
104
+ transformer_state = load_file(os.path.join(CHECKPOINT_DIR, "model.safetensors"))
105
+ transformer.load_state_dict(transformer_state, strict=False)
106
+ transformer = transformer.to(device).to(dtype)
107
+ print(" βœ“ Transformer")
108
+
109
+ # Load scheduler
110
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
111
+ 'black-forest-labs/FLUX.1-dev',
112
+ subfolder="scheduler",
113
+ token=HF_TOKEN
114
+ )
115
+
116
+ # Create pipeline
117
+ pipe = CustomFluxKontextPipeline(
118
+ scheduler=noise_scheduler,
119
+ aligner=aligner_model,
120
+ transformer=transformer,
121
+ vae=vae,
122
+ text_embedder=text_encoder,
123
+ ).to(device)
124
+
125
+ print("βœ… Models loaded and ready!")
126
+ torch.cuda.empty_cache()
127
+
128
+ # History tracking
129
+ previous_image = None
130
+ previous_prompt = None
131
+
132
+ @spaces.GPU(duration=75)
133
  @torch.no_grad()
134
+ def generate_image(prompt, height=512, width=512, guidance_scale=3.5,
135
+ true_cf_scale=1.0, num_inference_steps=20, seed=0,
136
+ progress=gr.Progress(track_tqdm=True)):
137
+ """Generate image - models already loaded"""
138
+ global previous_image, previous_prompt
139
 
140
  if not prompt.strip():
141
+ return previous_image, None, previous_prompt or "No previous generation", seed
142
 
143
  try:
 
 
 
 
144
  generator = torch.Generator(device=device).manual_seed(int(seed))
145
 
146
+ current_image = pipe(
 
 
147
  prompt=prompt,
148
  guidance_scale=guidance_scale,
149
  true_cfg_scale=true_cf_scale,
 
154
  generator=generator,
155
  ).images[0]
156
 
157
+ # Store for comparison
158
+ prev_image = previous_image
159
+ prev_prompt = previous_prompt or "No previous generation"
160
+
161
+ previous_image = current_image
162
+ previous_prompt = prompt
163
+
164
+ return prev_image, current_image, prev_prompt, seed
165
 
166
  except Exception as e:
167
  import traceback
168
+ print(f"❌ Error: {e}")
169
+ print(traceback.format_exc())
170
+ return previous_image, None, previous_prompt or "", seed
171
 
172
+ def reset_history():
173
+ """Clear generation history"""
174
+ global previous_image, previous_prompt
175
+ previous_image = None
176
+ previous_prompt = None
177
+ return None, None, "No previous generation"
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ # Create Gradio interface
180
+ css = """
181
+ #col-container {
182
+ margin: 0 auto;
183
+ max-width: 1200px;
184
+ }
185
+ """
186
 
187
+ with gr.Blocks(css=css, title="ConceptAligner") as demo:
188
+ with gr.Column(elem_id="col-container"):
189
+ gr.Markdown("""
190
+ # 🎨 ConceptAligner Image Generator
191
+
192
+ Create stunning AI-generated images from text descriptions.
193
+ """)
194
+
195
+ with gr.Row():
196
+ with gr.Column(scale=1):
197
+ prompt_input = gr.Text(
198
+ label="Prompt",
199
+ show_label=False,
200
+ max_lines=3,
201
+ placeholder="Describe your image...",
202
+ container=False,
203
+ )
204
 
 
 
 
 
205
  with gr.Row():
206
+ generate_btn = gr.Button("✨ Generate", variant="primary", scale=3)
207
+ reset_btn = gr.Button("πŸ”„ Clear", variant="secondary", scale=1)
208
+
209
+ with gr.Accordion("βš™οΈ Settings", open=False):
210
+ seed = gr.Slider(
211
+ label="Seed",
212
+ minimum=0,
213
+ maximum=2147483647,
214
+ step=1,
215
+ value=0,
216
+ )
217
+
218
+ guidance_scale = gr.Slider(
219
+ label="Creativity Level",
220
+ minimum=1.0,
221
+ maximum=10.0,
222
+ step=0.5,
223
+ value=3.5,
224
+ )
225
+
226
+ num_inference_steps = gr.Slider(
227
+ label="Quality (steps)",
228
+ minimum=10,
229
+ maximum=50,
230
+ step=1,
231
+ value=20,
232
+ )
233
+
234
+ with gr.Row():
235
+ width = gr.Slider(
236
+ label="Width",
237
+ minimum=256,
238
+ maximum=1024,
239
+ step=64,
240
+ value=512,
241
+ )
242
+
243
+ height = gr.Slider(
244
+ label="Height",
245
+ minimum=256,
246
+ maximum=1024,
247
+ step=64,
248
+ value=512,
249
+ )
250
+
251
+ true_cfg_scale = gr.Slider(
252
+ label="True CFG Scale",
253
+ minimum=1.0,
254
+ maximum=10.0,
255
+ step=0.5,
256
+ value=1.0,
257
+ visible=False
258
+ )
259
+
260
+ with gr.Column(scale=2):
261
+ gr.Markdown("### πŸ“Š Your Generations")
262
 
263
+ with gr.Row():
264
+ with gr.Column():
265
+ gr.Markdown("**Previous**")
266
+ prev_image = gr.Image(label="Previous", show_label=False, type="pil", height=400)
267
+ prev_prompt_display = gr.Textbox(
268
+ label="Previous Prompt",
269
+ lines=2,
270
+ interactive=False,
271
+ show_label=False
272
+ )
273
+
274
+ with gr.Column():
275
+ gr.Markdown("**Latest**")
276
+ current_image = gr.Image(label="Current", show_label=False, type="pil", height=400)
277
+
278
+ gr.Markdown("### πŸ“ Try This Example")
279
+ gr.Examples(
280
+ examples=EXAMPLE_PROMPTS,
281
+ inputs=prompt_input,
282
+ outputs=[prev_image, current_image, prev_prompt_display, seed],
283
+ fn=generate_image,
284
+ cache_examples=False
285
+ )
286
+
287
+ # Event handlers
288
+ gr.on(
289
+ triggers=[generate_btn.click, prompt_input.submit],
290
  fn=generate_image,
291
+ inputs=[prompt_input, height, width, guidance_scale, true_cfg_scale, num_inference_steps, seed],
292
+ outputs=[prev_image, current_image, prev_prompt_display, seed]
293
+ )
294
+
295
+ reset_btn.click(
296
+ fn=reset_history,
297
+ outputs=[prev_image, current_image, prev_prompt_display]
298
  )
299
 
300
  if __name__ == "__main__":