Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
|
|
| 8 |
import numpy as np
|
| 9 |
import random
|
| 10 |
import spaces
|
| 11 |
-
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL,UNet2DConditionModel
|
| 12 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
|
| 13 |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
| 14 |
from io import BytesIO
|
|
@@ -26,16 +26,12 @@ def get_hf_token(encrypted_token):
|
|
| 26 |
key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
|
| 27 |
if not key:
|
| 28 |
raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
|
| 29 |
-
|
| 30 |
-
# Convert key from string to bytes if necessary
|
| 31 |
if isinstance(key, str):
|
| 32 |
key = key.encode()
|
| 33 |
-
|
| 34 |
f = Fernet(key)
|
| 35 |
-
# Decrypt and decode the token
|
| 36 |
decrypted_token = f.decrypt(encrypted_token).decode()
|
| 37 |
return decrypted_token
|
| 38 |
-
|
| 39 |
groq_client = Groq(api_key="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")
|
| 40 |
decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
|
| 41 |
login(token=decrypted_token)
|
|
@@ -59,17 +55,17 @@ t5_text_encoder = T5EncoderModel.from_pretrained(
|
|
| 59 |
class TextProjection(torch.nn.Module):
|
| 60 |
def __init__(self):
|
| 61 |
super().__init__()
|
| 62 |
-
|
|
|
|
| 63 |
torch.nn.init.normal_(self.proj.weight, std=0.02)
|
| 64 |
|
| 65 |
def forward(self, x):
|
| 66 |
return self.proj(x.to(dtype))
|
| 67 |
|
| 68 |
-
#
|
| 69 |
class T5FluxPipeline(FluxPipeline):
|
| 70 |
def _get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
|
| 71 |
"""Modified to work with T5 outputs (without classifier-free guidance handling)"""
|
| 72 |
-
# Get T5 embeddings
|
| 73 |
text_inputs = self.tokenizer(
|
| 74 |
prompt,
|
| 75 |
padding="max_length",
|
|
@@ -77,24 +73,16 @@ class T5FluxPipeline(FluxPipeline):
|
|
| 77 |
truncation=True,
|
| 78 |
return_tensors="pt",
|
| 79 |
).to(device)
|
| 80 |
-
|
| 81 |
text_outputs = self.text_encoder(**text_inputs)
|
| 82 |
prompt_embeds = text_outputs.last_hidden_state
|
| 83 |
-
|
| 84 |
-
# Use mean pooling instead of CLIP's pooler_output
|
| 85 |
pooled_prompt_embeds = prompt_embeds.mean(dim=1)
|
| 86 |
-
|
| 87 |
-
# Expand for batch
|
| 88 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 89 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 90 |
-
|
| 91 |
return prompt_embeds, pooled_prompt_embeds
|
| 92 |
|
| 93 |
-
|
| 94 |
# Initialize pipeline components
|
| 95 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
| 96 |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
| 97 |
-
# Custom pipeline with T5 support
|
| 98 |
pipe = T5FluxPipeline.from_pretrained(
|
| 99 |
"black-forest-labs/FLUX.1-dev",
|
| 100 |
text_encoder=t5_text_encoder,
|
|
@@ -104,14 +92,14 @@ pipe = T5FluxPipeline.from_pretrained(
|
|
| 104 |
safety_checker=None
|
| 105 |
).to(device)
|
| 106 |
|
| 107 |
-
# Add projection layer to pipeline
|
| 108 |
pipe.text_projection = TextProjection().to(device, dtype=dtype)
|
| 109 |
torch.cuda.empty_cache()
|
| 110 |
|
| 111 |
MAX_SEED = np.iinfo(np.int32).max
|
| 112 |
MAX_IMAGE_SIZE = 2048
|
| 113 |
|
| 114 |
-
# Custom low-level CLIP prompt embedder override
|
| 115 |
def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
|
| 116 |
text_inputs = self.tokenizer(
|
| 117 |
prompt,
|
|
@@ -122,24 +110,14 @@ def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
|
|
| 122 |
).to(device)
|
| 123 |
text_outputs = self.text_encoder(**text_inputs)
|
| 124 |
prompt_embeds = text_outputs.last_hidden_state
|
| 125 |
-
# Use mean pooling along the sequence dimension for pooled embeddings
|
| 126 |
pooled_prompt_embeds = prompt_embeds.mean(dim=1)
|
| 127 |
-
# Repeat for each image in the batch
|
| 128 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 129 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 130 |
return prompt_embeds, pooled_prompt_embeds
|
| 131 |
|
| 132 |
-
# Override the high-level encode_prompt to use T5 encoding and return three outputs
|
| 133 |
-
def custom_encode_prompt(
|
| 134 |
-
|
| 135 |
-
device,
|
| 136 |
-
num_images_per_prompt,
|
| 137 |
-
do_classifier_free_guidance=False,
|
| 138 |
-
negative_prompt=None,
|
| 139 |
-
prompt_embeds=None,
|
| 140 |
-
prompt_2=None,
|
| 141 |
-
**kwargs):
|
| 142 |
-
# Encode the prompt using the T5 components
|
| 143 |
text_inputs = self.tokenizer(
|
| 144 |
prompt,
|
| 145 |
padding="max_length",
|
|
@@ -148,150 +126,120 @@ def custom_encode_prompt( self,
|
|
| 148 |
return_tensors="pt",
|
| 149 |
).to(device)
|
| 150 |
text_outputs = self.text_encoder(**text_inputs)
|
| 151 |
-
# Project T5 embeddings into CLIP space
|
| 152 |
text_embeddings = self.text_projection(text_outputs.last_hidden_state)
|
| 153 |
-
# Compute pooled embeddings via mean pooling
|
| 154 |
pooled_text_embeddings = text_embeddings.mean(dim=1)
|
| 155 |
-
|
| 156 |
if do_classifier_free_guidance:
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
|
| 171 |
-
token_ids = text_inputs.input_ids # use the conditional tokens as placeholder
|
| 172 |
else:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
# Repeat for the number of images per prompt
|
| 176 |
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 177 |
pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 178 |
token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
|
| 179 |
-
|
| 180 |
-
# IMPORTANT: Return pooled_text_embeddings as a tensor (not a tuple)
|
| 181 |
return text_embeddings, pooled_text_embeddings, token_ids
|
| 182 |
|
| 183 |
-
# Patch both methods in your pipeline instance:
|
| 184 |
pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
|
| 185 |
pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
|
| 186 |
pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
|
| 187 |
-
|
| 188 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
| 189 |
|
|
|
|
|
|
|
| 190 |
pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(device, dtype=dtype)
|
| 191 |
|
| 192 |
def patched_time_embed(self, timestep, guidance, pooled_projections):
|
| 193 |
-
# Compute
|
| 194 |
time_out = self.time_proj(timestep)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
# If it doesn't exist or its output dimension is not 256, recreate it.
|
| 198 |
-
if (not hasattr(self, "fixed_text_proj")) or (self.fixed_text_proj.out_features != 256):
|
| 199 |
-
self.fixed_text_proj = nn.Linear(3072, 256).to(
|
| 200 |
-
device=pooled_projections.device, dtype=pooled_projections.dtype
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
text_out = self.fixed_text_proj(pooled_projections) # Should produce shape (B,256)
|
| 204 |
return time_out + text_out
|
| 205 |
-
|
|
|
|
| 206 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
| 207 |
|
| 208 |
-
#
|
| 209 |
def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
|
| 210 |
-
"""Store only the final generated image"""
|
| 211 |
if image is None:
|
| 212 |
return history
|
| 213 |
-
|
| 214 |
-
# Convert numpy array to PIL Image if needed
|
| 215 |
from PIL import Image
|
| 216 |
import numpy as np
|
| 217 |
-
|
| 218 |
if isinstance(image, np.ndarray):
|
| 219 |
-
# Convert from [0-255] to PIL Image
|
| 220 |
if image.dtype == np.uint8:
|
| 221 |
image = Image.fromarray(image)
|
| 222 |
-
# Convert from float [0-1] to PIL Image
|
| 223 |
else:
|
| 224 |
image = Image.fromarray((image * 255).astype(np.uint8))
|
| 225 |
-
|
| 226 |
-
# Convert final image to bytes
|
| 227 |
buffered = BytesIO()
|
| 228 |
image.save(buffered, format="PNG")
|
| 229 |
img_bytes = buffered.getvalue()
|
| 230 |
-
|
| 231 |
return history + [{
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
}]
|
| 240 |
|
| 241 |
def create_history_html(history):
|
| 242 |
html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
|
| 243 |
for i, entry in enumerate(reversed(history)):
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
|
| 259 |
|
| 260 |
-
|
| 261 |
@spaces.GPU(duration=75)
|
| 262 |
-
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
|
| 263 |
-
|
| 264 |
if randomize_seed:
|
| 265 |
-
|
| 266 |
generator = torch.Generator().manual_seed(seed)
|
| 267 |
-
|
| 268 |
-
# Truncate prompt to 512 tokens if needed
|
| 269 |
tokens = t5_tokenizer.encode(prompt)[:512]
|
| 270 |
processed_prompt = t5_tokenizer.decode(tokens, skip_special_tokens=True)
|
| 271 |
-
|
| 272 |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
|
| 286 |
def enhance_prompt(user_prompt):
|
| 287 |
-
"""Enhances the given prompt using Groq and returns the refined prompt."""
|
| 288 |
try:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
|
| 296 |
Medium:
|
| 297 |
Consider what form of art this image should be simulating.
|
|
@@ -312,23 +260,22 @@ Technique: For paintings, how was the brush manipulated? For digital art, any sp
|
|
| 312 |
Photo: Describe type of photography, camera gear, and camera settings. Any specific shot technique? (Comma-separated list of these)
|
| 313 |
Painting: Mention the kind of paint, texture of canvas, and shape/texture of brushstrokes. (List)
|
| 314 |
Digital: Note the software used, shading techniques, and multimedia approaches."""
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
except Exception as e:
|
| 328 |
-
|
| 329 |
return enhanced
|
| 330 |
|
| 331 |
-
# --- Gradio Interface ---
|
| 332 |
css = """
|
| 333 |
#col-container {
|
| 334 |
margin: 0 auto;
|
|
@@ -338,79 +285,64 @@ css = """
|
|
| 338 |
|
| 339 |
with gr.Blocks(css=css) as demo:
|
| 340 |
history_state = gr.State([])
|
| 341 |
-
|
| 342 |
with gr.Column(elem_id="col-container"):
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
examples=[
|
| 376 |
-
"a tiny astronaut hatching from an egg on the moon",
|
| 377 |
-
"a cat holding a sign that says hello world",
|
| 378 |
-
"an anime illustration of a wiener schnitzel",
|
| 379 |
-
],
|
| 380 |
-
inputs=enhanced_prompt,
|
| 381 |
-
outputs=[result, seed],
|
| 382 |
-
fn=infer,
|
| 383 |
-
cache_examples="lazy"
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
# Event handling
|
| 387 |
generation_event = run_button.click(
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
)
|
| 392 |
-
# This will execute AFTER the generator completes
|
| 393 |
generation_event.then(
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
).then(
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
)
|
| 402 |
enhanced_prompt.submit(
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
).then(
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
).then(
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
)
|
| 415 |
-
|
| 416 |
-
demo.launch(share=True)
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import random
|
| 10 |
import spaces
|
| 11 |
+
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, UNet2DConditionModel
|
| 12 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5EncoderModel
|
| 13 |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
| 14 |
from io import BytesIO
|
|
|
|
| 26 |
key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
|
| 27 |
if not key:
|
| 28 |
raise ValueError("Missing decryption key! Set the DECRYPTION_KEY environment variable.")
|
|
|
|
|
|
|
| 29 |
if isinstance(key, str):
|
| 30 |
key = key.encode()
|
|
|
|
| 31 |
f = Fernet(key)
|
|
|
|
| 32 |
decrypted_token = f.decrypt(encrypted_token).decode()
|
| 33 |
return decrypted_token
|
| 34 |
+
|
| 35 |
groq_client = Groq(api_key="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")
|
| 36 |
decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQMkmi59cL-MKGr6SYnsB0E2gGITJG2j29Ar9yjaZP-EC6hHsCBmwKSj4aFtTor9_n0_NdMBv1GtlxZRmwnQwriB-Xr94e")
|
| 37 |
login(token=decrypted_token)
|
|
|
|
| 55 |
class TextProjection(torch.nn.Module):
|
| 56 |
def __init__(self):
|
| 57 |
super().__init__()
|
| 58 |
+
# Project from 768 to 3072 (T5 output to our combined text space)
|
| 59 |
+
self.proj = torch.nn.Linear(768, 3072)
|
| 60 |
torch.nn.init.normal_(self.proj.weight, std=0.02)
|
| 61 |
|
| 62 |
def forward(self, x):
|
| 63 |
return self.proj(x.to(dtype))
|
| 64 |
|
| 65 |
+
# Custom pipeline with T5 support
|
| 66 |
class T5FluxPipeline(FluxPipeline):
|
| 67 |
def _get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
|
| 68 |
"""Modified to work with T5 outputs (without classifier-free guidance handling)"""
|
|
|
|
| 69 |
text_inputs = self.tokenizer(
|
| 70 |
prompt,
|
| 71 |
padding="max_length",
|
|
|
|
| 73 |
truncation=True,
|
| 74 |
return_tensors="pt",
|
| 75 |
).to(device)
|
|
|
|
| 76 |
text_outputs = self.text_encoder(**text_inputs)
|
| 77 |
prompt_embeds = text_outputs.last_hidden_state
|
|
|
|
|
|
|
| 78 |
pooled_prompt_embeds = prompt_embeds.mean(dim=1)
|
|
|
|
|
|
|
| 79 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 80 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
|
|
|
| 81 |
return prompt_embeds, pooled_prompt_embeds
|
| 82 |
|
|
|
|
| 83 |
# Initialize pipeline components
|
| 84 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
| 85 |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
|
|
|
| 86 |
pipe = T5FluxPipeline.from_pretrained(
|
| 87 |
"black-forest-labs/FLUX.1-dev",
|
| 88 |
text_encoder=t5_text_encoder,
|
|
|
|
| 92 |
safety_checker=None
|
| 93 |
).to(device)
|
| 94 |
|
| 95 |
+
# Add our projection layer to the pipeline
|
| 96 |
pipe.text_projection = TextProjection().to(device, dtype=dtype)
|
| 97 |
torch.cuda.empty_cache()
|
| 98 |
|
| 99 |
MAX_SEED = np.iinfo(np.int32).max
|
| 100 |
MAX_IMAGE_SIZE = 2048
|
| 101 |
|
| 102 |
+
# Custom low-level CLIP prompt embedder override
|
| 103 |
def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
|
| 104 |
text_inputs = self.tokenizer(
|
| 105 |
prompt,
|
|
|
|
| 110 |
).to(device)
|
| 111 |
text_outputs = self.text_encoder(**text_inputs)
|
| 112 |
prompt_embeds = text_outputs.last_hidden_state
|
|
|
|
| 113 |
pooled_prompt_embeds = prompt_embeds.mean(dim=1)
|
|
|
|
| 114 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 115 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 116 |
return prompt_embeds, pooled_prompt_embeds
|
| 117 |
|
| 118 |
+
# Override the high-level encode_prompt to use T5 encoding and return three outputs.
|
| 119 |
+
def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
|
| 120 |
+
negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
text_inputs = self.tokenizer(
|
| 122 |
prompt,
|
| 123 |
padding="max_length",
|
|
|
|
| 126 |
return_tensors="pt",
|
| 127 |
).to(device)
|
| 128 |
text_outputs = self.text_encoder(**text_inputs)
|
| 129 |
+
# Project T5 embeddings into CLIP space using our projection layer.
|
| 130 |
text_embeddings = self.text_projection(text_outputs.last_hidden_state)
|
|
|
|
| 131 |
pooled_text_embeddings = text_embeddings.mean(dim=1)
|
|
|
|
| 132 |
if do_classifier_free_guidance:
|
| 133 |
+
uncond_input = self.tokenizer(
|
| 134 |
+
[negative_prompt] if negative_prompt else [""],
|
| 135 |
+
padding="max_length",
|
| 136 |
+
max_length=512,
|
| 137 |
+
truncation=True,
|
| 138 |
+
return_tensors="pt",
|
| 139 |
+
).to(device)
|
| 140 |
+
uncond_outputs = self.text_encoder(**uncond_input)
|
| 141 |
+
uncond_embeddings = self.text_projection(uncond_outputs.last_hidden_state)
|
| 142 |
+
pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
|
| 143 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
|
| 144 |
+
pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
|
| 145 |
+
token_ids = text_inputs.input_ids
|
|
|
|
|
|
|
| 146 |
else:
|
| 147 |
+
token_ids = text_inputs.input_ids
|
|
|
|
|
|
|
| 148 |
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 149 |
pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
| 150 |
token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
|
|
|
|
|
|
|
| 151 |
return text_embeddings, pooled_text_embeddings, token_ids
|
| 152 |
|
|
|
|
| 153 |
pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
|
| 154 |
pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
|
| 155 |
pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
|
|
|
|
| 156 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
| 157 |
|
| 158 |
+
# ----- PATCH THE TRANSFORMER'S TIME EMBEDDING LAYER -----
|
| 159 |
+
# Force-override the fixed_text_proj attribute so that it maps from 3072 to 256.
|
| 160 |
pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(device, dtype=dtype)
|
| 161 |
|
| 162 |
def patched_time_embed(self, timestep, guidance, pooled_projections):
|
| 163 |
+
# Compute timestep embedding (expected shape: (B,256))
|
| 164 |
time_out = self.time_proj(timestep)
|
| 165 |
+
# Use the fixed_text_proj we just set.
|
| 166 |
+
text_out = self.fixed_text_proj(pooled_projections)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
return time_out + text_out
|
| 168 |
+
|
| 169 |
+
# Patch the forward method.
|
| 170 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
| 171 |
|
| 172 |
+
# ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
|
| 173 |
def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
|
|
|
|
| 174 |
if image is None:
|
| 175 |
return history
|
|
|
|
|
|
|
| 176 |
from PIL import Image
|
| 177 |
import numpy as np
|
|
|
|
| 178 |
if isinstance(image, np.ndarray):
|
|
|
|
| 179 |
if image.dtype == np.uint8:
|
| 180 |
image = Image.fromarray(image)
|
|
|
|
| 181 |
else:
|
| 182 |
image = Image.fromarray((image * 255).astype(np.uint8))
|
|
|
|
|
|
|
| 183 |
buffered = BytesIO()
|
| 184 |
image.save(buffered, format="PNG")
|
| 185 |
img_bytes = buffered.getvalue()
|
|
|
|
| 186 |
return history + [{
|
| 187 |
+
"image": img_bytes,
|
| 188 |
+
"prompt": prompt,
|
| 189 |
+
"seed": seed,
|
| 190 |
+
"width": width,
|
| 191 |
+
"height": height,
|
| 192 |
+
"guidance_scale": guidance_scale,
|
| 193 |
+
"steps": steps,
|
| 194 |
}]
|
| 195 |
|
| 196 |
def create_history_html(history):
|
| 197 |
html = "<div style='display: flex; flex-direction: column; gap: 20px; margin: 20px;'>"
|
| 198 |
for i, entry in enumerate(reversed(history)):
|
| 199 |
+
img_str = base64.b64encode(entry["image"]).decode()
|
| 200 |
+
html += f"""
|
| 201 |
+
<div style='display: flex; gap: 20px; padding: 20px; background: #f5f5f5; border-radius: 10px;'>
|
| 202 |
+
<img src="data:image/png;base64,{img_str}" style="width: 150px; height: 150px; object-fit: cover; border-radius: 5px;"/>
|
| 203 |
+
<div style='flex: 1;'>
|
| 204 |
+
<h3 style='margin: 0;'>Generation #{len(history)-i}</h3>
|
| 205 |
+
<p><strong>Prompt:</strong> {entry["prompt"]}</p>
|
| 206 |
+
<p><strong>Seed:</strong> {entry["seed"]}</p>
|
| 207 |
+
<p><strong>Size:</strong> {entry["width"]}x{entry["height"]}</p>
|
| 208 |
+
<p><strong>Guidance:</strong> {entry["guidance_scale"]}</p>
|
| 209 |
+
<p><strong>Steps:</strong> {entry["steps"]}</p>
|
| 210 |
+
</div>
|
| 211 |
+
</div>
|
| 212 |
+
"""
|
| 213 |
return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
|
| 214 |
|
|
|
|
| 215 |
@spaces.GPU(duration=75)
|
| 216 |
+
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
|
| 217 |
+
guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
|
| 218 |
if randomize_seed:
|
| 219 |
+
seed = random.randint(0, MAX_SEED)
|
| 220 |
generator = torch.Generator().manual_seed(seed)
|
|
|
|
|
|
|
| 221 |
tokens = t5_tokenizer.encode(prompt)[:512]
|
| 222 |
processed_prompt = t5_tokenizer.decode(tokens, skip_special_tokens=True)
|
|
|
|
| 223 |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
| 224 |
+
prompt=processed_prompt,
|
| 225 |
+
guidance_scale=guidance_scale,
|
| 226 |
+
num_inference_steps=num_inference_steps,
|
| 227 |
+
width=width,
|
| 228 |
+
height=height,
|
| 229 |
+
generator=generator,
|
| 230 |
+
output_type="pil",
|
| 231 |
+
good_vae=good_vae,
|
| 232 |
+
):
|
| 233 |
+
yield img, seed
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def enhance_prompt(user_prompt):
|
|
|
|
| 236 |
try:
|
| 237 |
+
chat_completion = groq_client.chat.completions.create(
|
| 238 |
+
messages=[
|
| 239 |
+
{
|
| 240 |
+
"role": "system",
|
| 241 |
+
"content": (
|
| 242 |
+
"""Enhance user input into prompts that paint a clear picture for image generation. Be precise, detailed and direct, describe not only the content of the image but also such details as tone, style, color palette, and point of view, for photorealistic images, include the name of the device used (e.g., “shot on iPhone 16”), aperture, lens, and shot type. Use precise, visual descriptions (rather than metaphorical concepts).
|
| 243 |
Try to keep prompts to contain only keywords, yet precise, and awe-inspiring.
|
| 244 |
Medium:
|
| 245 |
Consider what form of art this image should be simulating.
|
|
|
|
| 260 |
Photo: Describe type of photography, camera gear, and camera settings. Any specific shot technique? (Comma-separated list of these)
|
| 261 |
Painting: Mention the kind of paint, texture of canvas, and shape/texture of brushstrokes. (List)
|
| 262 |
Digital: Note the software used, shading techniques, and multimedia approaches."""
|
| 263 |
+
),
|
| 264 |
+
},
|
| 265 |
+
{"role": "user", "content": user_prompt}
|
| 266 |
+
],
|
| 267 |
+
model="llama-3.3-70b-versatile",
|
| 268 |
+
temperature=0.5,
|
| 269 |
+
max_completion_tokens=1024,
|
| 270 |
+
top_p=1,
|
| 271 |
+
stop=None,
|
| 272 |
+
stream=False,
|
| 273 |
+
)
|
| 274 |
+
enhanced = chat_completion.choices[0].message.content
|
| 275 |
except Exception as e:
|
| 276 |
+
enhanced = f"Error enhancing prompt: {str(e)}"
|
| 277 |
return enhanced
|
| 278 |
|
|
|
|
| 279 |
css = """
|
| 280 |
#col-container {
|
| 281 |
margin: 0 auto;
|
|
|
|
| 285 |
|
| 286 |
with gr.Blocks(css=css) as demo:
|
| 287 |
history_state = gr.State([])
|
|
|
|
| 288 |
with gr.Column(elem_id="col-container"):
|
| 289 |
+
gr.Markdown("# FLUX.1 [dev] with History Tracking")
|
| 290 |
+
gr.Markdown("### Step 1: Enhance Your Prompt")
|
| 291 |
+
original_prompt = gr.Textbox(label="Original Prompt", lines=2)
|
| 292 |
+
enhance_button = gr.Button("Enhance Prompt")
|
| 293 |
+
enhanced_prompt = gr.Textbox(label="Enhanced Prompt (Editable)", lines=2)
|
| 294 |
+
enhance_button.click(enhance_prompt, original_prompt, enhanced_prompt)
|
| 295 |
+
gr.Markdown("### Step 2: Generate Image")
|
| 296 |
+
with gr.Row():
|
| 297 |
+
run_button = gr.Button("Generate Image", variant="primary")
|
| 298 |
+
result = gr.Image(label="Result", show_label=False)
|
| 299 |
+
with gr.Accordion("Advanced Settings"):
|
| 300 |
+
seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
|
| 301 |
+
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
| 302 |
+
with gr.Row():
|
| 303 |
+
width = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Width")
|
| 304 |
+
height = gr.Slider(256, MAX_IMAGE_SIZE, 1024, step=32, label="Height")
|
| 305 |
+
with gr.Row():
|
| 306 |
+
guidance_scale = gr.Slider(1, 15, 3.5, step=0.1, label="Guidance Scale")
|
| 307 |
+
num_inference_steps = gr.Slider(1, 50, 28, step=1, label="Inference Steps")
|
| 308 |
+
with gr.Accordion("Generation History", open=False):
|
| 309 |
+
history_display = gr.HTML("<p style='margin: 20px;'>No generations yet</p>")
|
| 310 |
+
gr.Examples(
|
| 311 |
+
examples=[
|
| 312 |
+
"a tiny astronaut hatching from an egg on the moon",
|
| 313 |
+
"a cat holding a sign that says hello world",
|
| 314 |
+
"an anime illustration of a wiener schnitzel",
|
| 315 |
+
],
|
| 316 |
+
inputs=enhanced_prompt,
|
| 317 |
+
outputs=[result, seed],
|
| 318 |
+
fn=infer,
|
| 319 |
+
cache_examples="lazy"
|
| 320 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
generation_event = run_button.click(
|
| 322 |
+
fn=infer,
|
| 323 |
+
inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
|
| 324 |
+
outputs=[result, seed]
|
| 325 |
)
|
|
|
|
| 326 |
generation_event.then(
|
| 327 |
+
fn=append_to_history,
|
| 328 |
+
inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
|
| 329 |
+
outputs=history_state
|
| 330 |
).then(
|
| 331 |
+
fn=create_history_html,
|
| 332 |
+
inputs=history_state,
|
| 333 |
+
outputs=history_display
|
| 334 |
)
|
| 335 |
enhanced_prompt.submit(
|
| 336 |
+
fn=infer,
|
| 337 |
+
inputs=[enhanced_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
|
| 338 |
+
outputs=[result, seed]
|
| 339 |
).then(
|
| 340 |
+
fn=append_to_history,
|
| 341 |
+
inputs=[result, enhanced_prompt, seed, width, height, guidance_scale, num_inference_steps, history_state],
|
| 342 |
+
outputs=history_state
|
| 343 |
).then(
|
| 344 |
+
fn=create_history_html,
|
| 345 |
+
inputs=history_state,
|
| 346 |
+
outputs=history_display
|
| 347 |
)
|
| 348 |
+
demo.launch(share=True)
|
|
|