Enable Flash Attention by trimming prompt embedding padding

#8
Files changed (1) hide show
  1. pipeline_motif_video.py +25 -6
pipeline_motif_video.py CHANGED
@@ -539,6 +539,11 @@ class MotifVideoPipeline(DiffusionPipeline):
539
  **prompt_embeds_kwargs,
540
  )
541
 
 
 
 
 
 
542
  # duplicate text embeddings for each generation per prompt, using mps friendly method
543
  seq_len = prompt_embeds.shape[1]
544
  prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
@@ -547,15 +552,16 @@ class MotifVideoPipeline(DiffusionPipeline):
547
  if pooled_prompt_embeds is not None:
548
  pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
549
 
550
- # Keep attention mask handling
551
- prompt_attention_mask = prompt_attention_mask.bool()
552
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
553
- prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
554
 
555
  return (
556
  prompt_embeds,
557
  pooled_prompt_embeds,
558
  prompt_attention_mask,
 
559
  )
560
 
561
  @property
@@ -1081,7 +1087,7 @@ class MotifVideoPipeline(DiffusionPipeline):
1081
  device = self._execution_device
1082
 
1083
  # 3. Prepare text embeddings
1084
- prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
1085
  prompt=prompt,
1086
  num_videos_per_prompt=num_videos_per_prompt,
1087
  prompt_embeds=prompt_embeds,
@@ -1091,12 +1097,17 @@ class MotifVideoPipeline(DiffusionPipeline):
1091
  device=device,
1092
  )
1093
 
 
 
 
 
1094
  if self.guider._enabled:
1095
  negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
1096
  (
1097
  negative_prompt_embeds,
1098
  negative_pooled_prompt_embeds,
1099
  negative_prompt_attention_mask,
 
1100
  ) = self.encode_prompt(
1101
  prompt=negative_prompt,
1102
  num_videos_per_prompt=num_videos_per_prompt,
@@ -1107,6 +1118,14 @@ class MotifVideoPipeline(DiffusionPipeline):
1107
  device=device,
1108
  )
1109
 
 
 
 
 
 
 
 
 
1110
  num_channels_latents = self.vae.config.z_dim
1111
  latents = self.prepare_latents(
1112
  batch_size * num_videos_per_prompt,
@@ -1229,7 +1248,7 @@ class MotifVideoPipeline(DiffusionPipeline):
1229
  guider_inputs = {
1230
  "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
1231
  }
1232
- if use_attention_mask:
1233
  guider_inputs["encoder_attention_mask"] = (
1234
  prompt_attention_mask,
1235
  negative_prompt_attention_mask,
 
539
  **prompt_embeds_kwargs,
540
  )
541
 
542
+ # Compute actual (non-padding) token count for batch=1 Flash Attention trimming in __call__
543
+ actual_seq_len = None
544
+ if batch_size == 1 and prompt_attention_mask is not None:
545
+ actual_seq_len = int(prompt_attention_mask.sum(dim=-1).max().item())
546
+
547
  # duplicate text embeddings for each generation per prompt, using mps friendly method
548
  seq_len = prompt_embeds.shape[1]
549
  prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
 
552
  if pooled_prompt_embeds is not None:
553
  pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
554
 
555
+ if prompt_attention_mask is not None:
556
+ prompt_attention_mask = prompt_attention_mask.bool()
557
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
558
+ prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
559
 
560
  return (
561
  prompt_embeds,
562
  pooled_prompt_embeds,
563
  prompt_attention_mask,
564
+ actual_seq_len,
565
  )
566
 
567
  @property
 
1087
  device = self._execution_device
1088
 
1089
  # 3. Prepare text embeddings
1090
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask, pos_actual_len = self.encode_prompt(
1091
  prompt=prompt,
1092
  num_videos_per_prompt=num_videos_per_prompt,
1093
  prompt_embeds=prompt_embeds,
 
1097
  device=device,
1098
  )
1099
 
1100
+ if not self.guider._enabled and pos_actual_len is not None:
1101
+ prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
1102
+ prompt_attention_mask = None
1103
+
1104
  if self.guider._enabled:
1105
  negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
1106
  (
1107
  negative_prompt_embeds,
1108
  negative_pooled_prompt_embeds,
1109
  negative_prompt_attention_mask,
1110
+ neg_actual_len,
1111
  ) = self.encode_prompt(
1112
  prompt=negative_prompt,
1113
  num_videos_per_prompt=num_videos_per_prompt,
 
1118
  device=device,
1119
  )
1120
 
1121
+ # Trim each to its own actual length — guider runs pos/neg in separate loop iterations,
1122
+ # so different seq lengths are fine. No padding embeddings attend without mask.
1123
+ if pos_actual_len is not None and neg_actual_len is not None:
1124
+ prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
1125
+ negative_prompt_embeds = negative_prompt_embeds[:, :neg_actual_len, :]
1126
+ prompt_attention_mask = None
1127
+ negative_prompt_attention_mask = None
1128
+
1129
  num_channels_latents = self.vae.config.z_dim
1130
  latents = self.prepare_latents(
1131
  batch_size * num_videos_per_prompt,
 
1248
  guider_inputs = {
1249
  "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
1250
  }
1251
+ if use_attention_mask and prompt_attention_mask is not None:
1252
  guider_inputs["encoder_attention_mask"] = (
1253
  prompt_attention_mask,
1254
  negative_prompt_attention_mask,