Buckets:
| """Build a student with the full upgraded linear-attention surrogate | |
| (RoPE + depthwise-conv + focused/learnable feature map + FFN = a real linear-attn block), | |
| dropping DROP_K blocks chosen by importance, warm-started to mimic the teacher blocks. | |
| Saves to outputs/student_linattn3/. Usage: python3 scripts/09_build_linattn.py [drop_k] | |
| """ | |
| import json | |
| import os | |
| import sys | |
| import torch | |
| from flux2distill.model_utils import load_pipeline, param_summary | |
| from flux2distill.surgery import (importance_by_ablation, select_blocks_by_importance, | |
| capture_block_io_seq, attach_surrogates) | |
| from flux2distill.calibration import fit_surrogate_to_block | |
| DROP_K = int(sys.argv[1]) if len(sys.argv) > 1 else 10 | |
| N_FFN = int(sys.argv[2]) if len(sys.argv) > 2 else 4 # FFN on the N_FFN deepest dropped blocks | |
| HEADS, HEAD_DIM, CONV_K, FFN_HIDDEN = 4, 128, 5, 1024 | |
| OUT = "outputs/student_linattn4" | |
| os.makedirs(OUT, exist_ok=True) | |
| PROMPTS = [ | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks", | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a close-up portrait of an elderly fisherman with a weathered face", | |
| "a futuristic city skyline at night with neon lights", | |
| ] | |
| print("loading teacher pipeline...") | |
| pipe = load_pipeline(device="cuda") | |
| print(f"measuring per-block importance (leave-one-out ablation)...") | |
| imp = importance_by_ablation(pipe, PROMPTS, num_inference_steps=4, seed=0) | |
| keep_idx, drop_idx = select_blocks_by_importance(imp, DROP_K) | |
| ffn_idx = sorted(drop_idx)[-N_FFN:] # FFN on the N_FFN deepest dropped blocks (incl. the last) | |
| print(f"DROP {len(drop_idx)} least-important: {drop_idx}") | |
| print(f"KEEP {len(keep_idx)}: {keep_idx}") | |
| print(f"FFN on {len(ffn_idx)} deepest dropped: {ffn_idx} (the other {DROP_K-N_FFN} are light)") | |
| print("capturing teacher block I/O sequences (+rotary) for warm-start...") | |
| io, rotary = capture_block_io_seq(pipe, PROMPTS, drop_idx, num_inference_steps=4, max_seqs=16) | |
| attach_surrogates(pipe.transformer, drop_idx, kind="linear_attention", heads=HEADS, head_dim=HEAD_DIM, | |
| conv_kernel=CONV_K, ffn_hidden=FFN_HIDDEN, ffn_idx=ffn_idx, device="cuda", dtype=torch.bfloat16) | |
| ps = param_summary(pipe.transformer) | |
| print("student params:", ps) | |
| print("warm-start: fitting each surrogate (attn+conv+focus+FFN) to mimic its teacher block...") | |
| for i in drop_idx: | |
| sur = pipe.transformer.single_transformer_blocks[i] | |
| ie, fe = fit_surrogate_to_block(sur, io[i]["X"], io[i]["Y"], rotary, steps=400, lr=1e-3) | |
| sur.to(dtype=torch.bfloat16) | |
| print(f" blk{i:2d}: residual rel-err {ie:.3f} -> {fe:.3f}") | |
| torch.save(pipe.transformer.state_dict(), f"{OUT}/student_state.pt") | |
| meta = {"keep_idx": keep_idx, "surrogate_idx": drop_idx, "drop_k": DROP_K, "kind": "linear_attention", | |
| "heads": HEADS, "head_dim": HEAD_DIM, "conv_kernel": CONV_K, "ffn_hidden": FFN_HIDDEN, | |
| "ffn_idx": ffn_idx, "importance": imp, "param_summary": ps} | |
| json.dump(meta, open(f"{OUT}/selection.json", "w"), indent=2) | |
| print(f"saved -> {OUT}/") | |
| os.makedirs("outputs/student_linattn4_smoke", exist_ok=True) | |
| SMOKE = ['a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water"] | |
| g = torch.Generator(device="cuda").manual_seed(0) | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| out = pipe(prompt=SMOKE, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=g) | |
| for i, im in enumerate(out.images): | |
| im.save(f"outputs/student_linattn4_smoke/ls3_{i}.png") | |
| print("saved warm-started smoke images") | |
Xet Storage Details
- Size:
- 3.66 kB
- Xet hash:
- f6de4b3c66ec97e416c20d65e76c2cdc854b8f2d220bdb9bb009b8a2005eb446
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.