File size: 4,689 Bytes
9e016c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
from typing import Callable, Union, Dict, Any

from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig
try:
    from torchao.quantization import FqnToConfig
except ImportError:
    from torchao.quantization import ModuleFqnToConfig as FqnToConfig


def load_torchao_fp8_static_model(
    *,
    ckpt_path: str,
    base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]],
    device: str = "cuda",
    strict: bool = True,
) -> nn.Module:

    ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu")

    if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")):
        raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}")

    # -------------------------
    # Parse dtype
    # -------------------------
    dtype_str = str(ckpt["fp8_dtype"])
    if "float8_e4m3fn" in dtype_str:
        fp8_dtype = torch.float8_e4m3fn
    elif "float8_e5m2" in dtype_str:
        fp8_dtype = torch.float8_e5m2
    else:
        raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}")

    # -------------------------
    # Normalize scales to fp32 scalar tensors
    # -------------------------
    act_scales_raw = {}
    for k, v in ckpt["act_scales"].items():
        if torch.is_tensor(v):
            act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0]
        else:
            act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32)

    # -------------------------
    # Build model
    # -------------------------
    if isinstance(base_model_or_factory, nn.Module):
        model = base_model_or_factory
    else:
        model = base_model_or_factory()

    if model is None or not isinstance(model, nn.Module):
        raise TypeError("base_model_or_factory must return an nn.Module")

    model.eval().to(device)

    # -------------------------
    # Collect Linear FQNs
    # -------------------------
    linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)]
    linear_set = set(linear_fqns)

    # -------------------------
    # Auto-fix FQN prefix mismatch
    # -------------------------
    def score(keys):
        return sum(1 for k in keys if k in linear_set)

    candidates = []

    # 1) identity
    candidates.append(act_scales_raw)

    # 2) strip "model."
    stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")}
    candidates.append(stripped)

    # 3) add "model."
    added = {("model." + k): v for k, v in act_scales_raw.items()}
    candidates.append(added)

    best = max(candidates, key=lambda d: score(d.keys()))
    if score(best.keys()) == 0:
        raise RuntimeError(
            "Could not match any activation scale keys to Linear layers.\n"
            f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n"
            f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}"
        )

    act_scales = best

    # -------------------------
    # Build torchao config map
    # -------------------------
    fqn_to_cfg = {}
    for fqn in linear_fqns:
        if fqn in act_scales:
            fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig(
                scale=act_scales[fqn],
                activation_dtype=fp8_dtype,
                weight_dtype=fp8_dtype,
                granularity=PerTensor(),
            )

    if not fqn_to_cfg:
        raise RuntimeError("No Linear layers matched activation scales.")

    try:
        cfg = FqnToConfig(fqn_to_config=fqn_to_cfg)
    except TypeError:
        cfg = FqnToConfig(fqn_to_cfg)

    # -------------------------
    # Quantize structure first
    # -------------------------
    quantize_(model, cfg, filter_fn=None, device=device)

    # -------------------------
    # Load weights (CRITICAL: assign=True)
    # -------------------------
    try:
        missing, unexpected = model.load_state_dict(
            ckpt["state_dict"],
            strict=strict,
            assign=True,  # <-- fixes copy_ dispatch error
        )
    except TypeError:
        # Fallback if PyTorch too old
        for name, tensor in ckpt["state_dict"].items():
            module_name, attr = name.rsplit(".", 1)
            mod = dict(model.named_modules())[module_name]
            if isinstance(getattr(mod, attr), nn.Parameter):
                setattr(mod, attr, nn.Parameter(tensor, requires_grad=False))
            else:
                setattr(mod, attr, tensor)
        missing, unexpected = [], []

    if strict and (missing or unexpected):
        raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}")

    return model