gkalstn0 Claude Opus 4.6 (1M context) commited on
Commit
eabf2fa
·
1 Parent(s): d1a46b1

feat: move prompt_embeds trim to __call__ for correct CFG alignment

Browse files

encode_prompt now returns actual_seq_len instead of trimming directly.
__call__ trims both positive/negative embeds to max(pos_len, neg_len)
using real encoder outputs (no zero-padding). Enables Flash Attention
by setting attention_mask=None after trim.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. pipeline_motif_video.py +20 -5
pipeline_motif_video.py CHANGED
@@ -539,11 +539,10 @@ class MotifVideoPipeline(DiffusionPipeline):
539
  **prompt_embeds_kwargs,
540
  )
541
 
542
- # Trim padding for batch=1 to enable Flash Attention (attn_mask=None SDPA uses Flash backend)
 
543
  if batch_size == 1 and prompt_attention_mask is not None:
544
- actual_len = prompt_attention_mask.sum(dim=-1).max().item()
545
- prompt_embeds = prompt_embeds[:, :actual_len, :]
546
- prompt_attention_mask = None
547
 
548
  # duplicate text embeddings for each generation per prompt, using mps friendly method
549
  seq_len = prompt_embeds.shape[1]
@@ -562,6 +561,7 @@ class MotifVideoPipeline(DiffusionPipeline):
562
  prompt_embeds,
563
  pooled_prompt_embeds,
564
  prompt_attention_mask,
 
565
  )
566
 
567
  @property
@@ -1087,7 +1087,7 @@ class MotifVideoPipeline(DiffusionPipeline):
1087
  device = self._execution_device
1088
 
1089
  # 3. Prepare text embeddings
1090
- prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
1091
  prompt=prompt,
1092
  num_videos_per_prompt=num_videos_per_prompt,
1093
  prompt_embeds=prompt_embeds,
@@ -1097,12 +1097,18 @@ class MotifVideoPipeline(DiffusionPipeline):
1097
  device=device,
1098
  )
1099
 
 
 
 
 
 
1100
  if self.guider._enabled:
1101
  negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
1102
  (
1103
  negative_prompt_embeds,
1104
  negative_pooled_prompt_embeds,
1105
  negative_prompt_attention_mask,
 
1106
  ) = self.encode_prompt(
1107
  prompt=negative_prompt,
1108
  num_videos_per_prompt=num_videos_per_prompt,
@@ -1113,6 +1119,15 @@ class MotifVideoPipeline(DiffusionPipeline):
1113
  device=device,
1114
  )
1115
 
 
 
 
 
 
 
 
 
 
1116
  num_channels_latents = self.vae.config.z_dim
1117
  latents = self.prepare_latents(
1118
  batch_size * num_videos_per_prompt,
 
539
  **prompt_embeds_kwargs,
540
  )
541
 
542
+ # Compute actual (non-padding) token count for batch=1 Flash Attention trimming in __call__
543
+ actual_seq_len = None
544
  if batch_size == 1 and prompt_attention_mask is not None:
545
+ actual_seq_len = int(prompt_attention_mask.sum(dim=-1).max().item())
 
 
546
 
547
  # duplicate text embeddings for each generation per prompt, using mps friendly method
548
  seq_len = prompt_embeds.shape[1]
 
561
  prompt_embeds,
562
  pooled_prompt_embeds,
563
  prompt_attention_mask,
564
+ actual_seq_len,
565
  )
566
 
567
  @property
 
1087
  device = self._execution_device
1088
 
1089
  # 3. Prepare text embeddings
1090
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask, pos_actual_len = self.encode_prompt(
1091
  prompt=prompt,
1092
  num_videos_per_prompt=num_videos_per_prompt,
1093
  prompt_embeds=prompt_embeds,
 
1097
  device=device,
1098
  )
1099
 
1100
+ if not self.guider._enabled and pos_actual_len is not None:
1101
+ # No CFG: trim positive only
1102
+ prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
1103
+ prompt_attention_mask = None
1104
+
1105
  if self.guider._enabled:
1106
  negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
1107
  (
1108
  negative_prompt_embeds,
1109
  negative_pooled_prompt_embeds,
1110
  negative_prompt_attention_mask,
1111
+ neg_actual_len,
1112
  ) = self.encode_prompt(
1113
  prompt=negative_prompt,
1114
  num_videos_per_prompt=num_videos_per_prompt,
 
1119
  device=device,
1120
  )
1121
 
1122
+ # Trim prompt_embeds for batch=1 to enable Flash Attention (attn_mask=None → SDPA uses Flash backend).
1123
+ # Use max(pos, neg) actual_len so both have real encoder embeddings at every position (no zero-padding).
1124
+ if pos_actual_len is not None and neg_actual_len is not None:
1125
+ trim_len = max(pos_actual_len, neg_actual_len)
1126
+ prompt_embeds = prompt_embeds[:, :trim_len, :]
1127
+ negative_prompt_embeds = negative_prompt_embeds[:, :trim_len, :]
1128
+ prompt_attention_mask = None
1129
+ negative_prompt_attention_mask = None
1130
+
1131
  num_channels_latents = self.vae.config.z_dim
1132
  latents = self.prepare_latents(
1133
  batch_size * num_videos_per_prompt,