concauu commited on
Commit
be5fc1c
·
verified ·
1 Parent(s): dbc3b23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -13
app.py CHANGED
@@ -187,24 +187,14 @@ pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
187
 
188
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
189
 
 
190
 
191
  def patched_time_embed(self, timestep, guidance, pooled_projections):
192
  # Compute the timestep embedding (expected shape: (B,256))
193
  time_out = self.time_proj(timestep)
194
-
195
- # On first call, attach a new linear layer to project pooled_projections from 3072 to 256,
196
- # matching the dimension of time_out.
197
- if not hasattr(self, "fixed_text_proj"):
198
- self.fixed_text_proj = nn.Linear(3072, 256).to(
199
- device=pooled_projections.device, dtype=pooled_projections.dtype
200
- )
201
-
202
- # Apply the new projection to pooled text embeddings.
203
- text_out = self.fixed_text_proj(pooled_projections) # now shape (B,256)
204
-
205
  return time_out + text_out
206
-
207
-
208
  # Apply the patch after the pipeline is created and patched with your custom encode methods:
209
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
210
 
 
187
 
188
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
189
 
190
+ pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(device, dtype=dtype)
191
 
192
  def patched_time_embed(self, timestep, guidance, pooled_projections):
193
  # Compute the timestep embedding (expected shape: (B,256))
194
  time_out = self.time_proj(timestep)
195
+ # Use our fixed text projection (now with output dimension 256) on the pooled text embeddings.
196
+ text_out = self.fixed_text_proj(pooled_projections)
 
 
 
 
 
 
 
 
 
197
  return time_out + text_out
 
 
198
  # Apply the patch after the pipeline is created and patched with your custom encode methods:
199
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
200