Buckets:

Mercity/FluxDistill / scripts /test_act_group_quant.py
Pranav2748's picture
download
raw
4.33 kB
"""Unit test: 4-bit activation quantization error vs group size (the granularity ladder).
This is the microbenchmark behind the per-group-activation decision (RESULTS.md 2026-06-10).
It measures the relative L2 reconstruction error of `flux2distill.svdquant.fake_quant_act`
(the EXACT function SVDQuantLinear calls at inference) at 4 bits across group sizes.
Two modes:
python3 scripts/test_act_group_quant.py # synthetic (CPU, seconds)
python3 scripts/test_act_group_quant.py --real # real activations hooked from the model
# (loads the pipeline, ~3 min, GPU)
Synthetic mode: a Gaussian tensor shaped like our real activations (4 x 1536 tokens x 3072
channels, bf16) with 3 hand-planted outlier channels (60x/25x/10x) imitating the documented
transformer outlier-channel phenomenon. Real mode: forward-pre-hooks capture the actual inputs
of 3 representative block Linears (attn qkv early / mlp mid / late) over 8 calib images from
data/monet_cache (same forward as scripts/12), then the same ladder runs on those tensors.
Reference results (2026-06-10 box, recorded for regression):
synthetic: per-token 0.602 | g256 0.233 | g128 0.172 | g64 0.129 | g32 0.099 | g16 0.077
REAL acts (8 calib imgs, bf16):
S0 qkv_mlp (3072-ch): per-token 0.469 | g64 0.169 | g16 0.109
S10 to_out (12288-ch): per-token 0.443 | g64 0.152 | g16 0.104
S19 qkv_mlp (3072-ch): per-token 0.366 | g64 0.134 | g16 0.095
-> ~minus 20-25% rel-err per halving of group size, synthetic AND real; per-token A4 matches
the catastrophic cells (vel rel-err 0.51-0.66) and g64 matches the fix (RESULTS.md).
"""
import sys
import torch
from flux2distill.svdquant import fake_quant_act
GROUPS = (0, 256, 128, 64, 32, 16)
BITS = 4
def ladder(x: torch.Tensor, tag: str):
xf = x.float()
print(f"-- {tag} shape={tuple(x.shape)} absmax={xf.abs().max():.1f}")
for g in GROUPS:
xq = fake_quant_act(x, BITS, group=g)
rel = ((xq.float() - xf).norm() / xf.norm()).item()
print(f" A{BITS} g={g if g else 'per-token':>9}: rel-err {rel:.4f}")
def synthetic():
torch.manual_seed(0)
x = torch.randn(4, 1536, 3072, dtype=torch.bfloat16)
for c, m in [(137, 60.0), (901, 25.0), (2050, 10.0)]:
x[..., c] *= m
ladder(x, "synthetic gaussian + 3 planted outlier channels")
@torch.no_grad()
def real():
from flux2distill.data import LatentCaptionDataset
from flux2distill.losses import build_x_t
from flux2distill.model_utils import load_pipeline
pipe = load_pipeline(device="cuda")
tf = pipe.transformer.eval().requires_grad_(False)
# one early attn input, one mid mlp input, one late block input
probes = {
"single_transformer_blocks.0.attn.to_qkv_mlp_proj": None,
"single_transformer_blocks.10.attn.to_out": None,
"single_transformer_blocks.19.attn.to_qkv_mlp_proj": None,
}
handles = []
for name in probes:
mod = tf.get_submodule(name)
def hook(m, args, _name=name):
if probes[_name] is None: # keep the first batch only
probes[_name] = args[0].detach().to("cpu", torch.bfloat16)
handles.append(mod.register_forward_pre_hook(hook))
ds = LatentCaptionDataset(cache_dir="data/monet_cache")
x0 = ds.latents[:8].to("cuda", torch.bfloat16)
pe, tid = pipe.encode_prompt([ds.captions[j] for j in range(8)], device="cuda")
g = torch.Generator(device="cuda").manual_seed(0)
eps = torch.randn(x0.shape, generator=g, device="cuda", dtype=torch.float32)
sigma = torch.rand(8, generator=g, device="cuda", dtype=torch.float32)
x_t = build_x_t(x0.float(), eps, sigma).to(torch.bfloat16)
_, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda",
torch.Generator(device="cuda").manual_seed(0))
with torch.autocast("cuda", dtype=torch.bfloat16):
tf(hidden_states=x_t, timestep=sigma, guidance=None, encoder_hidden_states=pe,
txt_ids=tid, img_ids=img_ids, return_dict=False)
for h in handles:
h.remove()
for name, x in probes.items():
ladder(x, f"REAL act input of {name}")
if __name__ == "__main__":
synthetic()
if "--real" in sys.argv:
real()

Xet Storage Details

Size:
4.33 kB
·
Xet hash:
06097c1eb57106214db058cb7efb47a22cc061c7945f9594055246ed7dce9795

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.