David-PHR's picture
Upload folder using huggingface_hub
9e016c4 verified
import torch
import torch.nn as nn
from typing import Callable, Union, Dict, Any
from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig
try:
from torchao.quantization import FqnToConfig
except ImportError:
from torchao.quantization import ModuleFqnToConfig as FqnToConfig
def load_torchao_fp8_static_model(
*,
ckpt_path: str,
base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]],
device: str = "cuda",
strict: bool = True,
) -> nn.Module:
ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu")
if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")):
raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}")
# -------------------------
# Parse dtype
# -------------------------
dtype_str = str(ckpt["fp8_dtype"])
if "float8_e4m3fn" in dtype_str:
fp8_dtype = torch.float8_e4m3fn
elif "float8_e5m2" in dtype_str:
fp8_dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}")
# -------------------------
# Normalize scales to fp32 scalar tensors
# -------------------------
act_scales_raw = {}
for k, v in ckpt["act_scales"].items():
if torch.is_tensor(v):
act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0]
else:
act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32)
# -------------------------
# Build model
# -------------------------
if isinstance(base_model_or_factory, nn.Module):
model = base_model_or_factory
else:
model = base_model_or_factory()
if model is None or not isinstance(model, nn.Module):
raise TypeError("base_model_or_factory must return an nn.Module")
model.eval().to(device)
# -------------------------
# Collect Linear FQNs
# -------------------------
linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)]
linear_set = set(linear_fqns)
# -------------------------
# Auto-fix FQN prefix mismatch
# -------------------------
def score(keys):
return sum(1 for k in keys if k in linear_set)
candidates = []
# 1) identity
candidates.append(act_scales_raw)
# 2) strip "model."
stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")}
candidates.append(stripped)
# 3) add "model."
added = {("model." + k): v for k, v in act_scales_raw.items()}
candidates.append(added)
best = max(candidates, key=lambda d: score(d.keys()))
if score(best.keys()) == 0:
raise RuntimeError(
"Could not match any activation scale keys to Linear layers.\n"
f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n"
f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}"
)
act_scales = best
# -------------------------
# Build torchao config map
# -------------------------
fqn_to_cfg = {}
for fqn in linear_fqns:
if fqn in act_scales:
fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig(
scale=act_scales[fqn],
activation_dtype=fp8_dtype,
weight_dtype=fp8_dtype,
granularity=PerTensor(),
)
if not fqn_to_cfg:
raise RuntimeError("No Linear layers matched activation scales.")
try:
cfg = FqnToConfig(fqn_to_config=fqn_to_cfg)
except TypeError:
cfg = FqnToConfig(fqn_to_cfg)
# -------------------------
# Quantize structure first
# -------------------------
quantize_(model, cfg, filter_fn=None, device=device)
# -------------------------
# Load weights (CRITICAL: assign=True)
# -------------------------
try:
missing, unexpected = model.load_state_dict(
ckpt["state_dict"],
strict=strict,
assign=True, # <-- fixes copy_ dispatch error
)
except TypeError:
# Fallback if PyTorch too old
for name, tensor in ckpt["state_dict"].items():
module_name, attr = name.rsplit(".", 1)
mod = dict(model.named_modules())[module_name]
if isinstance(getattr(mod, attr), nn.Parameter):
setattr(mod, attr, nn.Parameter(tensor, requires_grad=False))
else:
setattr(mod, attr, tensor)
missing, unexpected = [], []
if strict and (missing or unexpected):
raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}")
return model