Buckets:
| """Build student v2: keep most single blocks full, drop only DROP_K (default 6) by | |
| leave-one-out importance, replace the dropped ones with warm-started surrogates.""" | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| from flux2distill.config import Config | |
| from flux2distill.model_utils import load_pipeline, param_summary | |
| from flux2distill.surgery import (capture_single_block_io, importance_by_ablation, | |
| select_blocks_by_importance, build_student) | |
| from flux2distill.calibration import fit_surrogate | |
| cfg = Config() | |
| DROP_K = int(sys.argv[1]) if len(sys.argv) > 1 else 6 | |
| OUT = "outputs/student_v2" | |
| os.makedirs(OUT, exist_ok=True) | |
| IMP_PROMPTS = [ | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks", | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a close-up portrait of an elderly fisherman with a weathered face", | |
| "exactly five ripe red apples arranged in a row on a wooden table", | |
| ] | |
| print("loading teacher pipeline...") | |
| pipe = load_pipeline(device="cuda") | |
| print(f"\nmeasuring per-block importance (leave-one-out ablation, {len(IMP_PROMPTS)} prompts)...") | |
| t0 = time.time() | |
| imp = importance_by_ablation(pipe, IMP_PROMPTS, num_inference_steps=4, seed=0) | |
| print(f" done in {time.time()-t0:.1f}s") | |
| for i in sorted(imp): | |
| print(f" blk{i:2d}: importance={imp[i]:.4f}") | |
| keep_idx, drop_idx = select_blocks_by_importance(imp, DROP_K) | |
| print(f"\nDROP (surrogate) {len(drop_idx)} least-important: {drop_idx}") | |
| print(f"KEEP full {len(keep_idx)}: {keep_idx}") | |
| print("\ncapturing I/O for the dropped blocks (surrogate warm-start)...") | |
| io = capture_single_block_io(pipe, IMP_PROMPTS, num_inference_steps=4, max_tokens_per_block=12000) | |
| print("building v2 student (lstsq warm-start on dropped blocks)...") | |
| errs = build_student(pipe.transformer, drop_idx, io, rank=cfg.surgery.rank, act=cfg.surgery.act, device="cuda") | |
| for i in drop_idx: | |
| print(f" blk{i:2d}: lstsq rel-err {errs[i]:.4f}") | |
| print("calibration fit...") | |
| for i in drop_idx: | |
| sur = pipe.transformer.single_transformer_blocks[i] | |
| ie, fe = fit_surrogate(sur, io[i]["X"], io[i]["Delta"], steps=200, lr=1e-3) | |
| sur.to(dtype=torch.bfloat16) | |
| print(f" blk{i:2d}: {ie:.4f} -> {fe:.4f}") | |
| print("\nv2 student params:", param_summary(pipe.transformer)) | |
| torch.save(pipe.transformer.state_dict(), f"{OUT}/student_state.pt") | |
| meta = {"keep_idx": keep_idx, "surrogate_idx": drop_idx, "drop_k": DROP_K, | |
| "rank": cfg.surgery.rank, "act": cfg.surgery.act, "importance": imp, | |
| "lstsq_rel_err": {str(k): v for k, v in errs.items()}, | |
| "param_summary": param_summary(pipe.transformer)} | |
| with open(f"{OUT}/selection.json", "w") as f: | |
| json.dump(meta, f, indent=2) | |
| print(f"saved to {OUT}/") | |
| # smoke gen | |
| os.makedirs("outputs/student_v2_smoke", exist_ok=True) | |
| SMOKE = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water", | |
| ] | |
| gen = torch.Generator(device="cuda").manual_seed(0) | |
| out = pipe(prompt=SMOKE, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=gen) | |
| for i, im in enumerate(out.images): | |
| im.save(f"outputs/student_v2_smoke/v2_{i}.png") | |
| print("saved v2 smoke images (pre-training, warm-start only)") | |
Xet Storage Details
- Size:
- 3.34 kB
- Xet hash:
- 17c32603ceab203863ae3017b9e8b0c308a1b98603655bb30fd795c6e0aeeed7
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.