Shaoan commited on
Commit
9128255
·
verified ·
1 Parent(s): 6af382a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +158 -173
app.py CHANGED
@@ -1,7 +1,11 @@
1
  """
2
- ConceptAligner Hugging Face Demo
3
  """
4
 
 
 
 
 
5
  import torch
6
  import gradio as gr
7
  import os
@@ -13,14 +17,6 @@ from pipeline import CustomFluxKontextPipeline
13
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
14
  from peft import LoraConfig
15
 
16
- # For HF Spaces GPU support
17
- try:
18
- import spaces
19
- GPU_AVAILABLE = True
20
- except ImportError:
21
- GPU_AVAILABLE = False
22
- print("⚠️ spaces package not available, running without @spaces.GPU decorator")
23
-
24
  # Login with token from environment
25
  HF_TOKEN = os.environ.get("HF_TOKEN")
26
  if HF_TOKEN:
@@ -56,161 +52,151 @@ def download_checkpoint():
56
 
57
  print("✓ All checkpoint files ready!")
58
 
59
- class ConceptAlignerModel:
60
- def __init__(self):
61
- download_checkpoint()
62
-
63
- self.checkpoint_path = CHECKPOINT_DIR
64
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
65
- self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
66
-
67
- self.previous_image = None
68
- self.previous_prompt = None
69
-
70
- self.setup_models()
71
-
72
- def setup_models(self):
73
- """Load all models"""
74
- print(f"Loading models on {self.device}...")
75
-
76
- # Load ConceptAligner
77
- print(" Loading ConceptAligner...")
78
- self.model = ConceptAligner().to(self.device).to(self.dtype)
79
- adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
80
- self.model.load_state_dict(adapter_state, strict=True)
81
- print(" ConceptAligner loaded")
82
-
83
- # Load T5 encoder
84
- print(" Loading fine-tuned T5 encoder...")
85
- self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
86
- adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
87
- if "t5_encoder.shared.weight" in adapter_state:
88
- adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
89
- self.text_encoder.load_state_dict(adapter_state, strict=True)
90
- print(" T5 encoder loaded")
91
-
92
- # Download VAE
93
- print(" Loading VAE from FLUX.1-dev...")
94
- vae = AutoencoderKL.from_pretrained(
95
- 'black-forest-labs/FLUX.1-dev',
96
- subfolder="vae",
97
- torch_dtype=self.dtype,
98
- token=HF_TOKEN
99
- ).to(self.device)
100
- print(" VAE loaded")
101
-
102
- # Create transformer from config
103
- print(" Downloading transformer config...")
104
- config = FluxTransformer2DModel.load_config(
105
- 'black-forest-labs/FLUX.1-dev',
106
- subfolder="transformer",
107
- token=HF_TOKEN
108
- )
109
-
110
- print(" Initializing transformer...")
111
- transformer = FluxTransformer2DModel.from_config(config, torch_dtype=self.dtype)
112
-
113
- print(" Adding LoRA adapters...")
114
- transformer_lora_config = LoraConfig(
115
- r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
116
- target_modules=[
117
- "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
118
- "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
119
- "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
120
- "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
121
- ],
122
- )
123
- transformer.add_adapter(transformer_lora_config)
124
- transformer.context_embedder.requires_grad_(True)
125
-
126
- print(" Loading fine-tuned transformer weights...")
127
- transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
128
- transformer.load_state_dict(transformer_state, strict=False)
129
- transformer = transformer.to(self.device).to(self.dtype)
130
- print(" ✓ Transformer loaded")
131
-
132
- # Load empty pooled clip
133
- self.empty_pooled_clip = torch.load(
134
- os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
135
- map_location=self.device,
136
- weights_only=True
137
- ).to(self.dtype)
138
-
139
- # Create scheduler
140
- noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
141
- 'black-forest-labs/FLUX.1-dev',
142
- subfolder="scheduler",
143
- token=HF_TOKEN
144
- )
145
-
146
- # Create pipeline
147
- self.pipe = CustomFluxKontextPipeline(
148
- scheduler=noise_scheduler,
149
- aligner=self.model,
150
- transformer=transformer,
151
- vae=vae,
152
- text_embedder=self.text_encoder,
153
- ).to(self.device)
154
-
155
- print(" ALL MODELS LOADED!")
156
-
157
- if torch.cuda.is_available():
158
- allocated = torch.cuda.memory_allocated(0) / 1024**3
159
- print(f"📊 GPU Memory: {allocated:.2f}GB allocated")
160
-
161
- @torch.no_grad()
162
- def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
163
- guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
164
- if not prompt.strip():
165
- return self.previous_image, None, self.previous_prompt or ""
166
-
167
- try:
168
- generator = torch.Generator(device=self.device).manual_seed(int(seed))
169
- current_image = self.pipe(
170
- prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale,
171
- max_sequence_length=512, num_inference_steps=num_inference_steps,
172
- height=height, width=width, generator=generator,
173
- ).images[0]
174
-
175
- prev_image = self.previous_image
176
- prev_prompt = self.previous_prompt or "No previous generation"
177
-
178
- self.previous_image = current_image
179
- self.previous_prompt = prompt
180
-
181
- return prev_image, current_image, prev_prompt
182
- except Exception as e:
183
- import traceback
184
- print(f"❌ Error: {e}")
185
- print(traceback.format_exc())
186
- return self.previous_image, None, self.previous_prompt or ""
187
-
188
- def reset_history(self):
189
- self.previous_image = None
190
- self.previous_prompt = None
191
- return None, None, "No previous generation"
192
-
193
- # Initialize model
194
- print("🚀 Initializing ConceptAligner...")
195
- model = ConceptAlignerModel()
196
-
197
- # Wrap generation function with @spaces.GPU if available
198
- if GPU_AVAILABLE:
199
- generate_fn = spaces.GPU(model.generate_image)
200
- else:
201
- generate_fn = model.generate_image
202
 
