Update convert_safetensors.py
Browse files- convert_safetensors.py +128 -128
convert_safetensors.py
CHANGED
|
@@ -1,129 +1,129 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import glob
|
| 3 |
-
from safetensors import safe_open
|
| 4 |
-
from safetensors.torch import save_file
|
| 5 |
-
import torch
|
| 6 |
-
import json
|
| 7 |
-
|
| 8 |
-
# Model directory
|
| 9 |
-
model_dir = "xai-org/grok-2"
|
| 10 |
-
output_dir =
|
| 11 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 12 |
-
|
| 13 |
-
# Collect all safetensors files
|
| 14 |
-
print("Collecting safetensors files...", flush=True)
|
| 15 |
-
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
|
| 16 |
-
if not safetensors_files:
|
| 17 |
-
raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
|
| 18 |
-
|
| 19 |
-
# Load all files into cache and build key-to-file mapping
|
| 20 |
-
file_cache = {} # file path -> {key: tensor}
|
| 21 |
-
key_to_files = {} # key -> [file paths]
|
| 22 |
-
total_size = 0
|
| 23 |
-
print("Loading safetensors files...", flush=True)
|
| 24 |
-
for file_path in safetensors_files:
|
| 25 |
-
try:
|
| 26 |
-
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 27 |
-
file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
|
| 28 |
-
for key, tensor in file_cache[file_path].items():
|
| 29 |
-
if key not in key_to_files:
|
| 30 |
-
key_to_files[key] = []
|
| 31 |
-
key_to_files[key].append(file_path)
|
| 32 |
-
total_size += tensor.element_size() * tensor.nelement()
|
| 33 |
-
except Exception as e:
|
| 34 |
-
print(f"Warning: Failed to load {file_path}: {e}")
|
| 35 |
-
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
|
| 36 |
-
|
| 37 |
-
# Merge TP shards
|
| 38 |
-
tp_count = 8 # TP=8
|
| 39 |
-
merged_state_dict = {}
|
| 40 |
-
print("Merging TP shards...", flush=True)
|
| 41 |
-
for key, file_paths in key_to_files.items():
|
| 42 |
-
if len(file_paths) > 1: # TP shards
|
| 43 |
-
print(f"Merging {key} shards...", flush=True)
|
| 44 |
-
# Sort by TP number
|
| 45 |
-
sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
|
| 46 |
-
tensors = []
|
| 47 |
-
for file_path in sorted_paths[:tp_count]:
|
| 48 |
-
if file_path in file_cache and key in file_cache[file_path]:
|
| 49 |
-
tensors.append(file_cache[file_path][key])
|
| 50 |
-
else:
|
| 51 |
-
print(f"Warning: Key {key} missing in {file_path}")
|
| 52 |
-
if len(tensors) == tp_count:
|
| 53 |
-
try:
|
| 54 |
-
# Determine concatenation dimension
|
| 55 |
-
dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
|
| 56 |
-
merged_tensor = torch.cat(tensors, dim=dim)
|
| 57 |
-
# Verify shape
|
| 58 |
-
if "block_sparse_moe.experts" in key:
|
| 59 |
-
if "w1.weight" in key or "w3.weight" in key:
|
| 60 |
-
expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
|
| 61 |
-
if merged_tensor.shape != expected_shape:
|
| 62 |
-
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
| 63 |
-
elif "w2.weight" in key:
|
| 64 |
-
expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
|
| 65 |
-
if merged_tensor.shape != expected_shape:
|
| 66 |
-
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
| 67 |
-
merged_state_dict[key] = merged_tensor
|
| 68 |
-
except Exception as e:
|
| 69 |
-
print(f"Failed to merge {key}: {e}")
|
| 70 |
-
merged_state_dict[key] = tensors[0] if tensors else None
|
| 71 |
-
else:
|
| 72 |
-
print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
|
| 73 |
-
merged_state_dict[key] = tensors[0] if tensors else None
|
| 74 |
-
else:
|
| 75 |
-
print(f"Processing {key} ...", flush=True)
|
| 76 |
-
# Non-TP shard
|
| 77 |
-
file_path = file_paths[0]
|
| 78 |
-
if file_path in file_cache and key in file_cache[file_path]:
|
| 79 |
-
merged_state_dict[key] = file_cache[file_path][key]
|
| 80 |
-
else:
|
| 81 |
-
print(f"Warning: Key {key} missing in {file_path}")
|
| 82 |
-
merged_state_dict[key] = None
|
| 83 |
-
|
| 84 |
-
# Group by layer
|
| 85 |
-
layer_dicts = {}
|
| 86 |
-
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
|
| 87 |
-
last_layer_idx = None
|
| 88 |
-
print("Grouping weights by layer...", flush=True)
|
| 89 |
-
for key in list(merged_state_dict.keys()):
|
| 90 |
-
if merged_state_dict[key] is None:
|
| 91 |
-
continue
|
| 92 |
-
if key in special_weights:
|
| 93 |
-
continue
|
| 94 |
-
if "model.layers." in key:
|
| 95 |
-
layer_num = int(key.split(".")[2])
|
| 96 |
-
if layer_num not in layer_dicts:
|
| 97 |
-
layer_dicts[layer_num] = {}
|
| 98 |
-
layer_dicts[layer_num][key] = merged_state_dict.pop(key)
|
| 99 |
-
last_layer_idx = max(last_layer_idx or 0, layer_num)
|
| 100 |
-
|
| 101 |
-
# Save weights for each layer
|
| 102 |
-
print("Saving weight files...", flush=True)
|
| 103 |
-
for layer_num in sorted(layer_dicts.keys()):
|
| 104 |
-
output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
|
| 105 |
-
save_file(layer_dicts[layer_num], output_file)
|
| 106 |
-
print(f"Saved layer {layer_num} to {output_file}")
|
| 107 |
-
|
| 108 |
-
# Save final layer (including special weights)
|
| 109 |
-
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
|
| 110 |
-
last_layer_dict = layer_dicts.get(last_layer_idx, {})
|
| 111 |
-
for key in special_weights:
|
| 112 |
-
if key in merged_state_dict and merged_state_dict[key] is not None:
|
| 113 |
-
last_layer_dict[key] = merged_state_dict[key]
|
| 114 |
-
save_file(last_layer_dict, last_layer_file)
|
| 115 |
-
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
|
| 116 |
-
|
| 117 |
-
# Generate new index
|
| 118 |
-
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
| 119 |
-
for layer_num in sorted(layer_dicts.keys()):
|
| 120 |
-
file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
|
| 121 |
-
for key in layer_dicts[layer_num]:
|
| 122 |
-
new_index["weight_map"][key] = file_name
|
| 123 |
-
for key in special_weights:
|
| 124 |
-
if key in merged_state_dict and merged_state_dict[key] is not None:
|
| 125 |
-
new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
|
| 126 |
-
|
| 127 |
-
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
|
| 128 |
-
json.dump(new_index, f, indent=2)
|
| 129 |
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from safetensors import safe_open
|
| 4 |
+
from safetensors.torch import save_file
|
| 5 |
+
import torch
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
# Model directory
|
| 9 |
+
model_dir = "xai-org/grok-2"
|
| 10 |
+
output_dir = "huihui-ai/grok-2"
|
| 11 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# Collect all safetensors files
|
| 14 |
+
print("Collecting safetensors files...", flush=True)
|
| 15 |
+
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
|
| 16 |
+
if not safetensors_files:
|
| 17 |
+
raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")
|
| 18 |
+
|
| 19 |
+
# Load all files into cache and build key-to-file mapping
|
| 20 |
+
file_cache = {} # file path -> {key: tensor}
|
| 21 |
+
key_to_files = {} # key -> [file paths]
|
| 22 |
+
total_size = 0
|
| 23 |
+
print("Loading safetensors files...", flush=True)
|
| 24 |
+
for file_path in safetensors_files:
|
| 25 |
+
try:
|
| 26 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 27 |
+
file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
|
| 28 |
+
for key, tensor in file_cache[file_path].items():
|
| 29 |
+
if key not in key_to_files:
|
| 30 |
+
key_to_files[key] = []
|
| 31 |
+
key_to_files[key].append(file_path)
|
| 32 |
+
total_size += tensor.element_size() * tensor.nelement()
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Warning: Failed to load {file_path}: {e}")
|
| 35 |
+
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)
|
| 36 |
+
|
| 37 |
+
# Merge TP shards
|
| 38 |
+
tp_count = 8 # TP=8
|
| 39 |
+
merged_state_dict = {}
|
| 40 |
+
print("Merging TP shards...", flush=True)
|
| 41 |
+
for key, file_paths in key_to_files.items():
|
| 42 |
+
if len(file_paths) > 1: # TP shards
|
| 43 |
+
print(f"Merging {key} shards...", flush=True)
|
| 44 |
+
# Sort by TP number
|
| 45 |
+
sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
|
| 46 |
+
tensors = []
|
| 47 |
+
for file_path in sorted_paths[:tp_count]:
|
| 48 |
+
if file_path in file_cache and key in file_cache[file_path]:
|
| 49 |
+
tensors.append(file_cache[file_path][key])
|
| 50 |
+
else:
|
| 51 |
+
print(f"Warning: Key {key} missing in {file_path}")
|
| 52 |
+
if len(tensors) == tp_count:
|
| 53 |
+
try:
|
| 54 |
+
# Determine concatenation dimension
|
| 55 |
+
dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
|
| 56 |
+
merged_tensor = torch.cat(tensors, dim=dim)
|
| 57 |
+
# Verify shape
|
| 58 |
+
if "block_sparse_moe.experts" in key:
|
| 59 |
+
if "w1.weight" in key or "w3.weight" in key:
|
| 60 |
+
expected_shape = (16384, 8192) # moe_intermediate_size, hidden_size
|
| 61 |
+
if merged_tensor.shape != expected_shape:
|
| 62 |
+
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
| 63 |
+
elif "w2.weight" in key:
|
| 64 |
+
expected_shape = (8192, 16384) # hidden_size, moe_intermediate_size
|
| 65 |
+
if merged_tensor.shape != expected_shape:
|
| 66 |
+
print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
|
| 67 |
+
merged_state_dict[key] = merged_tensor
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Failed to merge {key}: {e}")
|
| 70 |
+
merged_state_dict[key] = tensors[0] if tensors else None
|
| 71 |
+
else:
|
| 72 |
+
print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
|
| 73 |
+
merged_state_dict[key] = tensors[0] if tensors else None
|
| 74 |
+
else:
|
| 75 |
+
print(f"Processing {key} ...", flush=True)
|
| 76 |
+
# Non-TP shard
|
| 77 |
+
file_path = file_paths[0]
|
| 78 |
+
if file_path in file_cache and key in file_cache[file_path]:
|
| 79 |
+
merged_state_dict[key] = file_cache[file_path][key]
|
| 80 |
+
else:
|
| 81 |
+
print(f"Warning: Key {key} missing in {file_path}")
|
| 82 |
+
merged_state_dict[key] = None
|
| 83 |
+
|
| 84 |
+
# Group by layer
|
| 85 |
+
layer_dicts = {}
|
| 86 |
+
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
|
| 87 |
+
last_layer_idx = None
|
| 88 |
+
print("Grouping weights by layer...", flush=True)
|
| 89 |
+
for key in list(merged_state_dict.keys()):
|
| 90 |
+
if merged_state_dict[key] is None:
|
| 91 |
+
continue
|
| 92 |
+
if key in special_weights:
|
| 93 |
+
continue
|
| 94 |
+
if "model.layers." in key:
|
| 95 |
+
layer_num = int(key.split(".")[2])
|
| 96 |
+
if layer_num not in layer_dicts:
|
| 97 |
+
layer_dicts[layer_num] = {}
|
| 98 |
+
layer_dicts[layer_num][key] = merged_state_dict.pop(key)
|
| 99 |
+
last_layer_idx = max(last_layer_idx or 0, layer_num)
|
| 100 |
+
|
| 101 |
+
# Save weights for each layer
|
| 102 |
+
print("Saving weight files...", flush=True)
|
| 103 |
+
for layer_num in sorted(layer_dicts.keys()):
|
| 104 |
+
output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
|
| 105 |
+
save_file(layer_dicts[layer_num], output_file)
|
| 106 |
+
print(f"Saved layer {layer_num} to {output_file}")
|
| 107 |
+
|
| 108 |
+
# Save final layer (including special weights)
|
| 109 |
+
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
|
| 110 |
+
last_layer_dict = layer_dicts.get(last_layer_idx, {})
|
| 111 |
+
for key in special_weights:
|
| 112 |
+
if key in merged_state_dict and merged_state_dict[key] is not None:
|
| 113 |
+
last_layer_dict[key] = merged_state_dict[key]
|
| 114 |
+
save_file(last_layer_dict, last_layer_file)
|
| 115 |
+
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)
|
| 116 |
+
|
| 117 |
+
# Generate new index
|
| 118 |
+
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
| 119 |
+
for layer_num in sorted(layer_dicts.keys()):
|
| 120 |
+
file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
|
| 121 |
+
for key in layer_dicts[layer_num]:
|
| 122 |
+
new_index["weight_map"][key] = file_name
|
| 123 |
+
for key in special_weights:
|
| 124 |
+
if key in merged_state_dict and merged_state_dict[key] is not None:
|
| 125 |
+
new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"
|
| 126 |
+
|
| 127 |
+
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
|
| 128 |
+
json.dump(new_index, f, indent=2)
|
| 129 |
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)
|