| | import torch |
| | import torch.nn as nn |
| | from typing import Callable, Union, Dict, Any |
| |
|
| | from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig |
| | try: |
| | from torchao.quantization import FqnToConfig |
| | except ImportError: |
| | from torchao.quantization import ModuleFqnToConfig as FqnToConfig |
| |
|
| |
|
| | def load_torchao_fp8_static_model( |
| | *, |
| | ckpt_path: str, |
| | base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]], |
| | device: str = "cuda", |
| | strict: bool = True, |
| | ) -> nn.Module: |
| |
|
| | ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu") |
| |
|
| | if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")): |
| | raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}") |
| |
|
| | |
| | |
| | |
| | dtype_str = str(ckpt["fp8_dtype"]) |
| | if "float8_e4m3fn" in dtype_str: |
| | fp8_dtype = torch.float8_e4m3fn |
| | elif "float8_e5m2" in dtype_str: |
| | fp8_dtype = torch.float8_e5m2 |
| | else: |
| | raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}") |
| |
|
| | |
| | |
| | |
| | act_scales_raw = {} |
| | for k, v in ckpt["act_scales"].items(): |
| | if torch.is_tensor(v): |
| | act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0] |
| | else: |
| | act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32) |
| |
|
| | |
| | |
| | |
| | if isinstance(base_model_or_factory, nn.Module): |
| | model = base_model_or_factory |
| | else: |
| | model = base_model_or_factory() |
| |
|
| | if model is None or not isinstance(model, nn.Module): |
| | raise TypeError("base_model_or_factory must return an nn.Module") |
| |
|
| | model.eval().to(device) |
| |
|
| | |
| | |
| | |
| | linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)] |
| | linear_set = set(linear_fqns) |
| |
|
| | |
| | |
| | |
| | def score(keys): |
| | return sum(1 for k in keys if k in linear_set) |
| |
|
| | candidates = [] |
| |
|
| | |
| | candidates.append(act_scales_raw) |
| |
|
| | |
| | stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")} |
| | candidates.append(stripped) |
| |
|
| | |
| | added = {("model." + k): v for k, v in act_scales_raw.items()} |
| | candidates.append(added) |
| |
|
| | best = max(candidates, key=lambda d: score(d.keys())) |
| | if score(best.keys()) == 0: |
| | raise RuntimeError( |
| | "Could not match any activation scale keys to Linear layers.\n" |
| | f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n" |
| | f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}" |
| | ) |
| |
|
| | act_scales = best |
| |
|
| | |
| | |
| | |
| | fqn_to_cfg = {} |
| | for fqn in linear_fqns: |
| | if fqn in act_scales: |
| | fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig( |
| | scale=act_scales[fqn], |
| | activation_dtype=fp8_dtype, |
| | weight_dtype=fp8_dtype, |
| | granularity=PerTensor(), |
| | ) |
| |
|
| | if not fqn_to_cfg: |
| | raise RuntimeError("No Linear layers matched activation scales.") |
| |
|
| | try: |
| | cfg = FqnToConfig(fqn_to_config=fqn_to_cfg) |
| | except TypeError: |
| | cfg = FqnToConfig(fqn_to_cfg) |
| |
|
| | |
| | |
| | |
| | quantize_(model, cfg, filter_fn=None, device=device) |
| |
|
| | |
| | |
| | |
| | try: |
| | missing, unexpected = model.load_state_dict( |
| | ckpt["state_dict"], |
| | strict=strict, |
| | assign=True, |
| | ) |
| | except TypeError: |
| | |
| | for name, tensor in ckpt["state_dict"].items(): |
| | module_name, attr = name.rsplit(".", 1) |
| | mod = dict(model.named_modules())[module_name] |
| | if isinstance(getattr(mod, attr), nn.Parameter): |
| | setattr(mod, attr, nn.Parameter(tensor, requires_grad=False)) |
| | else: |
| | setattr(mod, attr, tensor) |
| | missing, unexpected = [], [] |
| |
|
| | if strict and (missing or unexpected): |
| | raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}") |
| |
|
| | return model |
| |
|
| |
|