concauu commited on
Commit
2b6f6e8
·
verified ·
1 Parent(s): be5fc1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -192,8 +192,15 @@ pipe.transformer.time_text_embed.fixed_text_proj = nn.Linear(3072, 256).to(devic
192
  def patched_time_embed(self, timestep, guidance, pooled_projections):
193
  # Compute the timestep embedding (expected shape: (B,256))
194
  time_out = self.time_proj(timestep)
195
- # Use our fixed text projection (now with output dimension 256) on the pooled text embeddings.
196
- text_out = self.fixed_text_proj(pooled_projections)
 
 
 
 
 
 
 
197
  return time_out + text_out
198
  # Apply the patch after the pipeline is created and patched with your custom encode methods:
199
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)
 
192
  def patched_time_embed(self, timestep, guidance, pooled_projections):
193
  # Compute the timestep embedding (expected shape: (B,256))
194
  time_out = self.time_proj(timestep)
195
+
196
+ # Ensure fixed_text_proj is set to map from 3072 to 256.
197
+ # If it doesn't exist or its output dimension is not 256, recreate it.
198
+ if (not hasattr(self, "fixed_text_proj")) or (self.fixed_text_proj.out_features != 256):
199
+ self.fixed_text_proj = nn.Linear(3072, 256).to(
200
+ device=pooled_projections.device, dtype=pooled_projections.dtype
201
+ )
202
+
203
+ text_out = self.fixed_text_proj(pooled_projections) # Should produce shape (B,256)
204
  return time_out + text_out
205
  # Apply the patch after the pipeline is created and patched with your custom encode methods:
206
  pipe.transformer.time_text_embed.forward = patched_time_embed.__get__(pipe.transformer.time_text_embed)