sageattn-xpu-profiling / bench_optimizations.py
Yi30's picture
Add FlexAttention profiling traces, benchmark script, and updated analysis
8f3b910 verified
"""
CogVideoX-2B Optimization Benchmark — Intel XPU
Compares 4 configurations on the same pipeline:
1. baseline — stock CogVideoXAttnProcessor2_0 (SDPA)
2. flex_attention — FlexAttention processor (compiled fused kernel)
3. torch_compile — torch.compile(transformer) with SDPA
4. flex_attn + compile — both combined
Each config: 1 warmup run (JIT) + 3 timed runs → mean ± std, peak memory, speedup.
Usage:
cd tasks/profile-cogvideo
uv run python bench_optimizations.py
"""
import gc
import time
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention
from diffusers import CogVideoXPipeline
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
NUM_INFERENCE_STEPS = 4
NUM_FRAMES = 49
GUIDANCE_SCALE = 6
DTYPE = torch.float16
SEED = 42
MODEL_PATH = "/mnt/data4/yiliu/zai-org/CogVideoX-2b"
WARMUP_RUNS = 1
TIMED_RUNS = 3
PROMPT = (
"A panda, dressed in a small, red jacket and a tiny hat, "
"sits on a wooden stool in a serene bamboo forest. "
"The panda gently sways as a breeze rustles through the tall bamboo stalks."
)
# ---------------------------------------------------------------------------
# FlexAttention Processor — drop-in replacement via diffusers plugin API
# ---------------------------------------------------------------------------
# Pre-compile flex_attention once (triggers Triton codegen on first call)
_compiled_flex_attention = torch.compile(flex_attention)
class CogVideoXFlexAttnProcessor:
"""
Same as CogVideoXAttnProcessor2_0 but replaces F.scaled_dot_product_attention
with torch.compile(flex_attention) for fused attention on XPU.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = hidden_states.shape
# Q / K / V projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)
# >>> FlexAttention replaces F.scaled_dot_product_attention <<<
# CogVideoX never passes attention_mask in its transformer blocks,
# so we use flex_attention without score_mod or block_mask.
hidden_states = _compiled_flex_attention(query, key, value)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def mb(b: int) -> float:
return b / (1024 * 1024)
def run_inference(pipe):
"""Run one inference pass and return wall-clock seconds."""
torch.xpu.synchronize()
torch.xpu.reset_peak_memory_stats()
t0 = time.perf_counter()
pipe(
prompt=PROMPT,
num_videos_per_prompt=1,
num_inference_steps=NUM_INFERENCE_STEPS,
num_frames=NUM_FRAMES,
guidance_scale=GUIDANCE_SCALE,
generator=torch.Generator(device="xpu").manual_seed(SEED),
)
torch.xpu.synchronize()
elapsed = time.perf_counter() - t0
peak_mem = mb(torch.xpu.max_memory_allocated())
return elapsed, peak_mem
@dataclass
class BenchResult:
name: str
times: list[float] = field(default_factory=list)
peak_mems: list[float] = field(default_factory=list)
@property
def mean_time(self) -> float:
return sum(self.times) / len(self.times)
@property
def std_time(self) -> float:
m = self.mean_time
return (sum((t - m) ** 2 for t in self.times) / len(self.times)) ** 0.5
@property
def max_peak_mem(self) -> float:
return max(self.peak_mems)
def bench_config(pipe, name: str) -> BenchResult:
"""Warmup + timed runs for a single configuration."""
result = BenchResult(name=name)
print(f"\n [{name}] warmup ({WARMUP_RUNS} run)...")
for _ in range(WARMUP_RUNS):
run_inference(pipe)
print(f" [{name}] timing ({TIMED_RUNS} runs)...")
for i in range(TIMED_RUNS):
elapsed, peak_mem = run_inference(pipe)
result.times.append(elapsed)
result.peak_mems.append(peak_mem)
print(f" run {i+1}: {elapsed:.2f}s peak_mem={peak_mem:.0f} MB")
return result
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print("=" * 72)
print(f"CogVideoX-2B Optimization Benchmark")
print(f" {NUM_INFERENCE_STEPS} steps, {NUM_FRAMES} frames, {DTYPE}")
print(f" warmup={WARMUP_RUNS}, timed={TIMED_RUNS}")
print(f" Device: {torch.xpu.get_device_name(0)}")
print(f" PyTorch: {torch.__version__}")
print("=" * 72)
# ------------------------------------------------------------------
# Load pipeline once
# ------------------------------------------------------------------
print("\nLoading pipeline...")
pipe = CogVideoXPipeline.from_pretrained(MODEL_PATH, torch_dtype=DTYPE)
pipe = pipe.to("xpu")
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
results: list[BenchResult] = []
# ------------------------------------------------------------------
# 1. Baseline — SDPA (uncompiled)
# ------------------------------------------------------------------
print("\n" + "-" * 72)
print("Config 1/4: baseline (SDPA, uncompiled)")
print("-" * 72)
results.append(bench_config(pipe, "baseline"))
# ------------------------------------------------------------------
# 2. FlexAttention only (uncompiled transformer)
# ------------------------------------------------------------------
print("\n" + "-" * 72)
print("Config 2/4: flex_attention")
print("-" * 72)
pipe.transformer.set_attn_processor(CogVideoXFlexAttnProcessor())
results.append(bench_config(pipe, "flex_attention"))
# Restore baseline processor before compile step
pipe.transformer.set_attn_processor(CogVideoXAttnProcessor2_0())
# ------------------------------------------------------------------
# 3. torch.compile only — SDPA + compiled transformer
# NOTE: compile is irreversible, so this must come after configs 1-2
# ------------------------------------------------------------------
print("\n" + "-" * 72)
print("Config 3/4: torch_compile (SDPA + compiled transformer)")
print("-" * 72)
pipe.transformer = torch.compile(pipe.transformer, backend="inductor")
results.append(bench_config(pipe, "torch_compile"))
# ------------------------------------------------------------------
# 4. FlexAttention + compile — swap processor on compiled transformer
# ------------------------------------------------------------------
print("\n" + "-" * 72)
print("Config 4/4: flex_attention + torch_compile")
print("-" * 72)
# set_attn_processor works on the underlying module inside the compiled wrapper
pipe.transformer.set_attn_processor(CogVideoXFlexAttnProcessor())
results.append(bench_config(pipe, "flex_attn + compile"))
# ------------------------------------------------------------------
# Results table
# ------------------------------------------------------------------
baseline_mean = results[0].mean_time
print("\n")
print("=" * 72)
print(f"CogVideoX-2B Optimization Benchmark Results")
print(f" ({NUM_INFERENCE_STEPS} steps, {NUM_FRAMES} frames, {DTYPE})")
print("=" * 72)
header = (
f" {'Config':<24s} {'Mean (s)':>9s} {'Std (s)':>8s} "
f"{'Peak Mem (MB)':>14s} {'Speedup':>8s}"
)
print(header)
print(f" {'-'*24} {'-'*9} {'-'*8} {'-'*14} {'-'*8}")
for r in results:
speedup = baseline_mean / r.mean_time
print(
f" {r.name:<24s} {r.mean_time:>9.2f} {r.std_time:>8.2f} "
f"{r.max_peak_mem:>14.0f} {speedup:>7.2f}x"
)
print()
print("Notes:")
print(" - Speedup is relative to the baseline (SDPA, uncompiled)")
print(" - Peak Mem is the max across timed runs (excludes warmup)")
print(" - flex_attention uses torch.compile(flex_attention) for fused kernels")
print(" - torch_compile uses torch.compile(transformer, backend='inductor')")
print()
if __name__ == "__main__":
main()