File size: 4,689 Bytes
9e016c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import torch
import torch.nn as nn
from typing import Callable, Union, Dict, Any
from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig
try:
from torchao.quantization import FqnToConfig
except ImportError:
from torchao.quantization import ModuleFqnToConfig as FqnToConfig
def load_torchao_fp8_static_model(
*,
ckpt_path: str,
base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]],
device: str = "cuda",
strict: bool = True,
) -> nn.Module:
ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu")
if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")):
raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}")
# -------------------------
# Parse dtype
# -------------------------
dtype_str = str(ckpt["fp8_dtype"])
if "float8_e4m3fn" in dtype_str:
fp8_dtype = torch.float8_e4m3fn
elif "float8_e5m2" in dtype_str:
fp8_dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}")
# -------------------------
# Normalize scales to fp32 scalar tensors
# -------------------------
act_scales_raw = {}
for k, v in ckpt["act_scales"].items():
if torch.is_tensor(v):
act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0]
else:
act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32)
# -------------------------
# Build model
# -------------------------
if isinstance(base_model_or_factory, nn.Module):
model = base_model_or_factory
else:
model = base_model_or_factory()
if model is None or not isinstance(model, nn.Module):
raise TypeError("base_model_or_factory must return an nn.Module")
model.eval().to(device)
# -------------------------
# Collect Linear FQNs
# -------------------------
linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)]
linear_set = set(linear_fqns)
# -------------------------
# Auto-fix FQN prefix mismatch
# -------------------------
def score(keys):
return sum(1 for k in keys if k in linear_set)
candidates = []
# 1) identity
candidates.append(act_scales_raw)
# 2) strip "model."
stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")}
candidates.append(stripped)
# 3) add "model."
added = {("model." + k): v for k, v in act_scales_raw.items()}
candidates.append(added)
best = max(candidates, key=lambda d: score(d.keys()))
if score(best.keys()) == 0:
raise RuntimeError(
"Could not match any activation scale keys to Linear layers.\n"
f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n"
f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}"
)
act_scales = best
# -------------------------
# Build torchao config map
# -------------------------
fqn_to_cfg = {}
for fqn in linear_fqns:
if fqn in act_scales:
fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig(
scale=act_scales[fqn],
activation_dtype=fp8_dtype,
weight_dtype=fp8_dtype,
granularity=PerTensor(),
)
if not fqn_to_cfg:
raise RuntimeError("No Linear layers matched activation scales.")
try:
cfg = FqnToConfig(fqn_to_config=fqn_to_cfg)
except TypeError:
cfg = FqnToConfig(fqn_to_cfg)
# -------------------------
# Quantize structure first
# -------------------------
quantize_(model, cfg, filter_fn=None, device=device)
# -------------------------
# Load weights (CRITICAL: assign=True)
# -------------------------
try:
missing, unexpected = model.load_state_dict(
ckpt["state_dict"],
strict=strict,
assign=True, # <-- fixes copy_ dispatch error
)
except TypeError:
# Fallback if PyTorch too old
for name, tensor in ckpt["state_dict"].items():
module_name, attr = name.rsplit(".", 1)
mod = dict(model.named_modules())[module_name]
if isinstance(getattr(mod, attr), nn.Parameter):
setattr(mod, attr, nn.Parameter(tensor, requires_grad=False))
else:
setattr(mod, attr, tensor)
missing, unexpected = [], []
if strict and (missing or unexpected):
raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}")
return model
|