concauu commited on
Commit
343469f
·
verified ·
1 Parent(s): 8686600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -134
app.py CHANGED
@@ -39,135 +39,19 @@ login(token=decrypted_token)
39
  groq_client = Groq(api_key="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")
40
 
41
 
42
- # Load T5 components for longer context
43
- t5_tokenizer = T5Tokenizer.from_pretrained(
44
- "google-t5/t5-base",
45
- legacy=False,
46
- model_max_length=512
47
- )
48
- t5_text_encoder = T5EncoderModel.from_pretrained(
49
- "google-t5/t5-base",
50
- torch_dtype=dtype
51
- ).to(device)
52
-
53
- # --- UPDATED PROJECTION LAYER ---
54
- # Now project from 768 to 4096 (instead of 3072)
55
- class TextProjection(torch.nn.Module):
56
- def __init__(self):
57
- super().__init__()
58
- self.proj = torch.nn.Linear(768, 4096) # Updated: 4096 output features
59
- torch.nn.init.normal_(self.proj.weight, std=0.02)
60
-
61
- def forward(self, x):
62
- return self.proj(x.to(dtype))
63
-
64
- # Custom pipeline with T5 support
65
- class T5FluxPipeline(FluxPipeline):
66
- def _get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
67
- """Modified to work with T5 outputs (without classifier-free guidance handling)"""
68
- text_inputs = self.tokenizer(
69
- prompt,
70
- padding="max_length",
71
- max_length=512,
72
- truncation=True,
73
- return_tensors="pt",
74
- ).to(device)
75
- text_outputs = self.text_encoder(**text_inputs)
76
- prompt_embeds = text_outputs.last_hidden_state
77
- pooled_prompt_embeds = prompt_embeds.mean(dim=1)
78
- prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
79
- pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
80
- return prompt_embeds, pooled_prompt_embeds
81
 
82
- # Initialize pipeline components
83
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
84
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
85
- pipe = T5FluxPipeline.from_pretrained(
86
- "black-forest-labs/FLUX.1-dev",
87
- text_encoder=t5_text_encoder,
88
- tokenizer=t5_tokenizer,
89
- torch_dtype=dtype,
90
- vae=taef1,
91
- safety_checker=None
92
- ).to(device)
93
-
94
- # Add our updated projection layer to the pipeline
95
- pipe.text_projection = TextProjection().to(device, dtype=dtype)
96
  torch.cuda.empty_cache()
97
 
98
  MAX_SEED = np.iinfo(np.int32).max
99
  MAX_IMAGE_SIZE = 2048
100
 
101
- # Custom low-level CLIP prompt embedder override
102
- def custom_get_clip_prompt_embeds(self, prompt, num_images_per_prompt, device):
103
- text_inputs = self.tokenizer(
104
- prompt,
105
- padding="max_length",
106
- max_length=512,
107
- truncation=True,
108
- return_tensors="pt",
109
- ).to(device)
110
- text_outputs = self.text_encoder(**text_inputs)
111
- prompt_embeds = text_outputs.last_hidden_state
112
- pooled_prompt_embeds = prompt_embeds.mean(dim=1)
113
- prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
114
- pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
115
- return prompt_embeds, pooled_prompt_embeds
116
-
117
- # Override the high-level encode_prompt to use T5 encoding and return three outputs.
118
- # --- KEY CHANGE: Return token_ids as a single tensor.
119
- def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance=False,
120
- negative_prompt=None, prompt_embeds=None, prompt_2=None, **kwargs):
121
- text_inputs = self.tokenizer(
122
- prompt,
123
- padding="max_length",
124
- max_length=512,
125
- truncation=True,
126
- return_tensors="pt",
127
- ).to(device)
128
- text_outputs = self.text_encoder(**text_inputs)
129
- # Project T5 embeddings into CLIP space using our updated projection layer.
130
- text_embeddings = self.text_projection(text_outputs.last_hidden_state)
131
- pooled_text_embeddings = text_embeddings.mean(dim=1)
132
- if do_classifier_free_guidance:
133
- uncond_input = self.tokenizer(
134
- [negative_prompt] if negative_prompt else [""],
135
- padding="max_length",
136
- max_length=512,
137
- truncation=True,
138
- return_tensors="pt",
139
- ).to(device)
140
- uncond_outputs = self.text_encoder(**uncond_input)
141
- uncond_embeddings = self.text_projection(uncond_outputs.last_hidden_state)
142
- pooled_uncond_embeddings = uncond_embeddings.mean(dim=1)
143
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)
144
- pooled_text_embeddings = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings], dim=0)
145
- token_ids = text_inputs.input_ids
146
- else:
147
- token_ids = text_inputs.input_ids
148
- text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
149
- pooled_text_embeddings = pooled_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
150
- token_ids = token_ids.repeat_interleave(num_images_per_prompt, dim=0)
151
- return text_embeddings, pooled_text_embeddings, token_ids
152
-
153
- pipe._get_clip_prompt_embeds = custom_get_clip_prompt_embeds.__get__(pipe)
154
- pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
155
- pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
156
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
157
 