203
  # Create Gradio interface
204
  with gr.Blocks(title="ConceptAligner") as demo:
205
- gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
 
 
 
 
 
 
206
 
207
  with gr.Row():
208
  with gr.Column(scale=1):
209
- prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...")
 
 
 
 
210
 
211
- with gr.Row():
212
- generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3)
213
- reset_btn = gr.Button("🔄 Reset", variant="secondary", size="lg", scale=1)
214
 
215
  with gr.Accordion("⚙️ Settings", open=True):
216
  guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
@@ -226,26 +212,25 @@ with gr.Blocks(title="ConceptAligner") as demo:
226
  width = gr.Slider(256, 1024, value=512, step=64, label="Width")
227
 
228
  with gr.Column(scale=2):
229
- gr.Markdown("### 📊 Comparison View")
230
- with gr.Row():
231
- with gr.Column():
232
- gr.Markdown("**Previous**")
233
- prev_image = gr.Image(label="Previous", type="pil", height=450)
234
- prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False)
235
- with gr.Column():
236
- gr.Markdown("**Current**")
237
- current_image = gr.Image(label="Current", type="pil", height=450)
238
-
239
- gr.Markdown("### 📝 Example")
240
  gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
241
 
 
 
 
 
242
  generate_btn.click(
243
- fn=generate_fn,
244
- inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
245
- outputs=[prev_image, current_image, prev_prompt_display]
 
 
 
246
  )
247
 
248
- reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
249
-
250
  if __name__ == "__main__":
251
  demo.launch()
 
1
  """
2
+ ConceptAligner Hugging Face Demo - ZeroGPU Compatible
3
  """
4
 
5
+ # CRITICAL: Import spaces FIRST
6
+ import spaces
7
+
8
+ # Now import everything else
9
  import torch
10
  import gradio as gr
11
  import os
 
17
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
18
  from peft import LoraConfig
19
 
 
 
 
 
 
 
 
 
20
  # Login with token from environment
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  if HF_TOKEN:
 
52
 
53
  print("✓ All checkpoint files ready!")
54
 
