Buckets:

Mercity/FluxDistill / scripts /09_build_linattn.py
Pranav2748's picture
download
raw
3.66 kB
"""Build a student with the full upgraded linear-attention surrogate
(RoPE + depthwise-conv + focused/learnable feature map + FFN = a real linear-attn block),
dropping DROP_K blocks chosen by importance, warm-started to mimic the teacher blocks.
Saves to outputs/student_linattn3/. Usage: python3 scripts/09_build_linattn.py [drop_k]
"""
import json
import os
import sys
import torch
from flux2distill.model_utils import load_pipeline, param_summary
from flux2distill.surgery import (importance_by_ablation, select_blocks_by_importance,
capture_block_io_seq, attach_surrogates)
from flux2distill.calibration import fit_surrogate_to_block
DROP_K = int(sys.argv[1]) if len(sys.argv) > 1 else 10
N_FFN = int(sys.argv[2]) if len(sys.argv) > 2 else 4 # FFN on the N_FFN deepest dropped blocks
HEADS, HEAD_DIM, CONV_K, FFN_HIDDEN = 4, 128, 5, 1024
OUT = "outputs/student_linattn4"
os.makedirs(OUT, exist_ok=True)
PROMPTS = [
"a serene mountain lake at sunrise reflecting snow-capped peaks",
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a close-up portrait of an elderly fisherman with a weathered face",
"a futuristic city skyline at night with neon lights",
]
print("loading teacher pipeline...")
pipe = load_pipeline(device="cuda")
print(f"measuring per-block importance (leave-one-out ablation)...")
imp = importance_by_ablation(pipe, PROMPTS, num_inference_steps=4, seed=0)
keep_idx, drop_idx = select_blocks_by_importance(imp, DROP_K)
ffn_idx = sorted(drop_idx)[-N_FFN:] # FFN on the N_FFN deepest dropped blocks (incl. the last)
print(f"DROP {len(drop_idx)} least-important: {drop_idx}")
print(f"KEEP {len(keep_idx)}: {keep_idx}")
print(f"FFN on {len(ffn_idx)} deepest dropped: {ffn_idx} (the other {DROP_K-N_FFN} are light)")
print("capturing teacher block I/O sequences (+rotary) for warm-start...")
io, rotary = capture_block_io_seq(pipe, PROMPTS, drop_idx, num_inference_steps=4, max_seqs=16)
attach_surrogates(pipe.transformer, drop_idx, kind="linear_attention", heads=HEADS, head_dim=HEAD_DIM,
conv_kernel=CONV_K, ffn_hidden=FFN_HIDDEN, ffn_idx=ffn_idx, device="cuda", dtype=torch.bfloat16)
ps = param_summary(pipe.transformer)
print("student params:", ps)
print("warm-start: fitting each surrogate (attn+conv+focus+FFN) to mimic its teacher block...")
for i in drop_idx:
sur = pipe.transformer.single_transformer_blocks[i]
ie, fe = fit_surrogate_to_block(sur, io[i]["X"], io[i]["Y"], rotary, steps=400, lr=1e-3)
sur.to(dtype=torch.bfloat16)
print(f" blk{i:2d}: residual rel-err {ie:.3f} -> {fe:.3f}")
torch.save(pipe.transformer.state_dict(), f"{OUT}/student_state.pt")
meta = {"keep_idx": keep_idx, "surrogate_idx": drop_idx, "drop_k": DROP_K, "kind": "linear_attention",
"heads": HEADS, "head_dim": HEAD_DIM, "conv_kernel": CONV_K, "ffn_hidden": FFN_HIDDEN,
"ffn_idx": ffn_idx, "importance": imp, "param_summary": ps}
json.dump(meta, open(f"{OUT}/selection.json", "w"), indent=2)
print(f"saved -> {OUT}/")
os.makedirs("outputs/student_linattn4_smoke", exist_ok=True)
SMOKE = ['a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water"]
g = torch.Generator(device="cuda").manual_seed(0)
with torch.autocast("cuda", dtype=torch.bfloat16):
out = pipe(prompt=SMOKE, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=g)
for i, im in enumerate(out.images):
im.save(f"outputs/student_linattn4_smoke/ls3_{i}.png")
print("saved warm-started smoke images")

Xet Storage Details

Size:
3.66 kB
·
Xet hash:
f6de4b3c66ec97e416c20d65e76c2cdc854b8f2d220bdb9bb009b8a2005eb446

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.