Buckets:

Mercity/FluxDistill / scripts /05_build_student_v2.py
Pranav2748's picture
download
raw
3.34 kB
"""Build student v2: keep most single blocks full, drop only DROP_K (default 6) by
leave-one-out importance, replace the dropped ones with warm-started surrogates."""
import json
import os
import sys
import time
import torch
from flux2distill.config import Config
from flux2distill.model_utils import load_pipeline, param_summary
from flux2distill.surgery import (capture_single_block_io, importance_by_ablation,
select_blocks_by_importance, build_student)
from flux2distill.calibration import fit_surrogate
cfg = Config()
DROP_K = int(sys.argv[1]) if len(sys.argv) > 1 else 6
OUT = "outputs/student_v2"
os.makedirs(OUT, exist_ok=True)
IMP_PROMPTS = [
"a serene mountain lake at sunrise reflecting snow-capped peaks",
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a close-up portrait of an elderly fisherman with a weathered face",
"exactly five ripe red apples arranged in a row on a wooden table",
]
print("loading teacher pipeline...")
pipe = load_pipeline(device="cuda")
print(f"\nmeasuring per-block importance (leave-one-out ablation, {len(IMP_PROMPTS)} prompts)...")
t0 = time.time()
imp = importance_by_ablation(pipe, IMP_PROMPTS, num_inference_steps=4, seed=0)
print(f" done in {time.time()-t0:.1f}s")
for i in sorted(imp):
print(f" blk{i:2d}: importance={imp[i]:.4f}")
keep_idx, drop_idx = select_blocks_by_importance(imp, DROP_K)
print(f"\nDROP (surrogate) {len(drop_idx)} least-important: {drop_idx}")
print(f"KEEP full {len(keep_idx)}: {keep_idx}")
print("\ncapturing I/O for the dropped blocks (surrogate warm-start)...")
io = capture_single_block_io(pipe, IMP_PROMPTS, num_inference_steps=4, max_tokens_per_block=12000)
print("building v2 student (lstsq warm-start on dropped blocks)...")
errs = build_student(pipe.transformer, drop_idx, io, rank=cfg.surgery.rank, act=cfg.surgery.act, device="cuda")
for i in drop_idx:
print(f" blk{i:2d}: lstsq rel-err {errs[i]:.4f}")
print("calibration fit...")
for i in drop_idx:
sur = pipe.transformer.single_transformer_blocks[i]
ie, fe = fit_surrogate(sur, io[i]["X"], io[i]["Delta"], steps=200, lr=1e-3)
sur.to(dtype=torch.bfloat16)
print(f" blk{i:2d}: {ie:.4f} -> {fe:.4f}")
print("\nv2 student params:", param_summary(pipe.transformer))
torch.save(pipe.transformer.state_dict(), f"{OUT}/student_state.pt")
meta = {"keep_idx": keep_idx, "surrogate_idx": drop_idx, "drop_k": DROP_K,
"rank": cfg.surgery.rank, "act": cfg.surgery.act, "importance": imp,
"lstsq_rel_err": {str(k): v for k, v in errs.items()},
"param_summary": param_summary(pipe.transformer)}
with open(f"{OUT}/selection.json", "w") as f:
json.dump(meta, f, indent=2)
print(f"saved to {OUT}/")
# smoke gen
os.makedirs("outputs/student_v2_smoke", exist_ok=True)
SMOKE = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
]
gen = torch.Generator(device="cuda").manual_seed(0)
out = pipe(prompt=SMOKE, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=gen)
for i, im in enumerate(out.images):
im.save(f"outputs/student_v2_smoke/v2_{i}.png")
print("saved v2 smoke images (pre-training, warm-start only)")

Xet Storage Details

Size:
3.34 kB
·
Xet hash:
17c32603ceab203863ae3017b9e8b0c308a1b98603655bb30fd795c6e0aeeed7

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.