55
+ # Global model variable
56
+ model_pipeline = None
57
+
58
+ def load_models():
59
+ """Load all models - called once at startup"""
60
+ global model_pipeline
61
+
62
+ if model_pipeline is not None:
63
+ return model_pipeline
64
+
65
+ print("🚀 Loading models...")
66
+
67
+ checkpoint_path = CHECKPOINT_DIR
68
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
70
+
71
+ # Load ConceptAligner
72
+ print(" Loading ConceptAligner...")
73
+ aligner_model = ConceptAligner().to(device).to(dtype)
74
+ adapter_state = load_file(os.path.join(checkpoint_path, "model_1.safetensors"))
75
+ aligner_model.load_state_dict(adapter_state, strict=True)
76
+
77
+ # Load T5 encoder
78
+ print(" Loading T5 encoder...")
79
+ text_encoder = LoraT5Embedder(device=device).to(dtype)
80
+ adapter_state = load_file(os.path.join(checkpoint_path, "model_2.safetensors"))
81
+ if "t5_encoder.shared.weight" in adapter_state:
82
+ adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
83
+ text_encoder.load_state_dict(adapter_state, strict=True)
84
+
85
+ # Load VAE
86
+ print(" Loading VAE...")
87
+ vae = AutoencoderKL.from_pretrained(
88
+ 'black-forest-labs/FLUX.1-dev',
89
+ subfolder="vae",
90
+ torch_dtype=dtype,
91
+ token=HF_TOKEN
92
+ ).to(device)
93
+
94
+ # Load transformer
95
+ print(" Loading transformer...")
96
+ config = FluxTransformer2DModel.load_config(
97
+ 'black-forest-labs/FLUX.1-dev',
98
+ subfolder="transformer",
99
+ token=HF_TOKEN
100
+ )
101
+
102
+ transformer = FluxTransformer2DModel.from_config(config, torch_dtype=dtype)
103
+
104
+ transformer_lora_config = LoraConfig(
105
+ r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
106
+ target_modules=[
107
+ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
108
+ "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
109
+ "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
110
+ "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
111
+ ],
112
+ )
113
+ transformer.add_adapter(transformer_lora_config)
114
+ transformer.context_embedder.requires_grad_(True)
115
+
116
+ transformer_state = load_file(os.path.join(checkpoint_path, "model.safetensors"))
117
+ transformer.load_state_dict(transformer_state, strict=False)
118
+ transformer = transformer.to(device).to(dtype)
119
+
120
+ # Load scheduler
121
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
122
+ 'black-forest-labs/FLUX.1-dev',
123
+ subfolder="scheduler",
124
+ token=HF_TOKEN
125
+ )
126
+
127
+ # Create pipeline
128
+ pipeline = CustomFluxKontextPipeline(
129
+ scheduler=noise_scheduler,
130
+ aligner=aligner_model,
131
+ transformer=transformer,
132
+ vae=vae,
133
+ text_embedder=text_encoder,
134
+ ).to(device)
135
+
136
+ model_pipeline = pipeline
137
+ print("✅ Models loaded!")
138
+ return pipeline
139
+
140
+ # Download checkpoint at startup
141
+ download_checkpoint()
142
+
143
+ # ZeroGPU decorator - this moves computation to GPU when called
144
+ @spaces.GPU(duration=60) # 60 seconds of GPU time per generation
145
+ @torch.no_grad()
146
+ def generate_image(prompt, threshold=0.0, topk=0, height=512, width=512,
147
+ guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
148
+ """Generate image using the model"""
149
+
150
+ if not prompt.strip():
151
+ return None, None, "Please enter a prompt"
152
+
153
+ try:
154
+ # Load models (will use cached version after first call)
155
+ pipe = load_models()
156
+
157
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+ generator = torch.Generator(device=device).manual_seed(int(seed))
159
+
160
+ print(f"Generating image: {prompt[:50]}...")
161
+
162
+ image = pipe(
163
+ prompt=prompt,
164
+ guidance_scale=guidance_scale,
165
+ true_cfg_scale=true_cf_scale,
166
+ max_sequence_length=512,
167
+ num_inference_steps=num_inference_steps,
168
+ height=height,
169
+ width=width,
170
+ generator=generator,
171
+ ).images[0]
172
+
173
+ return None, image, prompt
174
+
175
+ except Exception as e:
176
+ import traceback
177
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
178
+ print(error_msg)
179
+ return None, None, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # Create Gradio interface
182
  with gr.Blocks(title="ConceptAligner") as demo:
183
+ gr.Markdown("""
184
+ # 🎨 ConceptAligner Demo
185
+
186
+ Generate images with fine-tuned concept alignment using FLUX!
187
+
188
+ ⚡ Running on ZeroGPU - GPU allocated on-demand for each generation
189
+ """)
190
 
191
  with gr.Row():
192
  with gr.Column(scale=1):
193
+ prompt_input = gr.Textbox(
194
+ label="Prompt",
195
+ lines=6,
196
+ placeholder="Describe your image..."
197
+ )
198
 
199
+ generate_btn = gr.Button("✨ Generate", variant="primary", size="lg")
 
 
200
 
201
  with gr.Accordion("⚙️ Settings", open=True):
202
  guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
 
212
  width = gr.Slider(256, 1024, value=512, step=64, label="Width")
213
 
214
  with gr.Column(scale=2):
215
+ gr.Markdown("### 🖼️ Generated Image")
216
+ output_image = gr.Image(label="Output", type="pil", height=512)
217
+ status_text = gr.Textbox(label="Status", interactive=False, visible=False)
218
+
219
+ gr.Markdown("### 📝 Example Prompt")
 
 
 
 
 
 
220
  gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
221
 
222
+ # Hidden components for compatibility
223
+ prev_image_hidden = gr.Image(visible=False)
224
+ prev_prompt_hidden = gr.Textbox(visible=False)
225
+
226
  generate_btn.click(
227
+ fn=generate_image,
228
+ inputs=[
229
+ prompt_input, threshold, topk,
230
+ height, width, guidance_scale, true_cfg_scale, num_steps, seed
231
+ ],
232
+ outputs=[prev_image_hidden, output_image, prev_prompt_hidden]
233
  )
234
 
 
 
235
  if __name__ == "__main__":
236
  demo.launch()