Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -192,8 +192,15 @@ pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(devic
|
|
| 192 |
def patched_time_embed(self, timestep, guidance, pooled_projections):
|
| 193 |
# Compute the timestep embedding (expected shape: (B,256))
|
| 194 |
time_out = self.time_proj(timestep)
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
return time_out + text_out
|
| 198 |
# Apply the patch after the pipeline is created and patched with your custom encode methods:
|
| 199 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|
|
|
|
| 192 |
def patched_time_embed(self, timestep, guidance, pooled_projections):
|
| 193 |
# Compute the timestep embedding (expected shape: (B,256))
|
| 194 |
time_out = self.time_proj(timestep)
|
| 195 |
+
|
| 196 |
+
# Ensure fixed_text_proj is set to map from 3072 to 256.
|
| 197 |
+
# If it doesn't exist or its output dimension is not 256, recreate it.
|
| 198 |
+
if (not hasattr(self, "fixed_text_proj")) or (self.fixed_text_proj.out_features != 256):
|
| 199 |
+
self.fixed_text_proj = nn.Linear(3072, 256).to(
|
| 200 |
+
device=pooled_projections.device, dtype=pooled_projections.dtype
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
text_out = self.fixed_text_proj(pooled_projections) # Should produce shape (B,256)
|
| 204 |
return time_out + text_out
|
| 205 |
# Apply the patch after the pipeline is created and patched with your custom encode methods:
|
| 206 |
pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
|