concauu commited on
Commit
772461e
·
verified ·
1 Parent(s): 2b6f6e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -205
app.py CHANGED
@@ -8,7 +8,7 @@ os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
8
  import numpy as np
9
  import random
10
  import spaces
11
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL,UNet2DConditionModel
12
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
13
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
14
  from io import BytesIO
@@ -26,16 +26,12 @@ def get_hf_token(encrypted_token):
26
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
27
  if not key:
28
  raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
29
-
30
- # Convert key from string to bytes if necessary
31
  if isinstance(key, str):
32
  key = key.encode()
33
-
34
  f = Fernet(key)
35
- # Decrypt and decode the token
36
  decrypted_token = f.decrypt(encrypted_token).decode()
37
  return decrypted_token
38
-
39
  groq_client = Groq(api_key="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")
40
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
41
  login(token=decrypted_token)
@@ -59,17 +55,17 @@ t5_text_encoder = T5EncoderModel.from_pretrained(
59
  class TextProjection(torch.nn.Module):
60
  def __init__(self):
61
  super().__init__()
62
- self.proj = torch.nn.Linear(768, 3072) # Project from 768 to 3072 to match the transformer's expectation
 
63
  torch.nn.init.normal_(self.proj.weight, std=0.02)
64
 
65
  def forward(self, x):
66
  return self.proj(x.to(dtype))
67
 
68
- # Add this override to your existing pipeline setup
69
  class T5FluxPipeline(FluxPipeline):
70
  def _get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
71
  """Modified to work with T5 outputs (without classifier-free guidance handling)"""
72
- # Get T5 embeddings
73
  text_inputs = self.tokenizer(
74
  prompt,
75
  padding="max_length",
@@ -77,24 +73,16 @@ class T5FluxPipeline(FluxPipeline):
77
  truncation=True,
78
  return_tensors="pt",
79
  ).to(device)
80
-
81
  text_outputs = self.text_encoder(**text_inputs)
82
  prompt_embeds = text_outputs.last_hidden_state
83
-
84
- # Use mean pooling instead of CLIP's pooler_output
85
  pooled_prompt_embeds = prompt_embeds.mean(dim=1)
86
-
87
- # Expand for batch
88
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
89
  pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
90
-
91
  return prompt_embeds, pooled_prompt_embeds
92
 
93
-
94
  # Initialize pipeline components
95
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
96
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
97
- # Custom pipeline with T5 support
98
  pipe = T5FluxPipeline.from_pretrained(
99
  "black-forest-labs/FLUX.1-dev",
100
  text_encoder=t5_text_encoder,
@@ -104,14 +92,14 @@ pipe = T5FluxPipeline.from_pretrained(
104
  safety_checker=None
105
  ).to(device)
106
 
107
- # Add projection layer to pipeline
108
  pipe.text_projection = TextProjection().to(device, dtype=dtype)
109
  torch.cuda.empty_cache()
110
 
111
  MAX_SEED = np.iinfo(np.int32).max
112
  MAX_IMAGE_SIZE = 2048
113
 
114
- # Custom low-level CLIP prompt embedder override (returns exactly two tensors)
115
  def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
116
  text_inputs = self.tokenizer(
117
  prompt,
@@ -122,24 +110,14 @@ def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
122
  ).to(device)
123
  text_outputs = self.text_encoder(**text_inputs)
124
  prompt_embeds = text_outputs.last_hidden_state
125
- # Use mean pooling along the sequence dimension for pooled embeddings
126
  pooled_prompt_embeds = prompt_embeds.mean(dim=1)
127
- # Repeat for each image in the batch
128
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
129
  pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
130
  return prompt_embeds, pooled_prompt_embeds
131
 
132
- # Override the high-level encode_prompt to use T5 encoding and return three outputs:
133
- def custom_encode_prompt( self,
134
- prompt,
135
- device,
136
- num_images_per_prompt,
137
- do_classifier_free_guidance=False,
138
- negative_prompt=None,
139
- prompt_embeds=None,
140
- prompt_2=None,
141
- **kwargs):
142
- # Encode the prompt using the T5 components
143
  text_inputs = self.tokenizer(
144
  prompt,
145
  padding="max_length",
@@ -148,150 +126,120 @@ def custom_encode_prompt( self,
148
  return_tensors="pt",
149
  ).to(device)
150
  text_outputs = self.text_encoder(**text_inputs)
151
- # Project T5 embeddings into CLIP space
152
  text_embeddings = self.text_projection(text_outputs.last_hidden_state)
153
- # Compute pooled embeddings via mean pooling
154
  pooled_text_embeddings = text_embeddings.mean(dim=1)
155
-
156
  if do_classifier_free_guidance:
157
- # For classifier-free guidance, get negative prompt embeddings:
158
- uncond_input = self.tokenizer(
159
- [negative_prompt] if negative_prompt else [""],
160
- padding="max_length",
161
- max_length=512,
162
- truncation=True,
163
- return_tensors="pt",
164
- ).to(device)
165
- uncond_outputs = self.text_encoder(**uncond_input)
166
- uncond_embeddings = self.text_projection(uncond_outputs.last_hidden_state)
167
- pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
168
- # Concatenate unconditional and conditional embeddings
169
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
170
- pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
171
- token_ids = text_inputs.input_ids # use the conditional tokens as placeholder
172
  else:
173
- token_ids = text_inputs.input_ids
174
-
175
- # Repeat for the number of images per prompt
176
  text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
177
  pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
178
  token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
179
-
180
- # IMPORTANT: Return pooled_text_embeddings as a tensor (not a tuple)
181
  return text_embeddings, pooled_text_embeddings, token_ids
182
 
183
- # Patch both methods in your pipeline instance:
184
  pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
185
  pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
186
  pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
187
-
188
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
189
 
 
 
190
  pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(device, dtype=dtype)
191
 
192
  def patched_time_embed(self, timestep, guidance, pooled_projections):
193
- # Compute the timestep embedding (expected shape: (B,256))
194
  time_out = self.time_proj(timestep)
195
-
196
- # Ensure fixed_text_proj is set to map from 3072 to 256.
197
- # If it doesn't exist or its output dimension is not 256, recreate it.
198
- if (not hasattr(self, "fixed_text_proj")) or (self.fixed_text_proj.out_features != 256):
199
- self.fixed_text_proj = nn.Linear(3072, 256).to(
200
- device=pooled_projections.device, dtype=pooled_projections.dtype
201
- )
202
-
203
- text_out = self.fixed_text_proj(pooled_projections) # Should produce shape (B,256)
204
  return time_out + text_out
205
- # Apply the patch after the pipeline is created and patched with your custom encode methods:
 
206
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
207
 
208
- # History functions
209
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
210
- """Store only the final generated image"""
211
  if image is None:
212
  return history
213
-
214
- # Convert numpy array to PIL Image if needed
215
  from PIL import Image
216
  import numpy as np
217
-
218
  if isinstance(image, np.ndarray):
219
- # Convert from [0-255] to PIL Image
220
  if image.dtype == np.uint8:
221
  image = Image.fromarray(image)
222
- # Convert from float [0-1] to PIL Image
223
  else:
224
  image = Image.fromarray((image * 255).astype(np.uint8))
225
-
226
- # Convert final image to bytes
227
  buffered = BytesIO()
228
  image.save(buffered, format="PNG")
229
  img_bytes = buffered.getvalue()
230
-
231
  return history + [{
232
- "image": img_bytes,
233
- "prompt": prompt,
234
- "seed": seed,
235
- "width": width,
236
- "height": height,
237
- "guidance_scale": guidance_scale,
238
- "steps": steps,
239
  }]
240
 
241
  def create_history_html(history):
242
  html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
243
  for i, entry in enumerate(reversed(history)):
244
- img_str = base64.b64encode(entry["image"]).decode()
245
- html += f"""
246
- <div style='display: flex; gap: 20px; padding: 20px; background: #f5f5f5; border-radius: 10px;'>
247
- <img src="data:image/png;base64,{img_str}" style="width: 150px; height: 150px; object-fit: cover; border-radius: 5px;"/>
248
- <div style='flex: 1;'>
249
- <h3 style='margin: 0;'>Generation #{len(history)-i}</h3>
250
- <p><strong>Prompt:</strong> {entry["prompt"]}</p>
251
- <p><strong>Seed:</strong> {entry["seed"]}</p>
252
- <p><strong>Size:</strong> {entry["width"]}x{entry["height"]}</p>
253
- <p><strong>Guidance:</strong> {entry["guidance_scale"]}</p>
254
- <p><strong>Steps:</strong> {entry["steps"]}</p>
255
- </div>
256
- </div>
257
- """
258
  return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
259
 
260
-
261
  @spaces.GPU(duration=75)
262
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
263
- guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
264
  if randomize_seed:
265
- seed = random.randint(0, MAX_SEED)
266
  generator = torch.Generator().manual_seed(seed)
267
-
268
- # Truncate prompt to 512 tokens if needed
269
  tokens = t5_tokenizer.encode(prompt)[:512]
270
  processed_prompt = t5_tokenizer.decode(tokens, skip_special_tokens=True)
271
-
272
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
273
- prompt=processed_prompt,
274
- guidance_scale=guidance_scale,
275
- num_inference_steps=num_inference_steps,
276
- width=width,
277
- height=height,
278
- generator=generator,
279
- output_type="pil",
280
- good_vae=good_vae,
281
- ):
282
- yield img, seed
283
-
284
-
285
 
286
  def enhance_prompt(user_prompt):
287
- """Enhances the given prompt using Groq and returns the refined prompt."""
288
  try:
289
- chat_completion = groq_client.chat.completions.create(
290
- messages=[
291
- {
292
- "role": "system",
293
- "content": (
294
- """Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describe not only the content of the image but also such details as tone, style, color palette, and point of view, for photorealistic images, include the name of the device used (e.g., “shot on iPhone 16”), aperture, lens, and shot type. Use precise, visual descriptions (rather than metaphorical concepts).
295
  Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
296
  Medium:
297
  Consider what form of art this image should be simulating.
@@ -312,23 +260,22 @@ Technique: For paintings, how was the brush manipulated? For digital art, any sp
312
  Photo: Describe type of photography, camera gear, and camera settings. Any specific shot technique? (Comma-separated list of these)
313
  Painting: Mention the kind of paint, texture of canvas, and shape/texture of brushstrokes. (List)
314
  Digital: Note the software used, shading techniques, and multimedia approaches."""
315
- ),
316
- },
317
- {"role": "user", "content": user_prompt}
318
- ],
319
- model="llama-3.3-70b-versatile",
320
- temperature=0.5,
321
- max_completion_tokens=1024,
322
- top_p=1,
323
- stop=None,
324
- stream=False,
325
- )
326
- enhanced = chat_completion.choices[0].message.content
327
  except Exception as e:
328
- enhanced = f"Error enhancing prompt: {str(e)}"
329
  return enhanced
330
 
331
- # --- Gradio Interface ---
332
  css = """
333
  #col-container {
334
  margin: 0 auto;
@@ -338,79 +285,64 @@ css = """
338
 
339
  with gr.Blocks(css=css) as demo:
340
  history_state = gr.State([])
341
-
342
  with gr.Column(elem_id="col-container"):
343
- gr.Markdown("# FLUX.1 [dev] with History Tracking")
344
-
345
- # Prompt section
346
- gr.Markdown("### Step 1: Enhance Your Prompt")
347
- original_prompt = gr.Textbox(label="Original Prompt", lines=2)
348
- enhance_button = gr.Button("Enhance Prompt")
349
- enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", lines=2)
350
- enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
351
-
352
- # Generation section
353
- gr.Markdown("### Step 2: Generate Image")
354
- with gr.Row():
355
- run_button = gr.Button("Generate Image", variant="primary")
356
- result = gr.Image(label="Result", show_label=False)
357
-
358
- # Advanced settings
359
- with gr.Accordion("Advanced Settings"):
360
- seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
361
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
362
- with gr.Row():
363
- width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
364
- height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
365
- with gr.Row():
366
- guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
367
- num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
368
-
369
- # History sectionn
370
- with gr.Accordion("Generation History", open=False):
371
- history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
372
-
373
- # Examples
374
- gr.Examples(
375
- examples=[
376
- "a tiny astronaut hatching from an egg on the moon",
377
- "a cat holding a sign that says hello world",
378
- "an anime illustration of a wiener schnitzel",
379
- ],
380
- inputs=enhanced_prompt,
381
- outputs=[result, seed],
382
- fn=infer,
383
- cache_examples="lazy"
384
- )
385
-
386
- # Event handling
387
  generation_event = run_button.click(
388
- fn=infer,
389
- inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
390
- outputs=[result, seed]
391
  )
392
- # This will execute AFTER the generator completes
393
  generation_event.then(
394
- fn=append_to_history,
395
- inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
396
- outputs=history_state
397
  ).then(
398
- fn=create_history_html,
399
- inputs=history_state,
400
- outputs=history_display
401
  )
402
  enhanced_prompt.submit(
403
- fn=infer,
404
- inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
405
- outputs=[result, seed]
406
  ).then(
407
- fn=append_to_history,
408
- inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
409
- outputs=history_state
410
  ).then(
411
- fn=create_history_html,
412
- inputs=history_state,
413
- outputs=history_display
414
  )
415
-
416
- demo.launch(share=True)
 
8
  import numpy as np
9
  import random
10
  import spaces
11
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, UNet2DConditionModel
12
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
13
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
14
  from io import BytesIO
 
26
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
27
  if not key:
28
  raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
 
 
29
  if isinstance(key, str):
30
  key = key.encode()
 
31
  f = Fernet(key)
 
32
  decrypted_token = f.decrypt(encrypted_token).decode()
33
  return decrypted_token
34
+
35
  groq_client = Groq(api_key="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")
36
  decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
37
  login(token=decrypted_token)
 
55
  class TextProjection(torch.nn.Module):
56
  def __init__(self):
57
  super().__init__()
58
+ # Project from 768 to 3072 (T5 output to our combined text space)
59
+ self.proj = torch.nn.Linear(768, 3072)
60
  torch.nn.init.normal_(self.proj.weight, std=0.02)
61
 
62
  def forward(self, x):
63
  return self.proj(x.to(dtype))
64
 
65
+ # Custom pipeline with T5 support
66
  class T5FluxPipeline(FluxPipeline):
67
  def _get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
68
  """Modified to work with T5 outputs (without classifier-free guidance handling)"""
 
69
  text_inputs = self.tokenizer(
70
  prompt,
71
  padding="max_length",
 
73
  truncation=True,
74
  return_tensors="pt",
75
  ).to(device)
 
76
  text_outputs = self.text_encoder(**text_inputs)
77
  prompt_embeds = text_outputs.last_hidden_state
 
 
78
  pooled_prompt_embeds = prompt_embeds.mean(dim=1)
 
 
79
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
80
  pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
 
81
  return prompt_embeds, pooled_prompt_embeds
82
 
 
83
  # Initialize pipeline components
84
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
85
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
 
86
  pipe = T5FluxPipeline.from_pretrained(
87
  "black-forest-labs/FLUX.1-dev",
88
  text_encoder=t5_text_encoder,
 
92
  safety_checker=None
93
  ).to(device)
94
 
95
+ # Add our projection layer to the pipeline
96
  pipe.text_projection = TextProjection().to(device, dtype=dtype)
97
  torch.cuda.empty_cache()
98
 
99
  MAX_SEED = np.iinfo(np.int32).max
100
  MAX_IMAGE_SIZE = 2048
101
 
102
+ # Custom low-level CLIP prompt embedder override
103
  def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
104
  text_inputs = self.tokenizer(
105
  prompt,
 
110
  ).to(device)
111
  text_outputs = self.text_encoder(**text_inputs)
112
  prompt_embeds = text_outputs.last_hidden_state
 
113
  pooled_prompt_embeds = prompt_embeds.mean(dim=1)
 
114
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
115
  pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
116
  return prompt_embeds, pooled_prompt_embeds
117
 
118
+ # Override the high-level encode_prompt to use T5 encoding and return three outputs.
119
+ def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
120
+ negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
 
 
 
 
 
 
 
 
121
  text_inputs = self.tokenizer(
122
  prompt,
123
  padding="max_length",
 
126
  return_tensors="pt",
127
  ).to(device)
128
  text_outputs = self.text_encoder(**text_inputs)
129
+ # Project T5 embeddings into CLIP space using our projection layer.
130
  text_embeddings = self.text_projection(text_outputs.last_hidden_state)
 
131
  pooled_text_embeddings = text_embeddings.mean(dim=1)
 
132
  if do_classifier_free_guidance:
133
+ uncond_input = self.tokenizer(
134
+ [negative_prompt] if negative_prompt else [""],
135
+ padding="max_length",
136
+ max_length=512,
137
+ truncation=True,
138
+ return_tensors="pt",
139
+ ).to(device)
140
+ uncond_outputs = self.text_encoder(**uncond_input)
141
+ uncond_embeddings = self.text_projection(uncond_outputs.last_hidden_state)
142
+ pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
143
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
144
+ pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
145
+ token_ids = text_inputs.input_ids
 
 
146
  else:
147
+ token_ids = text_inputs.input_ids
 
 
148
  text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
149
  pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
150
  token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
 
 
151
  return text_embeddings, pooled_text_embeddings, token_ids
152
 
 
153
  pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
154
  pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
155
  pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
 
156
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
157
 
158
+ # ----- PATCH THE TRANSFORMER'S TIME EMBEDDING LAYER -----
159
+ # Force-override the fixed_text_proj attribute so that it maps from 3072 to 256.
160
  pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(device, dtype=dtype)
161
 
162
  def patched_time_embed(self, timestep, guidance, pooled_projections):
163
+ # Compute timestep embedding (expected shape: (B,256))
164
  time_out = self.time_proj(timestep)
165
+ # Use the fixed_text_proj we just set.
166
+ text_out = self.fixed_text_proj(pooled_projections)
 
 
 
 
 
 
 
167
  return time_out + text_out
168
+
169
+ # Patch the forward method.
170
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
171
 
172
+ # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
173
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
 
174
  if image is None:
175
  return history
 
 
176
  from PIL import Image
177
  import numpy as np
 
178
  if isinstance(image, np.ndarray):
 
179
  if image.dtype == np.uint8:
180
  image = Image.fromarray(image)
 
181
  else:
182
  image = Image.fromarray((image * 255).astype(np.uint8))
 
 
183
  buffered = BytesIO()
184
  image.save(buffered, format="PNG")
185
  img_bytes = buffered.getvalue()
 
186
  return history + [{
187
+ "image": img_bytes,
188
+ "prompt": prompt,
189
+ "seed": seed,
190
+ "width": width,
191
+ "height": height,
192
+ "guidance_scale": guidance_scale,
193
+ "steps": steps,
194
  }]
195
 
196
  def create_history_html(history):
197
  html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
198
  for i, entry in enumerate(reversed(history)):
199
+ img_str = base64.b64encode(entry["image"]).decode()
200
+ html += f"""
201
+ <div style='display: flex; gap: 20px; padding: 20px; background: #f5f5f5; border-radius: 10px;'>
202
+ <img src="data:image/png;base64,{img_str}" style="width: 150px; height: 150px; object-fit: cover; border-radius: 5px;"/>
203
+ <div style='flex: 1;'>
204
+ <h3 style='margin: 0;'>Generation #{len(history)-i}</h3>
205
+ <p><strong>Prompt:</strong> {entry["prompt"]}</p>
206
+ <p><strong>Seed:</strong> {entry["seed"]}</p>
207
+ <p><strong>Size:</strong> {entry["width"]}x{entry["height"]}</p>
208
+ <p><strong>Guidance:</strong> {entry["guidance_scale"]}</p>
209
+ <p><strong>Steps:</strong> {entry["steps"]}</p>
210
+ </div>
211
+ </div>
212
+ """
213
  return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
214
 
 
215
  @spaces.GPU(duration=75)
216
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
217
+ guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
218
  if randomize_seed:
219
+ seed = random.randint(0, MAX_SEED)
220
  generator = torch.Generator().manual_seed(seed)
 
 
221
  tokens = t5_tokenizer.encode(prompt)[:512]
222
  processed_prompt = t5_tokenizer.decode(tokens, skip_special_tokens=True)
 
223
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
224
+ prompt=processed_prompt,
225
+ guidance_scale=guidance_scale,
226
+ num_inference_steps=num_inference_steps,
227
+ width=width,
228
+ height=height,
229
+ generator=generator,
230
+ output_type="pil",
231
+ good_vae=good_vae,
232
+ ):
233
+ yield img, seed
 
 
234
 
235
  def enhance_prompt(user_prompt):
 
236
  try:
237
+ chat_completion = groq_client.chat.completions.create(
238
+ messages=[
239
+ {
240
+ "role": "system",
241
+ "content": (
242
+ """Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describe not only the content of the image but also such details as tone, style, color palette, and point of view, for photorealistic images, include the name of the device used (e.g., “shot on iPhone 16”), aperture, lens, and shot type. Use precise, visual descriptions (rather than metaphorical concepts).
243
  Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
244
  Medium:
245
  Consider what form of art this image should be simulating.
 
260
  Photo: Describe type of photography, camera gear, and camera settings. Any specific shot technique? (Comma-separated list of these)
261
  Painting: Mention the kind of paint, texture of canvas, and shape/texture of brushstrokes. (List)
262
  Digital: Note the software used, shading techniques, and multimedia approaches."""
263
+ ),
264
+ },
265
+ {"role": "user", "content": user_prompt}
266
+ ],
267
+ model="llama-3.3-70b-versatile",
268
+ temperature=0.5,
269
+ max_completion_tokens=1024,
270
+ top_p=1,
271
+ stop=None,
272
+ stream=False,
273
+ )
274
+ enhanced = chat_completion.choices[0].message.content
275
  except Exception as e:
276
+ enhanced = f"Error enhancing prompt: {str(e)}"
277
  return enhanced
278
 
 
279
  css = """
280
  #col-container {
281
  margin: 0 auto;
 
285
 
286
  with gr.Blocks(css=css) as demo:
287
  history_state = gr.State([])
 
288
  with gr.Column(elem_id="col-container"):
289
+ gr.Markdown("# FLUX.1 [dev] with History Tracking")
290
+ gr.Markdown("### Step 1: Enhance Your Prompt")
291
+ original_prompt = gr.Textbox(label="Original Prompt", lines=2)
292
+ enhance_button = gr.Button("Enhance Prompt")
293
+ enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", lines=2)
294
+ enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
295
+ gr.Markdown("### Step 2: Generate Image")
296
+ with gr.Row():
297
+ run_button = gr.Button("Generate Image", variant="primary")
298
+ result = gr.Image(label="Result", show_label=False)
299
+ with gr.Accordion("Advanced Settings"):
300
+ seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
301
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
302
+ with gr.Row():
303
+ width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
304
+ height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
305
+ with gr.Row():
306
+ guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
307
+ num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
308
+ with gr.Accordion("Generation History", open=False):
309
+ history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
310
+ gr.Examples(
311
+ examples=[
312
+ "a tiny astronaut hatching from an egg on the moon",
313
+ "a cat holding a sign that says hello world",
314
+ "an anime illustration of a wiener schnitzel",
315
+ ],
316
+ inputs=enhanced_prompt,
317
+ outputs=[result, seed],
318
+ fn=infer,
319
+ cache_examples="lazy"
320
+ )
 
 
 
 
 
 
 
 
 
 
 
 
321
  generation_event = run_button.click(
322
+ fn=infer,
323
+ inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
324
+ outputs=[result, seed]
325
  )
 
326
  generation_event.then(
327
+ fn=append_to_history,
328
+ inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
329
+ outputs=history_state
330
  ).then(
331
+ fn=create_history_html,
332
+ inputs=history_state,
333
+ outputs=history_display
334
  )
335
  enhanced_prompt.submit(
336
+ fn=infer,
337
+ inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
338
+ outputs=[result, seed]
339
  ).then(
340
+ fn=append_to_history,
341
+ inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
342
+ outputs=history_state
343
  ).then(
344
+ fn=create_history_html,
345
+ inputs=history_state,
346
+ outputs=history_display
347
  )
348
+ demo.launch(share=True)