ToastyPigeon commited on
Commit
a9979b5
·
verified ·
1 Parent(s): 24b5e5b

Upload convert_hf_to_scm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_hf_to_scm.py +179 -0
convert_hf_to_scm.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert GLM-4.7-Flash from HuggingFace format to ScatterMoE (SCM) format.
3
+
4
+ Usage:
5
+ python convert_hf_to_scm.py <input_model_path> <output_model_path>
6
+
7
+ Example:
8
+ python convert_hf_to_scm.py ~/.cache/huggingface/hub/models--zai-org--GLM-4.7-Flash/snapshots/<hash> ./GLM-4.7-Flash-SCM
9
+ """
10
+ import glob
11
+ import os
12
+ import re
13
+ import shutil
14
+ import sys
15
+
16
+ import accelerate
17
+ import torch
18
+ from safetensors import safe_open
19
+
20
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
+ from configuration_glm_scm import Glm4MoeLiteSCMConfig
22
+ from modeling_glm_scm import Glm4MoeLiteSCMForCausalLM
23
+
24
+ input_model = sys.argv[1]
25
+ output_model_path = sys.argv[2]
26
+
27
+ auto_map = {
28
+ "AutoConfig": "configuration_glm_scm.Glm4MoeLiteSCMConfig",
29
+ "AutoModel": "modeling_glm_scm.Glm4MoeLiteSCMModel",
30
+ "AutoModelForCausalLM": "modeling_glm_scm.Glm4MoeLiteSCMForCausalLM",
31
+ }
32
+
33
+ # Load original config - use our config class which can parse the original format
34
+ import json
35
+ with open(os.path.join(input_model, "config.json")) as f:
36
+ orig_config = json.load(f)
37
+
38
+ cfg_scm = Glm4MoeLiteSCMConfig(
39
+ auto_map=auto_map,
40
+ architectures=["Glm4MoeLiteSCMForCausalLM"],
41
+ vocab_size=orig_config["vocab_size"],
42
+ hidden_size=orig_config["hidden_size"],
43
+ intermediate_size=orig_config["intermediate_size"],
44
+ moe_intermediate_size=orig_config["moe_intermediate_size"],
45
+ num_hidden_layers=orig_config["num_hidden_layers"],
46
+ num_attention_heads=orig_config["num_attention_heads"],
47
+ num_key_value_heads=orig_config["num_key_value_heads"],
48
+ n_shared_experts=orig_config["n_shared_experts"],
49
+ n_routed_experts=orig_config["n_routed_experts"],
50
+ routed_scaling_factor=orig_config["routed_scaling_factor"],
51
+ kv_lora_rank=orig_config["kv_lora_rank"],
52
+ q_lora_rank=orig_config["q_lora_rank"],
53
+ qk_rope_head_dim=orig_config["qk_rope_head_dim"],
54
+ v_head_dim=orig_config["v_head_dim"],
55
+ qk_nope_head_dim=orig_config["qk_nope_head_dim"],
56
+ n_group=orig_config["n_group"],
57
+ topk_group=orig_config["topk_group"],
58
+ num_experts_per_tok=orig_config["num_experts_per_tok"],
59
+ norm_topk_prob=orig_config["norm_topk_prob"],
60
+ topk_method=orig_config["topk_method"],
61
+ first_k_dense_replace=orig_config.get("first_k_dense_replace", 1),
62
+ num_nextn_predict_layers=orig_config.get("num_nextn_predict_layers", 1),
63
+ hidden_act=orig_config["hidden_act"],
64
+ max_position_embeddings=orig_config["max_position_embeddings"],
65
+ rms_norm_eps=orig_config["rms_norm_eps"],
66
+ rope_theta=orig_config["rope_theta"],
67
+ rope_scaling=orig_config.get("rope_scaling", None),
68
+ rope_interleave=orig_config.get("rope_interleave", True),
69
+ attention_bias=orig_config.get("attention_bias", False),
70
+ attention_dropout=orig_config.get("attention_dropout", 0.0),
71
+ tie_word_embeddings=orig_config.get("tie_word_embeddings", False),
72
+ bos_token_id=orig_config.get("bos_token_id", 0),
73
+ eos_token_id=orig_config.get("eos_token_id", 1),
74
+ pad_token_id=orig_config.get("pad_token_id", None),
75
+ torch_dtype=orig_config.get("dtype", "bfloat16"),
76
+ )
77
+
78
+ num_experts = cfg_scm.n_routed_experts
79
+ num_layers = cfg_scm.num_hidden_layers
80
+
81
+ # Create empty model
82
+ with accelerate.init_empty_weights():
83
+ model_scm = Glm4MoeLiteSCMForCausalLM(cfg_scm)
84
+
85
+ model_scm = model_scm.to(torch.bfloat16)
86
+
87
+ # Load all tensors from safetensors files
88
+ new_state_dict = {}
89
+ pattern = f"{input_model}/model-*-of-*.safetensors"
90
+ files = sorted(glob.glob(pattern))
91
+ if len(files) == 0:
92
+ pattern = f"{input_model}/model.safetensors"
93
+ files = sorted(glob.glob(pattern))
94
+ if len(files) == 0:
95
+ raise FileNotFoundError(f"No safetensors files found in {input_model}")
96
+
97
+ tensors = {}
98
+ for file_path in files:
99
+ print(f"Loading {file_path}")
100
+ with safe_open(file_path, framework="pt", device="cpu") as f:
101
+ for key in f.keys():
102
+ tensors[key] = f.get_tensor(key)
103
+
104
+ print(f"Loaded {len(tensors)} tensors")
105
+
106
+ # Filter out layer 47+ (next-token prediction layers) if present
107
+ filtered_tensors = {}
108
+ for key in tensors:
109
+ layer_match = re.search(r"layers\.(\d+)", key)
110
+ if layer_match and int(layer_match.group(1)) >= num_layers:
111
+ print(f"Skipping next-token prediction layer key: {key}")
112
+ continue
113
+ filtered_tensors[key] = tensors[key]
114
+ tensors = filtered_tensors
115
+
116
+ # Convert weights
117
+ processed_layers = set()
118
+ for key in tensors:
119
+ if "mlp.experts" not in key or "shared_experts" in key:
120
+ # Non-expert weights: copy directly
121
+ new_state_dict[key] = tensors[key]
122
+ elif "experts.0." in key:
123
+ # First expert triggers conversion for the whole layer
124
+ layer_num = int(re.search(r"layers\.(\d+)", key).group(1))
125
+ if layer_num in processed_layers:
126
+ continue
127
+ processed_layers.add(layer_num)
128
+
129
+ print(f"Converting experts for layer {layer_num}")
130
+
131
+ # Stack down_proj -> output_experts.weight [n_experts, hidden_size, moe_intermediate_size]
132
+ new_state_dict[
133
+ f"model.layers.{layer_num}.mlp.moe_mlp.output_experts.weight"
134
+ ] = torch.stack(
135
+ [
136
+ tensors[f"model.layers.{layer_num}.mlp.experts.{i}.down_proj.weight"]
137
+ for i in range(num_experts)
138
+ ]
139
+ )
140
+
141
+ # Stack cat(up_proj, gate_proj) -> experts.weight [n_experts, 2*moe_intermediate_size, hidden_size]
142
+ new_state_dict[
143
+ f"model.layers.{layer_num}.mlp.moe_mlp.experts.weight"
144
+ ] = torch.stack(
145
+ [
146
+ torch.cat(
147
+ [
148
+ tensors[f"model.layers.{layer_num}.mlp.experts.{i}.up_proj.weight"],
149
+ tensors[f"model.layers.{layer_num}.mlp.experts.{i}.gate_proj.weight"],
150
+ ],
151
+ dim=0,
152
+ )
153
+ for i in range(num_experts)
154
+ ]
155
+ )
156
+
157
+ print(f"Converted state dict has {len(new_state_dict)} keys")
158
+
159
+ # Load and save
160
+ model_scm.load_state_dict(new_state_dict, strict=True, assign=True)
161
+ model_scm.save_pretrained(output_model_path)
162
+ cfg_scm.save_pretrained(output_model_path)
163
+
164
+ # Copy modeling and config files
165
+ script_dir = os.path.dirname(os.path.abspath(__file__))
166
+ for fname in ["modeling_glm_scm.py", "configuration_glm_scm.py"]:
167
+ shutil.copy(os.path.join(script_dir, fname), os.path.join(output_model_path, fname))
168
+
169
+ # Copy tokenizer files
170
+ for fname in os.listdir(input_model):
171
+ if fname.startswith("tokenizer") or fname in [
172
+ "special_tokens_map.json",
173
+ "chat_template.jinja",
174
+ ]:
175
+ src = os.path.join(input_model, fname)
176
+ if os.path.isfile(src):
177
+ shutil.copy(src, os.path.join(output_model_path, fname))
178
+
179
+ print(f"Model saved to {output_model_path}")