Buckets:
| """Unit test: 4-bit activation quantization error vs group size (the granularity ladder). | |
| This is the microbenchmark behind the per-group-activation decision (RESULTS.md 2026-06-10). | |
| It measures the relative L2 reconstruction error of `flux2distill.svdquant.fake_quant_act` | |
| (the EXACT function SVDQuantLinear calls at inference) at 4 bits across group sizes. | |
| Two modes: | |
| python3 scripts/test_act_group_quant.py # synthetic (CPU, seconds) | |
| python3 scripts/test_act_group_quant.py --real # real activations hooked from the model | |
| # (loads the pipeline, ~3 min, GPU) | |
| Synthetic mode: a Gaussian tensor shaped like our real activations (4 x 1536 tokens x 3072 | |
| channels, bf16) with 3 hand-planted outlier channels (60x/25x/10x) imitating the documented | |
| transformer outlier-channel phenomenon. Real mode: forward-pre-hooks capture the actual inputs | |
| of 3 representative block Linears (attn qkv early / mlp mid / late) over 8 calib images from | |
| data/monet_cache (same forward as scripts/12), then the same ladder runs on those tensors. | |
| Reference results (2026-06-10 box, recorded for regression): | |
| synthetic: per-token 0.602 | g256 0.233 | g128 0.172 | g64 0.129 | g32 0.099 | g16 0.077 | |
| REAL acts (8 calib imgs, bf16): | |
| S0 qkv_mlp (3072-ch): per-token 0.469 | g64 0.169 | g16 0.109 | |
| S10 to_out (12288-ch): per-token 0.443 | g64 0.152 | g16 0.104 | |
| S19 qkv_mlp (3072-ch): per-token 0.366 | g64 0.134 | g16 0.095 | |
| -> ~minus 20-25% rel-err per halving of group size, synthetic AND real; per-token A4 matches | |
| the catastrophic cells (vel rel-err 0.51-0.66) and g64 matches the fix (RESULTS.md). | |
| """ | |
| import sys | |
| import torch | |
| from flux2distill.svdquant import fake_quant_act | |
| GROUPS = (0, 256, 128, 64, 32, 16) | |
| BITS = 4 | |
| def ladder(x: torch.Tensor, tag: str): | |
| xf = x.float() | |
| print(f"-- {tag} shape={tuple(x.shape)} absmax={xf.abs().max():.1f}") | |
| for g in GROUPS: | |
| xq = fake_quant_act(x, BITS, group=g) | |
| rel = ((xq.float() - xf).norm() / xf.norm()).item() | |
| print(f" A{BITS} g={g if g else 'per-token':>9}: rel-err {rel:.4f}") | |
| def synthetic(): | |
| torch.manual_seed(0) | |
| x = torch.randn(4, 1536, 3072, dtype=torch.bfloat16) | |
| for c, m in [(137, 60.0), (901, 25.0), (2050, 10.0)]: | |
| x[..., c] *= m | |
| ladder(x, "synthetic gaussian + 3 planted outlier channels") | |
| def real(): | |
| from flux2distill.data import LatentCaptionDataset | |
| from flux2distill.losses import build_x_t | |
| from flux2distill.model_utils import load_pipeline | |
| pipe = load_pipeline(device="cuda") | |
| tf = pipe.transformer.eval().requires_grad_(False) | |
| # one early attn input, one mid mlp input, one late block input | |
| probes = { | |
| "single_transformer_blocks.0.attn.to_qkv_mlp_proj": None, | |
| "single_transformer_blocks.10.attn.to_out": None, | |
| "single_transformer_blocks.19.attn.to_qkv_mlp_proj": None, | |
| } | |
| handles = [] | |
| for name in probes: | |
| mod = tf.get_submodule(name) | |
| def hook(m, args, _name=name): | |
| if probes[_name] is None: # keep the first batch only | |
| probes[_name] = args[0].detach().to("cpu", torch.bfloat16) | |
| handles.append(mod.register_forward_pre_hook(hook)) | |
| ds = LatentCaptionDataset(cache_dir="data/monet_cache") | |
| x0 = ds.latents[:8].to("cuda", torch.bfloat16) | |
| pe, tid = pipe.encode_prompt([ds.captions[j] for j in range(8)], device="cuda") | |
| g = torch.Generator(device="cuda").manual_seed(0) | |
| eps = torch.randn(x0.shape, generator=g, device="cuda", dtype=torch.float32) | |
| sigma = torch.rand(8, generator=g, device="cuda", dtype=torch.float32) | |
| x_t = build_x_t(x0.float(), eps, sigma).to(torch.bfloat16) | |
| _, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda", | |
| torch.Generator(device="cuda").manual_seed(0)) | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| tf(hidden_states=x_t, timestep=sigma, guidance=None, encoder_hidden_states=pe, | |
| txt_ids=tid, img_ids=img_ids, return_dict=False) | |
| for h in handles: | |
| h.remove() | |
| for name, x in probes.items(): | |
| ladder(x, f"REAL act input of {name}") | |
| if __name__ == "__main__": | |
| synthetic() | |
| if "--real" in sys.argv: | |
| real() | |
Xet Storage Details
- Size:
- 4.33 kB
- Xet hash:
- 06097c1eb57106214db058cb7efb47a22cc061c7945f9594055246ed7dce9795
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.