import torch import torch.nn as nn from typing import Callable, Union, Dict, Any from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig try: from torchao.quantization import FqnToConfig except ImportError: from torchao.quantization import ModuleFqnToConfig as FqnToConfig def load_torchao_fp8_static_model( *, ckpt_path: str, base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]], device: str = "cuda", strict: bool = True, ) -> nn.Module: ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu") if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")): raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}") # ------------------------- # Parse dtype # ------------------------- dtype_str = str(ckpt["fp8_dtype"]) if "float8_e4m3fn" in dtype_str: fp8_dtype = torch.float8_e4m3fn elif "float8_e5m2" in dtype_str: fp8_dtype = torch.float8_e5m2 else: raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}") # ------------------------- # Normalize scales to fp32 scalar tensors # ------------------------- act_scales_raw = {} for k, v in ckpt["act_scales"].items(): if torch.is_tensor(v): act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0] else: act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32) # ------------------------- # Build model # ------------------------- if isinstance(base_model_or_factory, nn.Module): model = base_model_or_factory else: model = base_model_or_factory() if model is None or not isinstance(model, nn.Module): raise TypeError("base_model_or_factory must return an nn.Module") model.eval().to(device) # ------------------------- # Collect Linear FQNs # ------------------------- linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)] linear_set = set(linear_fqns) # ------------------------- # Auto-fix FQN prefix mismatch # ------------------------- def score(keys): return sum(1 for k in keys if k in linear_set) candidates = [] # 1) identity candidates.append(act_scales_raw) # 2) strip "model." stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")} candidates.append(stripped) # 3) add "model." added = {("model." + k): v for k, v in act_scales_raw.items()} candidates.append(added) best = max(candidates, key=lambda d: score(d.keys())) if score(best.keys()) == 0: raise RuntimeError( "Could not match any activation scale keys to Linear layers.\n" f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n" f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}" ) act_scales = best # ------------------------- # Build torchao config map # ------------------------- fqn_to_cfg = {} for fqn in linear_fqns: if fqn in act_scales: fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig( scale=act_scales[fqn], activation_dtype=fp8_dtype, weight_dtype=fp8_dtype, granularity=PerTensor(), ) if not fqn_to_cfg: raise RuntimeError("No Linear layers matched activation scales.") try: cfg = FqnToConfig(fqn_to_config=fqn_to_cfg) except TypeError: cfg = FqnToConfig(fqn_to_cfg) # ------------------------- # Quantize structure first # ------------------------- quantize_(model, cfg, filter_fn=None, device=device) # ------------------------- # Load weights (CRITICAL: assign=True) # ------------------------- try: missing, unexpected = model.load_state_dict( ckpt["state_dict"], strict=strict, assign=True, # <-- fixes copy_ dispatch error ) except TypeError: # Fallback if PyTorch too old for name, tensor in ckpt["state_dict"].items(): module_name, attr = name.rsplit(".", 1) mod = dict(model.named_modules())[module_name] if isinstance(getattr(mod, attr), nn.Parameter): setattr(mod, attr, nn.Parameter(tensor, requires_grad=False)) else: setattr(mod, attr, tensor) missing, unexpected = [], [] if strict and (missing or unexpected): raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}") return model