Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -187,24 +187,14 @@ pipe.encode_prompt = custom_encode_prompt.__get__(pipe)
|
|
| 187 |
|
| 188 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
| 189 |
|
|
|
|
| 190 |
|
| 191 |
def patched_time_embed(self, timestep, guidance, pooled_projections):
|
| 192 |
# Compute the timestep embedding (expected shape: (B,256))
|
| 193 |
time_out = self.time_proj(timestep)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
# matching the dimension of time_out.
|
| 197 |
-
if not hasattr(self, "fixed_text_proj"):
|
| 198 |
-
self.fixed_text_proj = nn.Linear(3072, 256).to(
|
| 199 |
-
device=pooled_projections.device, dtype=pooled_projections.dtype
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
# Apply the new projection to pooled text embeddings.
|
| 203 |
-
text_out = self.fixed_text_proj(pooled_projections) # now shape (B,256)
|
| 204 |
-
|
| 205 |
return time_out + text_out
|
| 206 |
-
|
| 207 |
-
|
| 208 |
# Apply the patch after the pipeline is created and patched with your custom encode methods:
|
| 209 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
| 210 |
|
|
|
|
| 187 |
|
| 188 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
| 189 |
|
| 190 |
+
pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(device, dtype=dtype)
|
| 191 |
|
| 192 |
def patched_time_embed(self, timestep, guidance, pooled_projections):
|
| 193 |
# Compute the timestep embedding (expected shape: (B,256))
|
| 194 |
time_out = self.time_proj(timestep)
|
| 195 |
+
# Use our fixed text projection (now with output dimension 256) on the pooled text embeddings.
|
| 196 |
+
text_out = self.fixed_text_proj(pooled_projections)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
return time_out + text_out
|
|
|
|
|
|
|
| 198 |
# Apply the patch after the pipeline is created and patched with your custom encode methods:
|
| 199 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
| 200 |
|