David-PHR commited on
Commit
9e016c4
·
verified ·
1 Parent(s): a40ff06

Upload folder using huggingface_hub

Browse files
load_torchao.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Callable, Union, Dict, Any
4
+
5
+ from torchao.quantization import quantize_, PerTensor, Float8StaticActivationFloat8WeightConfig
6
+ try:
7
+ from torchao.quantization import FqnToConfig
8
+ except ImportError:
9
+ from torchao.quantization import ModuleFqnToConfig as FqnToConfig
10
+
11
+
12
+ def load_torchao_fp8_static_model(
13
+ *,
14
+ ckpt_path: str,
15
+ base_model_or_factory: Union[nn.Module, Callable[[], nn.Module]],
16
+ device: str = "cuda",
17
+ strict: bool = True,
18
+ ) -> nn.Module:
19
+
20
+ ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu")
21
+
22
+ if not all(k in ckpt for k in ("state_dict", "act_scales", "fp8_dtype")):
23
+ raise ValueError(f"Checkpoint missing required keys. Found: {list(ckpt.keys())}")
24
+
25
+ # -------------------------
26
+ # Parse dtype
27
+ # -------------------------
28
+ dtype_str = str(ckpt["fp8_dtype"])
29
+ if "float8_e4m3fn" in dtype_str:
30
+ fp8_dtype = torch.float8_e4m3fn
31
+ elif "float8_e5m2" in dtype_str:
32
+ fp8_dtype = torch.float8_e5m2
33
+ else:
34
+ raise ValueError(f"Unsupported fp8 dtype string: {dtype_str}")
35
+
36
+ # -------------------------
37
+ # Normalize scales to fp32 scalar tensors
38
+ # -------------------------
39
+ act_scales_raw = {}
40
+ for k, v in ckpt["act_scales"].items():
41
+ if torch.is_tensor(v):
42
+ act_scales_raw[k] = v.detach().to(torch.float32).reshape(-1)[0]
43
+ else:
44
+ act_scales_raw[k] = torch.tensor(float(v), dtype=torch.float32)
45
+
46
+ # -------------------------
47
+ # Build model
48
+ # -------------------------
49
+ if isinstance(base_model_or_factory, nn.Module):
50
+ model = base_model_or_factory
51
+ else:
52
+ model = base_model_or_factory()
53
+
54
+ if model is None or not isinstance(model, nn.Module):
55
+ raise TypeError("base_model_or_factory must return an nn.Module")
56
+
57
+ model.eval().to(device)
58
+
59
+ # -------------------------
60
+ # Collect Linear FQNs
61
+ # -------------------------
62
+ linear_fqns = [fqn for fqn, m in model.named_modules() if isinstance(m, nn.Linear)]
63
+ linear_set = set(linear_fqns)
64
+
65
+ # -------------------------
66
+ # Auto-fix FQN prefix mismatch
67
+ # -------------------------
68
+ def score(keys):
69
+ return sum(1 for k in keys if k in linear_set)
70
+
71
+ candidates = []
72
+
73
+ # 1) identity
74
+ candidates.append(act_scales_raw)
75
+
76
+ # 2) strip "model."
77
+ stripped = {k[6:]: v for k, v in act_scales_raw.items() if k.startswith("model.")}
78
+ candidates.append(stripped)
79
+
80
+ # 3) add "model."
81
+ added = {("model." + k): v for k, v in act_scales_raw.items()}
82
+ candidates.append(added)
83
+
84
+ best = max(candidates, key=lambda d: score(d.keys()))
85
+ if score(best.keys()) == 0:
86
+ raise RuntimeError(
87
+ "Could not match any activation scale keys to Linear layers.\n"
88
+ f"Example Linear FQNs:\n{linear_fqns[:20]}\n\n"
89
+ f"Example scale keys:\n{list(act_scales_raw.keys())[:20]}"
90
+ )
91
+
92
+ act_scales = best
93
+
94
+ # -------------------------
95
+ # Build torchao config map
96
+ # -------------------------
97
+ fqn_to_cfg = {}
98
+ for fqn in linear_fqns:
99
+ if fqn in act_scales:
100
+ fqn_to_cfg[fqn] = Float8StaticActivationFloat8WeightConfig(
101
+ scale=act_scales[fqn],
102
+ activation_dtype=fp8_dtype,
103
+ weight_dtype=fp8_dtype,
104
+ granularity=PerTensor(),
105
+ )
106
+
107
+ if not fqn_to_cfg:
108
+ raise RuntimeError("No Linear layers matched activation scales.")
109
+
110
+ try:
111
+ cfg = FqnToConfig(fqn_to_config=fqn_to_cfg)
112
+ except TypeError:
113
+ cfg = FqnToConfig(fqn_to_cfg)
114
+
115
+ # -------------------------
116
+ # Quantize structure first
117
+ # -------------------------
118
+ quantize_(model, cfg, filter_fn=None, device=device)
119
+
120
+ # -------------------------
121
+ # Load weights (CRITICAL: assign=True)
122
+ # -------------------------
123
+ try:
124
+ missing, unexpected = model.load_state_dict(
125
+ ckpt["state_dict"],
126
+ strict=strict,
127
+ assign=True, # <-- fixes copy_ dispatch error
128
+ )
129
+ except TypeError:
130
+ # Fallback if PyTorch too old
131
+ for name, tensor in ckpt["state_dict"].items():
132
+ module_name, attr = name.rsplit(".", 1)
133
+ mod = dict(model.named_modules())[module_name]
134
+ if isinstance(getattr(mod, attr), nn.Parameter):
135
+ setattr(mod, attr, nn.Parameter(tensor, requires_grad=False))
136
+ else:
137
+ setattr(mod, attr, tensor)
138
+ missing, unexpected = [], []
139
+
140
+ if strict and (missing or unexpected):
141
+ raise RuntimeError(f"load_state_dict mismatch. missing={missing} unexpected={unexpected}")
142
+
143
+ return model
144
+
transformer_bf16/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Flux2Transformer2DModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "attention_head_dim": 128,
5
+ "axes_dims_rope": [
6
+ 32,
7
+ 32,
8
+ 32,
9
+ 32
10
+ ],
11
+ "eps": 1e-06,
12
+ "guidance_embeds": false,
13
+ "in_channels": 128,
14
+ "joint_attention_dim": 7680,
15
+ "mlp_ratio": 3.0,
16
+ "num_attention_heads": 24,
17
+ "num_layers": 5,
18
+ "num_single_layers": 20,
19
+ "out_channels": null,
20
+ "patch_size": 1,
21
+ "rope_theta": 2000,
22
+ "timestep_guidance_channels": 256
23
+ }
transformer_bf16/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fa71fe800721fd1d3184a41ce0d8938b1c7d393a70247d9630bd0b8f3d60a85
3
+ size 7751109744
transformer_fp8_static/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Flux2Transformer2DModel",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "attention_head_dim": 128,
5
+ "axes_dims_rope": [
6
+ 32,
7
+ 32,
8
+ 32,
9
+ 32
10
+ ],
11
+ "eps": 1e-06,
12
+ "guidance_embeds": false,
13
+ "in_channels": 128,
14
+ "joint_attention_dim": 7680,
15
+ "mlp_ratio": 3.0,
16
+ "num_attention_heads": 24,
17
+ "num_layers": 5,
18
+ "num_single_layers": 20,
19
+ "out_channels": null,
20
+ "patch_size": 1,
21
+ "rope_theta": 2000,
22
+ "timestep_guidance_channels": 256
23
+ }
transformer_fp8_static/model_fp8_static.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13b37a5ca5cd9cf190236e7e99a3f086cf24618682f74e27a6f00cb173c308c8
3
+ size 4070791292