Upload folder using huggingface_hub
Browse files
load_torchao.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Callable, Union, Dict, Any
|
| 4 |
+
|
| 5 |
+
from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig
|
| 6 |
+
try:
|
| 7 |
+
from torchao.quantization import FqnToConfig
|
| 8 |
+
except ImportError:
|
| 9 |
+
from torchao.quantization import ModuleFqnToConfig as FqnToConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_torchao_fp8_static_model(
|
| 13 |
+
*,
|
| 14 |
+
ckpt_path: str,
|
| 15 |
+
base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]],
|
| 16 |
+
device: str = "cuda",
|
| 17 |
+
strict: bool = True,
|
| 18 |
+
) -> nn.Module:
|
| 19 |
+
|
| 20 |
+
ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu")
|
| 21 |
+
|
| 22 |
+
if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")):
|
| 23 |
+
raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}")
|
| 24 |
+
|
| 25 |
+
# -------------------------
|
| 26 |
+
# Parse dtype
|
| 27 |
+
# -------------------------
|
| 28 |
+
dtype_str = str(ckpt["fp8_dtype"])
|
| 29 |
+
if "float8_e4m3fn" in dtype_str:
|
| 30 |
+
fp8_dtype = torch.float8_e4m3fn
|
| 31 |
+
elif "float8_e5m2" in dtype_str:
|
| 32 |
+
fp8_dtype = torch.float8_e5m2
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}")
|
| 35 |
+
|
| 36 |
+
# -------------------------
|
| 37 |
+
# Normalize scales to fp32 scalar tensors
|
| 38 |
+
# -------------------------
|
| 39 |
+
act_scales_raw = {}
|
| 40 |
+
for k, v in ckpt["act_scales"].items():
|
| 41 |
+
if torch.is_tensor(v):
|
| 42 |
+
act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0]
|
| 43 |
+
else:
|
| 44 |
+
act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32)
|
| 45 |
+
|
| 46 |
+
# -------------------------
|
| 47 |
+
# Build model
|
| 48 |
+
# -------------------------
|
| 49 |
+
if isinstance(base_model_or_factory, nn.Module):
|
| 50 |
+
model = base_model_or_factory
|
| 51 |
+
else:
|
| 52 |
+
model = base_model_or_factory()
|
| 53 |
+
|
| 54 |
+
if model is None or not isinstance(model, nn.Module):
|
| 55 |
+
raise TypeError("base_model_or_factory must return an nn.Module")
|
| 56 |
+
|
| 57 |
+
model.eval().to(device)
|
| 58 |
+
|
| 59 |
+
# -------------------------
|
| 60 |
+
# Collect Linear FQNs
|
| 61 |
+
# -------------------------
|
| 62 |
+
linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)]
|
| 63 |
+
linear_set = set(linear_fqns)
|
| 64 |
+
|
| 65 |
+
# -------------------------
|
| 66 |
+
# Auto-fix FQN prefix mismatch
|
| 67 |
+
# -------------------------
|
| 68 |
+
def score(keys):
|
| 69 |
+
return sum(1 for k in keys if k in linear_set)
|
| 70 |
+
|
| 71 |
+
candidates = []
|
| 72 |
+
|
| 73 |
+
# 1) identity
|
| 74 |
+
candidates.append(act_scales_raw)
|
| 75 |
+
|
| 76 |
+
# 2) strip "model."
|
| 77 |
+
stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")}
|
| 78 |
+
candidates.append(stripped)
|
| 79 |
+
|
| 80 |
+
# 3) add "model."
|
| 81 |
+
added = {("model." + k): v for k, v in act_scales_raw.items()}
|
| 82 |
+
candidates.append(added)
|
| 83 |
+
|
| 84 |
+
best = max(candidates, key=lambda d: score(d.keys()))
|
| 85 |
+
if score(best.keys()) == 0:
|
| 86 |
+
raise RuntimeError(
|
| 87 |
+
"Could not match any activation scale keys to Linear layers.\n"
|
| 88 |
+
f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n"
|
| 89 |
+
f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
act_scales = best
|
| 93 |
+
|
| 94 |
+
# -------------------------
|
| 95 |
+
# Build torchao config map
|
| 96 |
+
# -------------------------
|
| 97 |
+
fqn_to_cfg = {}
|
| 98 |
+
for fqn in linear_fqns:
|
| 99 |
+
if fqn in act_scales:
|
| 100 |
+
fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig(
|
| 101 |
+
scale=act_scales[fqn],
|
| 102 |
+
activation_dtype=fp8_dtype,
|
| 103 |
+
weight_dtype=fp8_dtype,
|
| 104 |
+
granularity=PerTensor(),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if not fqn_to_cfg:
|
| 108 |
+
raise RuntimeError("No Linear layers matched activation scales.")
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
cfg = FqnToConfig(fqn_to_config=fqn_to_cfg)
|
| 112 |
+
except TypeError:
|
| 113 |
+
cfg = FqnToConfig(fqn_to_cfg)
|
| 114 |
+
|
| 115 |
+
# -------------------------
|
| 116 |
+
# Quantize structure first
|
| 117 |
+
# -------------------------
|
| 118 |
+
quantize_(model, cfg, filter_fn=None, device=device)
|
| 119 |
+
|
| 120 |
+
# -------------------------
|
| 121 |
+
# Load weights (CRITICAL: assign=True)
|
| 122 |
+
# -------------------------
|
| 123 |
+
try:
|
| 124 |
+
missing, unexpected = model.load_state_dict(
|
| 125 |
+
ckpt["state_dict"],
|
| 126 |
+
strict=strict,
|
| 127 |
+
assign=True, # <-- fixes copy_ dispatch error
|
| 128 |
+
)
|
| 129 |
+
except TypeError:
|
| 130 |
+
# Fallback if PyTorch too old
|
| 131 |
+
for name, tensor in ckpt["state_dict"].items():
|
| 132 |
+
module_name, attr = name.rsplit(".", 1)
|
| 133 |
+
mod = dict(model.named_modules())[module_name]
|
| 134 |
+
if isinstance(getattr(mod, attr), nn.Parameter):
|
| 135 |
+
setattr(mod, attr, nn.Parameter(tensor, requires_grad=False))
|
| 136 |
+
else:
|
| 137 |
+
setattr(mod, attr, tensor)
|
| 138 |
+
missing, unexpected = [], []
|
| 139 |
+
|
| 140 |
+
if strict and (missing or unexpected):
|
| 141 |
+
raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}")
|
| 142 |
+
|
| 143 |
+
return model
|
| 144 |
+
|
transformer_bf16/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "Flux2Transformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"attention_head_dim": 128,
|
| 5 |
+
"axes_dims_rope": [
|
| 6 |
+
32,
|
| 7 |
+
32,
|
| 8 |
+
32,
|
| 9 |
+
32
|
| 10 |
+
],
|
| 11 |
+
"eps": 1e-06,
|
| 12 |
+
"guidance_embeds": false,
|
| 13 |
+
"in_channels": 128,
|
| 14 |
+
"joint_attention_dim": 7680,
|
| 15 |
+
"mlp_ratio": 3.0,
|
| 16 |
+
"num_attention_heads": 24,
|
| 17 |
+
"num_layers": 5,
|
| 18 |
+
"num_single_layers": 20,
|
| 19 |
+
"out_channels": null,
|
| 20 |
+
"patch_size": 1,
|
| 21 |
+
"rope_theta": 2000,
|
| 22 |
+
"timestep_guidance_channels": 256
|
| 23 |
+
}
|
transformer_bf16/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1fa71fe800721fd1d3184a41ce0d8938b1c7d393a70247d9630bd0b8f3d60a85
|
| 3 |
+
size 7751109744
|
transformer_fp8_static/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "Flux2Transformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.37.0.dev0",
|
| 4 |
+
"attention_head_dim": 128,
|
| 5 |
+
"axes_dims_rope": [
|
| 6 |
+
32,
|
| 7 |
+
32,
|
| 8 |
+
32,
|
| 9 |
+
32
|
| 10 |
+
],
|
| 11 |
+
"eps": 1e-06,
|
| 12 |
+
"guidance_embeds": false,
|
| 13 |
+
"in_channels": 128,
|
| 14 |
+
"joint_attention_dim": 7680,
|
| 15 |
+
"mlp_ratio": 3.0,
|
| 16 |
+
"num_attention_heads": 24,
|
| 17 |
+
"num_layers": 5,
|
| 18 |
+
"num_single_layers": 20,
|
| 19 |
+
"out_channels": null,
|
| 20 |
+
"patch_size": 1,
|
| 21 |
+
"rope_theta": 2000,
|
| 22 |
+
"timestep_guidance_channels": 256
|
| 23 |
+
}
|
transformer_fp8_static/model_fp8_static.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13b37a5ca5cd9cf190236e7e99a3f086cf24618682f74e27a6f00cb173c308c8
|
| 3 |
+
size 4070791292
|