158
- # ----- PATCH THE TRANSFORMER'S TIME EMBEDDING LAYER -----
159
- # Force-override the fixed_text_proj attribute so that it maps from 4096 to 256.
160
- pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(4096, 256).to(device, dtype=dtype)
161
-
162
- def patched_time_embed(self, timestep, guidance, pooled_projections):
163
- # Compute timestep embedding (expected shape: (B,256))
164
- time_out = self.time_proj(timestep)
165
- # Use the pre-assigned fixed_text_proj (mapping from 4096 to 256)
166
- text_out = self.fixed_text_proj(pooled_projections)
167
- return time_out + text_out
168
-
169
- pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
170
-
171
  # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
172
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
173
  if image is None:
@@ -212,24 +96,22 @@ def create_history_html(history):
212
  return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
213
 
214
  @spaces.GPU(duration=75)
215
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
216
- guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
217
  if randomize_seed:
218
- seed = random.randint(0, MAX_SEED)
219
  generator = torch.Generator().manual_seed(seed)
220
- tokens = t5_tokenizer.encode(prompt)[:512]
221
- processed_prompt = t5_tokenizer.decode(tokens, skip_special_tokens=True)
222
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
223
- prompt=processed_prompt,
224
- guidance_scale=guidance_scale,
225
- num_inference_steps=num_inference_steps,
226
- width=width,
227
- height=height,
228
- generator=generator,
229
- output_type="pil",
230
- good_vae=good_vae,
231
- ):
232
- yield img, seed
233
 
234
  def enhance_prompt(user_prompt):
235
  try:
 
39
  groq_client = Groq(api_key="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")
40
 
41
 
42
+ dtype = torch.bfloat16
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
45
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
46
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
47
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
 
 
 
 
 
 
 
 
 
 
48
  torch.cuda.empty_cache()
49
 
50
  MAX_SEED = np.iinfo(np.int32).max
51
  MAX_IMAGE_SIZE = 2048
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # ----- HISTORY FUNCTIONS & GRADIO INTERFACE -----
56
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
57
  if image is None:
 
96
  return html + "</div>" if history else "<p style='margin: 20px;'>No generations yet</p>"
97
 
98
  @spaces.GPU(duration=75)
99
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
100
  if randomize_seed:
101
+ seed = random.randint(0, MAX_SEED)
102
  generator = torch.Generator().manual_seed(seed)
103
+
 
104
  for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
105
+ prompt=prompt,
106
+ guidance_scale=guidance_scale,
107
+ num_inference_steps=num_inference_steps,
108
+ width=width,
109
+ height=height,
110
+ generator=generator,
111
+ output_type="pil",
112
+ good_vae=good_vae,
113
+ ):
114
+ yield img, seed
115
 
116
  def enhance_prompt(user_prompt):
117
  try: