minpeter commited on
Commit
d9f34c1
·
verified ·
1 Parent(s): ebb29ac

Upload extract_llm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. extract_llm.py +302 -0
extract_llm.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM.
4
+ Converts to LLaMA-compatible format for standard inference engines.
5
+
6
+ Usage:
7
+ python extract_llm.py --input ./HyperCLOVAX-SEED-Think-32B --output ./HyperCLOVAX-SEED-Text-Think-32B
8
+
9
+ Requirements:
10
+ pip install safetensors torch tqdm
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import os
16
+ import shutil
17
+ from pathlib import Path
18
+ from collections import defaultdict
19
+ from safetensors import safe_open
20
+ from safetensors.torch import save_file
21
+ import torch
22
+ from tqdm import tqdm
23
+
24
+
25
+ def load_weight_index(model_path: Path) -> dict:
26
+ """Load the safetensors weight index file."""
27
+ index_path = model_path / "model.safetensors.index.json"
28
+ with open(index_path, "r") as f:
29
+ return json.load(f)
30
+
31
+
32
+ def extract_llm_weights(model_path: Path, output_path: Path):
33
+ """
34
+ Extract LLM weights from VLM.
35
+
36
+ Key mapping:
37
+ - model.language_model.model.* → model.*
38
+ - model.language_model.lm_head.* → lm_head.*
39
+
40
+ All vision encoder and MM projector weights are excluded.
41
+ """
42
+ output_path.mkdir(parents=True, exist_ok=True)
43
+
44
+ weight_index = load_weight_index(model_path)
45
+ weight_map = weight_index["weight_map"]
46
+
47
+ # Filter and remap LLM weights
48
+ llm_weights = {}
49
+ for key, shard_file in weight_map.items():
50
+ if key.startswith("model.language_model."):
51
+ if key.startswith("model.language_model.model."):
52
+ new_key = key.replace("model.language_model.model.", "model.")
53
+ elif key.startswith("model.language_model.lm_head."):
54
+ new_key = key.replace("model.language_model.", "")
55
+ else:
56
+ new_key = key.replace("model.language_model.", "")
57
+ llm_weights[new_key] = (key, shard_file)
58
+
59
+ print(f"Found {len(llm_weights)} LLM weight tensors")
60
+ print(f"Excluded {len(weight_map) - len(llm_weights)} vision/projector tensors")
61
+
62
+ # Group by source shard for efficient loading
63
+ shard_to_weights = defaultdict(list)
64
+ for new_key, (old_key, shard_file) in llm_weights.items():
65
+ shard_to_weights[shard_file].append((old_key, new_key))
66
+
67
+ # Load all LLM tensors
68
+ all_tensors = {}
69
+ shard_files = sorted(set(shard_to_weights.keys()))
70
+
71
+ print(f"\nLoading weights from {len(shard_files)} shards...")
72
+ for shard_file in tqdm(shard_files, desc="Loading shards"):
73
+ shard_path = model_path / shard_file
74
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
75
+ for old_key, new_key in shard_to_weights[shard_file]:
76
+ tensor = f.get_tensor(old_key)
77
+ all_tensors[new_key] = tensor
78
+
79
+ print(f"\nTotal tensors extracted: {len(all_tensors)}")
80
+
81
+ total_size = sum(t.numel() * t.element_size() for t in all_tensors.values())
82
+ print(f"Total size: {total_size / 1e9:.2f} GB")
83
+
84
+ # Save as sharded safetensors (~5GB per shard)
85
+ max_shard_size = 5 * 1024 * 1024 * 1024
86
+
87
+ print("\nSaving extracted weights...")
88
+ save_sharded_safetensors(all_tensors, output_path, max_shard_size)
89
+
90
+ return list(all_tensors.keys())
91
+
92
+
93
+ def save_sharded_safetensors(tensors: dict, output_path: Path, max_shard_size: int):
94
+ """Save tensors as sharded safetensors files with index."""
95
+ sorted_keys = sorted(tensors.keys())
96
+
97
+ shards = []
98
+ current_shard = {}
99
+ current_size = 0
100
+ shard_idx = 1
101
+ weight_map = {}
102
+
103
+ for key in sorted_keys:
104
+ tensor = tensors[key]
105
+ tensor_size = tensor.numel() * tensor.element_size()
106
+
107
+ if current_size + tensor_size > max_shard_size and current_shard:
108
+ shards.append((shard_idx, current_shard))
109
+ shard_idx += 1
110
+ current_shard = {}
111
+ current_size = 0
112
+
113
+ current_shard[key] = tensor
114
+ current_size += tensor_size
115
+
116
+ if current_shard:
117
+ shards.append((shard_idx, current_shard))
118
+
119
+ total_shards = len(shards)
120
+ total_size = sum(t.numel() * t.element_size() for t in tensors.values())
121
+
122
+ for shard_idx, shard_tensors in tqdm(shards, desc="Saving shards"):
123
+ shard_name = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors"
124
+ shard_path = output_path / shard_name
125
+ save_file(shard_tensors, shard_path)
126
+
127
+ for key in shard_tensors.keys():
128
+ weight_map[key] = shard_name
129
+
130
+ # Create index file
131
+ index = {
132
+ "metadata": {"total_size": total_size},
133
+ "weight_map": weight_map
134
+ }
135
+ index_path = output_path / "model.safetensors.index.json"
136
+ with open(index_path, "w") as f:
137
+ json.dump(index, f, indent=2)
138
+
139
+ print(f"Saved {total_shards} shards to {output_path}")
140
+
141
+
142
+ def create_llama_config(original_config_path: Path, output_path: Path):
143
+ """
144
+ Create LLaMA-compatible config from VLM config.
145
+
146
+ Note: HyperCLOVAX uses attention_multiplier ≈ 1/sqrt(head_dim)
147
+ which matches standard LLaMA scaled dot-product attention.
148
+ """
149
+ with open(original_config_path, "r") as f:
150
+ vlm_config = json.load(f)
151
+
152
+ text_config = vlm_config["text_config"]
153
+
154
+ llama_config = {
155
+ "architectures": ["LlamaForCausalLM"],
156
+ "attention_bias": text_config.get("attention_bias", False),
157
+ "attention_dropout": text_config.get("attention_dropout", 0.0),
158
+ "bos_token_id": text_config.get("bos_token_id", 128000),
159
+ "eos_token_id": text_config.get("eos_token_id", 128001),
160
+ "head_dim": text_config.get("head_dim", 128),
161
+ "hidden_act": text_config.get("hidden_act", "silu"),
162
+ "hidden_size": text_config.get("hidden_size", 5120),
163
+ "initializer_range": text_config.get("initializer_range", 0.006),
164
+ "intermediate_size": text_config.get("intermediate_size", 24192),
165
+ "max_position_embeddings": text_config.get("max_position_embeddings", 131072),
166
+ "mlp_bias": text_config.get("mlp_bias", False),
167
+ "model_type": "llama",
168
+ "num_attention_heads": text_config.get("num_attention_heads", 40),
169
+ "num_hidden_layers": text_config.get("num_hidden_layers", 72),
170
+ "num_key_value_heads": text_config.get("num_key_value_heads", 8),
171
+ "pad_token_id": text_config.get("pad_token_id", 0),
172
+ "pretraining_tp": 1,
173
+ "rms_norm_eps": text_config.get("rms_norm_eps", 1e-05),
174
+ "rope_scaling": text_config.get("rope_scaling", None),
175
+ "rope_theta": text_config.get("rope_theta", 50000000),
176
+ "tie_word_embeddings": text_config.get("tie_word_embeddings", False),
177
+ "torch_dtype": "bfloat16",
178
+ "transformers_version": "4.52.4",
179
+ "use_cache": True,
180
+ "vocab_size": text_config.get("vocab_size", 128256),
181
+ }
182
+
183
+ config_path = output_path / "config.json"
184
+ with open(config_path, "w") as f:
185
+ json.dump(llama_config, f, indent=2)
186
+
187
+ print(f"Saved LLaMA config to {config_path}")
188
+
189
+ # Generation config
190
+ gen_config = {
191
+ "bos_token_id": llama_config["bos_token_id"],
192
+ "eos_token_id": llama_config["eos_token_id"],
193
+ "pad_token_id": llama_config["pad_token_id"],
194
+ "do_sample": True,
195
+ "temperature": 0.7,
196
+ "top_p": 0.9,
197
+ "max_length": 4096
198
+ }
199
+ gen_config_path = output_path / "generation_config.json"
200
+ with open(gen_config_path, "w") as f:
201
+ json.dump(gen_config, f, indent=2)
202
+
203
+ return llama_config
204
+
205
+
206
+ def copy_tokenizer_files(original_path: Path, output_path: Path):
207
+ """Copy tokenizer files from original model."""
208
+ tokenizer_files = [
209
+ "tokenizer.json",
210
+ "tokenizer_config.json",
211
+ "special_tokens_map.json",
212
+ "added_tokens.json",
213
+ "vocab.json",
214
+ "merges.txt",
215
+ "chat_template.jinja"
216
+ ]
217
+
218
+ copied = []
219
+ for fname in tokenizer_files:
220
+ src = original_path / fname
221
+ if src.exists():
222
+ dst = output_path / fname
223
+ shutil.copy2(src, dst)
224
+ copied.append(fname)
225
+
226
+ print(f"Copied tokenizer files: {copied}")
227
+
228
+
229
+ def main():
230
+ parser = argparse.ArgumentParser(
231
+ description="Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM",
232
+ formatter_class=argparse.RawDescriptionHelpFormatter,
233
+ epilog="""
234
+ Example:
235
+ # Download original VLM
236
+ huggingface-cli download naver-hyperclovax/HyperCLOVAX-SEED-Think-32B \\
237
+ --local-dir ./HyperCLOVAX-SEED-Think-32B
238
+
239
+ # Extract text-only LLM
240
+ python extract_llm.py \\
241
+ --input ./HyperCLOVAX-SEED-Think-32B \\
242
+ --output ./HyperCLOVAX-SEED-Text-Think-32B
243
+ """
244
+ )
245
+ parser.add_argument(
246
+ "--input", "-i",
247
+ type=Path,
248
+ required=True,
249
+ help="Path to original HyperCLOVAX-SEED-Think-32B VLM"
250
+ )
251
+ parser.add_argument(
252
+ "--output", "-o",
253
+ type=Path,
254
+ required=True,
255
+ help="Output path for extracted text-only LLM"
256
+ )
257
+
258
+ args = parser.parse_args()
259
+
260
+ if not args.input.exists():
261
+ print(f"Error: Input path does not exist: {args.input}")
262
+ return 1
263
+
264
+ if not (args.input / "model.safetensors.index.json").exists():
265
+ print(f"Error: model.safetensors.index.json not found in {args.input}")
266
+ return 1
267
+
268
+ print("=" * 60)
269
+ print("HyperCLOVAX VLM → Text-only LLM Extraction")
270
+ print("=" * 60)
271
+ print(f"Input: {args.input}")
272
+ print(f"Output: {args.output}")
273
+
274
+ print("\n[Step 1] Extracting LLM weights...")
275
+ extracted_keys = extract_llm_weights(args.input, args.output)
276
+
277
+ print("\n[Step 2] Creating LLaMA-compatible config...")
278
+ config = create_llama_config(args.input / "config.json", args.output)
279
+
280
+ print("\n[Step 3] Copying tokenizer files...")
281
+ copy_tokenizer_files(args.input, args.output)
282
+
283
+ print("\n" + "=" * 60)
284
+ print("Extraction complete!")
285
+ print(f"Output: {args.output}")
286
+ print("=" * 60)
287
+
288
+ print(f"\nModel summary:")
289
+ print(f" - Architecture: LlamaForCausalLM")
290
+ print(f" - Hidden size: {config['hidden_size']}")
291
+ print(f" - Layers: {config['num_hidden_layers']}")
292
+ print(f" - Attention heads: {config['num_attention_heads']}")
293
+ print(f" - KV heads: {config['num_key_value_heads']}")
294
+ print(f" - Vocab size: {config['vocab_size']}")
295
+ print(f" - Max context: {config['max_position_embeddings']}")
296
+
297
+ print(f"\nYou can now use the model with vLLM, transformers, or other LLaMA-compatible frameworks.")
298
+ return 0
299
+
300
+
301
+ if __name__ == "__main__":
302
+ exit(main())