Enable Flash Attention by trimming prompt embedding padding

#8
by gkalstn0 - opened
Motif Technologies org

Summary

  • Trim prompt_embeds to actual token length (removing padding) for batch_size=1 inference
  • Pass attention_mask=None to transformer, allowing PyTorch SDPA to use Flash Attention backend
  • Positive and negative prompts trimmed independently (guider runs them in separate iterations)
  • batch_size>1 preserves original attention_mask path for variable-length prompt compatibility

Changes

  • pipeline_motif_video.py: encode_prompt() computes actual_seq_len, call() trims embeddings and drops mask
  • No transformer code changes needed (existing None-guard handles it)

Test plan

  • batch=1 with CFG: trim confirmed (pos 512->117, neg 512->113)
  • batch>1: mask path preserved (no trim)
  • Video output quality verified (720p 121f 50 steps)
  • I2V compatibility: transformer handles encoder_attention_mask=None safely
gkalstn0 changed pull request status to open
gkalstn0 changed pull request status to merged

Sign up or log in to comment