Update pipeline.py

#3
by linoyts HF Staff - opened
Files changed (1) hide show
  1. pipeline.py +4 -2
pipeline.py CHANGED
@@ -1233,8 +1233,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
1233
 
1234
  # If we have concat conditioning, extend video_coords with concat_positions
1235
  if concat_positions is not None:
1236
- # video_coords is [B, 3, base_num_tokens]
1237
- # concat_positions is [B, 3, concat_num_tokens]
 
 
1238
  video_coords = torch.cat([video_coords, concat_positions], dim=2)
1239
 
1240
  audio_coords = self.transformer.audio_rope.prepare_audio_coords(
 
1233
 
1234
  # If we have concat conditioning, extend video_coords with concat_positions
1235
  if concat_positions is not None:
1236
+ # video_coords is [B, 3, base_num_tokens, 2]
1237
+ # concat_positions is [B, 3, concat_num_tokens] - need to expand to 4D
1238
+ # Add the last dimension by expanding to match video_coords shape
1239
+ concat_positions = concat_positions.unsqueeze(-1).expand(-1, -1, -1, video_coords.shape[-1])
1240
  video_coords = torch.cat([video_coords, concat_positions], dim=2)
1241
 
1242
  audio_coords = self.transformer.audio_rope.prepare_audio_coords(