Text-to-Image
Diffusers
English
Mixture of Experts
sparse-moe
diffusion
image-generation
quantization
fp8
Instructions to use D-Squarius-Green-Jr/Nucleus-Image-FP8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use D-Squarius-Green-Jr/Nucleus-Image-FP8 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("D-Squarius-Green-Jr/Nucleus-Image-FP8", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
| """ | |
| SwiGLUExperts FP8 monkey-patch. | |
| Imported by both 03_quantize_fp8.py (before loading BF16 model) and 04_test_inference.py | |
| (before loading FP8 model). Idempotent: importing twice is a no-op. | |
| What it does: | |
| - Adds two persistent buffers (gate_up_proj_scale, down_proj_scale) to every SwiGLUExperts. | |
| - Replaces _run_experts_for_loop to dequantize per-expert weights on-the-fly when stored as fp8_e4m3fn. | |
| - Forces use_grouped_mm=False (the grouped_mm kernel doesn't accept fp8 e4m3 inputs as of torch 2.11). | |
| - If weights are still bf16 (un-quantized model), behavior is identical to the original SwiGLU forward. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models.transformers import transformer_nucleusmoe_image as _moe_mod | |
| FP8_E4M3_MAX = 448.0 # float8_e4m3fn dynamic range | |
| SCALE_DTYPE = torch.bfloat16 | |
| _PATCH_FLAG = "_nucleus_fp8_patched_v1" | |
| def apply_patch(): | |
| if getattr(_moe_mod.SwiGLUExperts, _PATCH_FLAG, False): | |
| return # already patched | |
| cls = _moe_mod.SwiGLUExperts | |
| orig_init = cls.__init__ | |
| def patched_init(self, hidden_size, moe_intermediate_dim, num_experts, use_grouped_mm: bool = False): | |
| # Force use_grouped_mm off — fp8 inputs aren't supported by F.grouped_mm on Blackwell yet, | |
| # and the for-loop path is the patched/quantized path. | |
| orig_init(self, hidden_size, moe_intermediate_dim, num_experts, use_grouped_mm=False) | |
| # Persistent buffers; default to ones so a non-quantized BF16 checkpoint still produces | |
| # mathematically identical output through patched_for_loop. | |
| self.register_buffer( | |
| "gate_up_proj_scale", | |
| torch.ones(num_experts, 1, 2 * moe_intermediate_dim, dtype=SCALE_DTYPE), | |
| persistent=True, | |
| ) | |
| self.register_buffer( | |
| "down_proj_scale", | |
| torch.ones(num_experts, 1, hidden_size, dtype=SCALE_DTYPE), | |
| persistent=True, | |
| ) | |
| def patched_for_loop(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: | |
| n_list = num_tokens_per_expert.tolist() | |
| n_real = sum(n_list) | |
| n_pad = x.shape[0] - n_real | |
| x_per_expert = torch.split(x[:n_real], split_size_or_sections=n_list, dim=0) | |
| is_fp8 = self.gate_up_proj.dtype == torch.float8_e4m3fn | |
| compute_dtype = x.dtype # bf16 in normal use | |
| outs = [] | |
| for i, xe in enumerate(x_per_expert): | |
| if is_fp8: | |
| w_gu = self.gate_up_proj[i].to(compute_dtype) * self.gate_up_proj_scale[i].to(compute_dtype) | |
| w_dn = self.down_proj[i].to(compute_dtype) * self.down_proj_scale[i].to(compute_dtype) | |
| else: | |
| w_gu = self.gate_up_proj[i] | |
| w_dn = self.down_proj[i] | |
| gate_up = torch.matmul(xe, w_gu) | |
| gate, up = gate_up.chunk(2, dim=-1) | |
| outs.append(torch.matmul(F.silu(gate) * up, w_dn)) | |
| out = torch.cat(outs, dim=0) | |
| return torch.vstack((out, out.new_zeros((n_pad, out.shape[-1])))) | |
| def patched_forward(self, x, num_tokens_per_expert): | |
| return patched_for_loop(self, x, num_tokens_per_expert) | |
| cls.__init__ = patched_init | |
| cls._run_experts_for_loop = patched_for_loop | |
| cls.forward = patched_forward | |
| setattr(cls, _PATCH_FLAG, True) | |
| def _patched_fp8_linear_forward(self, x): | |
| # Dequantize qdata fp8 + per-output-channel scale on the fly, then F.linear. | |
| w = self.weight.to(x.dtype) * self._fp8_scale.to(x.dtype) | |
| return F.linear(x, w, self.bias) | |
| def install_fp8_linear(linear: nn.Linear, fp8_qdata: torch.Tensor, scale: torch.Tensor) -> None: | |
| """Convert one Linear instance to fp8-stored, per-instance forward override.""" | |
| assert isinstance(linear, nn.Linear), f"expected nn.Linear, got {type(linear).__name__}" | |
| linear.weight = nn.Parameter(fp8_qdata, requires_grad=False) | |
| linear.register_buffer("_fp8_scale", scale.to(SCALE_DTYPE)) | |
| # Bind forward at the instance level — does NOT affect other nn.Linear instances. | |
| linear.forward = _patched_fp8_linear_forward.__get__(linear, type(linear)) | |
| def load_fp8_safetensors_transformer(safetensors_path: str, config_dir: str | None = None): | |
| """ | |
| Load a Nucleus-Image FP8 transformer from a single safetensors file. The safetensors | |
| holds raw fp8 + per-channel bf16 scale tensors (no torchao wrapper objects). | |
| SwiGLUExperts: weights `gate_up_proj` / `down_proj` (fp8) + buffers `*_scale` (bf16). | |
| Standard nn.Linear: weight (fp8) + buffer `_fp8_scale` (bf16) — forward override | |
| installed by `install_fp8_linear` so the dequantize+matmul happens inline. | |
| """ | |
| from pathlib import Path | |
| import json | |
| from safetensors.torch import safe_open | |
| from diffusers import AutoModel | |
| from accelerate import init_empty_weights | |
| apply_patch() | |
| cfg_dir = Path(config_dir) if config_dir else Path(safetensors_path).parent | |
| cfg = json.loads((cfg_dir / "config.json").read_text(encoding="utf-8")) | |
| # Locate the actual transformer class via the AutoModel mapping. | |
| from diffusers import NucleusMoEImageTransformer2DModel | |
| with init_empty_weights(): | |
| model = NucleusMoEImageTransformer2DModel.from_config(cfg) | |
| fp8_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) | |
| # Two passes over the file: | |
| # Pass 1: identify Linears that have an `_fp8_scale` sibling — install fp8 weight + scale + forward. | |
| # Pass 2: load everything else via standard assign. | |
| with safe_open(safetensors_path, framework="pt", device="cpu") as f: | |
| all_keys = set(f.keys()) | |
| # Linear keys with fp8 representation: those with a `_fp8_scale` buffer. | |
| linear_paths = set() | |
| for k in all_keys: | |
| if k.endswith("._fp8_scale"): | |
| base = k[: -len("._fp8_scale")] | |
| if (base + ".weight") in all_keys: | |
| linear_paths.add(base) | |
| # Pass 1: install fp8 Linears. | |
| for base in linear_paths: | |
| module = model.get_submodule(base) | |
| qdata = f.get_tensor(base + ".weight").to("cpu") | |
| scale = f.get_tensor(base + "._fp8_scale").to("cpu") | |
| bias = None | |
| if (base + ".bias") in all_keys: | |
| bias = f.get_tensor(base + ".bias").to("cpu").to(torch.bfloat16) | |
| # The init_empty_weights model has meta-device weights; replace cleanly. | |
| if module.bias is not None and bias is not None: | |
| module.bias = nn.Parameter(bias) | |
| install_fp8_linear(module, qdata, scale) | |
| # Pass 2: rest of the state dict (SwiGLUExperts fp8 + scales, norms, embeddings, etc.). | |
| for k in all_keys: | |
| # Skip keys we already handled in pass 1. | |
| if any(k == p + ".weight" or k == p + "._fp8_scale" or k == p + ".bias" for p in linear_paths): | |
| continue | |
| t = f.get_tensor(k).to("cpu") | |
| module_path, _, attr = k.rpartition(".") | |
| module = model.get_submodule(module_path) if module_path else model | |
| cur = getattr(module, attr, None) | |
| if isinstance(cur, nn.Parameter) or cur is None: | |
| # cur may be a meta-device Parameter — replace whole thing | |
| if t.dtype.is_floating_point and t.dtype not in fp8_dtypes and t.dtype != torch.bfloat16: | |
| t = t.to(torch.bfloat16) | |
| setattr(module, attr, nn.Parameter(t, requires_grad=False) if attr in dict(module.named_parameters(recurse=False)) or cur is None or isinstance(cur, nn.Parameter) else t) | |
| else: | |
| # Buffer | |
| if t.dtype.is_floating_point and t.dtype not in fp8_dtypes and t.dtype != torch.bfloat16: | |
| t = t.to(torch.bfloat16) | |
| module.register_buffer(attr, t) | |
| # Final cleanup: any param/buffer still on meta should never happen, but assert. | |
| meta_left = [n for n, p in model.named_parameters() if p.is_meta] | |
| if meta_left: | |
| raise RuntimeError(f"Some parameters still on meta device after load: {meta_left[:5]}... ({len(meta_left)} total)") | |
| print(f" load_fp8_safetensors_transformer: installed {len(linear_paths)} fp8 Linears") | |
| return model | |
| def load_fp8_transformer(model_dir: str): | |
| """ | |
| Load an FP8-quantized Nucleus transformer from `model_dir` while preserving on-disk | |
| fp8_e4m3fn dtypes for SwiGLUExperts weights. Other floating params are normalized to bf16. | |
| Why: AutoModel.from_pretrained(torch_dtype=torch.bfloat16) force-casts ALL floating | |
| weights including fp8. We do the standard load (which casts to bf16), then re-stream | |
| the on-disk shards and reassign every fp8_e4m3fn tensor as a fresh nn.Parameter so | |
| the dtype is preserved. TorchAO Float8Tensor wrappers for nn.Linear are restored | |
| correctly by the standard loader (its DiffusersAutoQuantizer hook intercepts them). | |
| """ | |
| import json | |
| from pathlib import Path | |
| from diffusers import AutoModel | |
| apply_patch() # ensure SwiGLUExperts is patched before construction | |
| model = AutoModel.from_pretrained( | |
| model_dir, | |
| use_safetensors=False, | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Re-stream disk to recover fp8 dtypes (auto-cast lost them). | |
| md = Path(model_dir) | |
| idx_path = md / "diffusion_pytorch_model.bin.index.json" | |
| if idx_path.exists(): | |
| idx = json.loads(idx_path.read_text(encoding="utf-8")) | |
| files = sorted(set(idx["weight_map"].values())) | |
| else: | |
| files = ["diffusion_pytorch_model.bin"] | |
| fp8_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) | |
| fp8_reassigned = 0 | |
| for fname in files: | |
| shard = torch.load(md / fname, map_location="cpu", weights_only=False) | |
| for key, tensor in shard.items(): | |
| if not hasattr(tensor, "dtype") or tensor.dtype not in fp8_dtypes: | |
| continue | |
| module_path, _, attr = key.rpartition(".") | |
| module = model.get_submodule(module_path) | |
| cur = getattr(module, attr, None) | |
| if isinstance(cur, nn.Parameter): | |
| setattr(module, attr, nn.Parameter(tensor, requires_grad=False)) | |
| else: | |
| # buffer | |
| module.register_buffer(attr, tensor) | |
| fp8_reassigned += 1 | |
| del shard | |
| # Final pass: any non-fp8 floating param in fp32/fp16 → bf16 (uniformity). | |
| for _, p in model.named_parameters(): | |
| if p.dtype in fp8_dtypes: | |
| continue | |
| if p.dtype.is_floating_point and p.dtype != torch.bfloat16: | |
| p.data = p.data.to(torch.bfloat16) | |
| for _, b in model.named_buffers(): | |
| if b.dtype in fp8_dtypes: | |
| continue | |
| if b.dtype.is_floating_point and b.dtype != torch.bfloat16: | |
| b.data = b.data.to(torch.bfloat16) | |
| print(f" load_fp8_transformer: re-assigned {fp8_reassigned} fp8 tensors from shards") | |
| return model | |
| def quantize_swiglu_experts_(module: nn.Module) -> dict: | |
| """Quantize a single SwiGLUExperts module in-place. Returns a small report dict.""" | |
| assert type(module).__name__ == "SwiGLUExperts", f"expected SwiGLUExperts, got {type(module).__name__}" | |
| device = module.gate_up_proj.device | |
| def _quant(w: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| # w shape: (num_experts, in_dim, out_dim). Per-expert per-output-channel scale on dim=-2. | |
| w32 = w.detach().to(torch.float32) | |
| scale = w32.abs().amax(dim=-2, keepdim=True).clamp(min=1e-12) / FP8_E4M3_MAX | |
| q = (w32 / scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX).to(torch.float8_e4m3fn) | |
| return q, scale.to(SCALE_DTYPE) | |
| q_gu, s_gu = _quant(module.gate_up_proj.data) | |
| q_dn, s_dn = _quant(module.down_proj.data) | |
| # Reconstruction error gate. Use rel L2 (||W_hat - W||_2 / ||W||_2) per tensor — the right | |
| # metric for forward fidelity (matmul amplifies L2-style error, not max-abs). | |
| w_gu_orig = module.gate_up_proj.data.to(torch.float32) | |
| w_dn_orig = module.down_proj.data.to(torch.float32) | |
| rec_gu = q_gu.to(torch.float32) * s_gu.to(torch.float32) | |
| rec_dn = q_dn.to(torch.float32) * s_dn.to(torch.float32) | |
| rep = { | |
| "gu_rel_l2": ((rec_gu - w_gu_orig).norm() / w_gu_orig.norm().clamp(min=1e-12)).item(), | |
| "dn_rel_l2": ((rec_dn - w_dn_orig).norm() / w_dn_orig.norm().clamp(min=1e-12)).item(), | |
| # Keep the loose rel-max-err for visibility but no longer used as a gate. | |
| "gu_rel_max": ((rec_gu - w_gu_orig).abs().amax() / w_gu_orig.abs().amax().clamp(min=1e-12)).item(), | |
| "dn_rel_max": ((rec_dn - w_dn_orig).abs().amax() / w_dn_orig.abs().amax().clamp(min=1e-12)).item(), | |
| } | |
| # In-place replacement | |
| module.gate_up_proj = nn.Parameter(q_gu.to(device), requires_grad=False) | |
| module.down_proj = nn.Parameter(q_dn.to(device), requires_grad=False) | |
| module.gate_up_proj_scale.data = s_gu.to(device) | |
| module.down_proj_scale.data = s_dn.to(device) | |
| return rep | |