Buckets:

Mercity/FluxDistill / scripts /27_convert_full_4b.py
Pranav2748's picture
download
raw
5.07 kB
"""Convert klein-4B to a DEPLOYABLE fused NVFP4 model: quantize+pack every Linear into
Nunchaku's format, load into NunchakuFlux2Transformer2DModel (fused attention + real FP4 kernel),
and generate images vs the bf16 teacher. The converter convention is validated (scripts/26).
Weight->fused-module map (from the diffusers Flux2 structure):
double: attn.to_qkv = cat(to_q,to_k,to_v); to_out.0; to_added_qkv = cat(add_q/k/v); to_add_out;
ff.linear_in/out; ff_context.linear_in/out
single: to_qkv_mlp_proj (3072->27648) -> qkv_proj[:9216] + mlp_fc1[9216:];
to_out (12288->3072) -> out_proj[:, :3072] + mlp_fc2[:, 3072:]
"""
import sys, time, torch
import torch.nn as nn
from flux2distill.model_utils import load_pipeline
from flux2distill.eval_utils import side_by_side
from flux2distill.nunchaku_export import quantize_pack_nvfp4
from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel
from nunchaku.models.linear import SVDQW4A4Linear
RANK = int(sys.argv[1]) if len(sys.argv) > 1 else 128
PROMPTS = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
"a close-up of a smiling young woman holding up five fingers, natural window light, sharp focus on the hand",
"a bustling tokyo street at night, neon signs, rain-slicked pavement, reflections",
]
def gen(pipe, tag):
# batch=1 per prompt: the fused attention's packed rotary is per-token (assumes batch 1)
imgs = []
for p in PROMPTS:
g = torch.Generator(device="cuda").manual_seed(0)
imgs.append(pipe(prompt=p, num_inference_steps=4, guidance_scale=1.0,
height=512, width=512, generator=g).images[0])
return imgs
pipe = load_pipeline(device="cuda")
tf = pipe.transformer
tf.eval().requires_grad_(False)
print("generating teacher images...")
t_imgs = gen(pipe, "teacher")
# ---- capture source weights BEFORE patching (patching discards the fused originals) ----
src = {}
for i, b in enumerate(tf.transformer_blocks):
a = b.attn
src[("d", i, "attn.to_qkv")] = torch.cat([a.to_q.weight, a.to_k.weight, a.to_v.weight], 0).clone()
src[("d", i, "attn.to_out.0")] = a.to_out[0].weight.clone()
src[("d", i, "attn.to_added_qkv")] = torch.cat([a.add_q_proj.weight, a.add_k_proj.weight, a.add_v_proj.weight], 0).clone()
src[("d", i, "attn.to_add_out")] = a.to_add_out.weight.clone()
src[("d", i, "ff.linear_in")] = b.ff.linear_in.weight.clone()
src[("d", i, "ff.linear_out")] = b.ff.linear_out.weight.clone()
src[("d", i, "ff_context.linear_in")] = b.ff_context.linear_in.weight.clone()
src[("d", i, "ff_context.linear_out")] = b.ff_context.linear_out.weight.clone()
for i, b in enumerate(tf.single_transformer_blocks):
src[("s", i, "Win")] = b.attn.to_qkv_mlp_proj.weight.clone() # (27648, 3072)
src[("s", i, "Wout")] = b.attn.to_out.weight.clone() # (3072, 12288)
# ---- patch to fused model ----
tf.__class__ = NunchakuFlux2Transformer2DModel
tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16)
def load_into(m, W):
bufs = quantize_pack_nvfp4(W.float().cuda(), rank=RANK, refine=3)
for n in ("qweight", "wscales", "wcscales", "proj_down", "proj_up", "smooth_factor"):
getattr(m, n).data.copy_(bufs[n].reshape(getattr(m, n).shape).to(getattr(m, n).dtype))
m.smooth_factor_orig.data.copy_(bufs["smooth_factor"].reshape(m.smooth_factor_orig.shape).to(m.smooth_factor_orig.dtype))
m.wtscale = bufs["wtscale"]
t0 = time.time(); n = 0
for name, m in tf.named_modules():
if not isinstance(m, SVDQW4A4Linear):
continue
m.to_empty(device="cuda")
if m.bias is not None:
m.bias.zero_()
parts = name.split(".")
bt, idx = ("d" if parts[0] == "transformer_blocks" else "s"), int(parts[1])
local = ".".join(parts[2:])
if bt == "d":
W = src[("d", idx, local)]
else: # single block: split the fused projections
qo, mo = m.out_features, None
if local == "attn.qkv_proj": W = src[("s", idx, "Win")][:m.out_features]
elif local == "attn.mlp_fc1": W = src[("s", idx, "Win")][9216:9216 + m.out_features]
elif local == "attn.out_proj": W = src[("s", idx, "Wout")][:, :m.in_features]
elif local == "attn.mlp_fc2": W = src[("s", idx, "Wout")][:, 3072:3072 + m.in_features]
else: raise KeyError(local)
load_into(m, W)
n += 1
if n % 25 == 0:
print(f" converted {n} layers ({time.time()-t0:.0f}s)")
del src; torch.cuda.empty_cache()
print(f"converted {n} Linears in {time.time()-t0:.0f}s")
print("generating NVFP4-fused images...")
q_imgs = gen(pipe, "nvfp4")
import os
os.makedirs("outputs/nvfp4/deploy", exist_ok=True)
for i, (t, q) in enumerate(zip(t_imgs, q_imgs)):
side_by_side(t, q, "teacher", "NVFP4-fused-4B", PROMPTS[i]).save(f"outputs/nvfp4/deploy/cmp_{i}.png")
q.save(f"outputs/nvfp4/deploy/q_{i}.png")
print("saved montages -> outputs/nvfp4/deploy/")

Xet Storage Details

Size:
5.07 kB
·
Xet hash:
b5633d15fa9ac49dc292396f4423573e12e459b7a303a60a365ca1126cf38dc8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.