ngxson's picture
ngxson HF Staff
Upload folder using huggingface_hub
068d828 verified
import os
import json
import re
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
import torch
def main():
src_dir = Path("../GLM-4.7-Flash")
dst_path = Path("model.safetensors")
num_experts_to_keep = 2
# Find all safetensors files
safetensor_files = sorted(src_dir.glob("*.safetensors"))
print(f"Found {len(safetensor_files)} safetensors files")
# Pattern to match expert weights
expert_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\..+")
gate_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.weight")
bias_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.e_score_correction_bias")
new_tensors = {}
for sf_path in safetensor_files:
print(f"Processing {sf_path.name}...")
with safe_open(sf_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Check if this is an expert weight
expert_match = expert_pattern.search(key)
if expert_match:
layer_idx = int(expert_match.group(1))
expert_idx = int(expert_match.group(2))
if expert_idx >= num_experts_to_keep:
print(f" Skipping {key} (expert {expert_idx} >= {num_experts_to_keep})")
continue
new_tensors[key] = tensor
continue
# Check if this is the gate weight
gate_match = gate_pattern.search(key)
if gate_match:
layer_idx = int(gate_match.group(1))
original_shape = tensor.shape
# Gate weight is [num_experts, hidden_size], keep first 8 experts
new_tensor = tensor[:num_experts_to_keep, :]
print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}")
new_tensors[key] = new_tensor
continue
# Check if this is the e_score_correction_bias
bias_match = bias_pattern.search(key)
if bias_match:
layer_idx = int(bias_match.group(1))
original_shape = tensor.shape
# Bias is [num_experts], keep first 8
new_tensor = tensor[:num_experts_to_keep]
print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}")
new_tensors[key] = new_tensor
continue
# Keep all other tensors as-is
new_tensors[key] = tensor
print(f"\nTotal tensors to save: {len(new_tensors)}")
print(f"Saving to {dst_path}...")
save_file(new_tensors, dst_path)
print("Done!")
# Also copy and modify config.json
config_src = src_dir / "config.json"
if config_src.exists():
with open(config_src, "r") as f:
config = json.load(f)
# Update number of experts
# if "num_experts" in config:
# print(f"Updating num_experts: {config['num_experts']} -> {num_experts_to_keep}")
# config["num_experts"] = num_experts_to_keep
# if "n_routed_experts" in config:
# print(f"Updating n_routed_experts: {config['n_routed_experts']} -> {num_experts_to_keep}")
# config["n_routed_experts"] = num_experts_to_keep
# config_dst = Path("config.json")
# with open(config_dst, "w") as f:
# json.dump(config, f, indent=2)
# print(f"Saved modified config to {config_dst}")
# Create safetensors index file
index_data = {
"metadata": {
"total_size": sum(t.numel() * t.element_size() for t in new_tensors.values())
},
"weight_map": {key: str(dst_path) for key in new_tensors.keys()}
}
index_dst = Path("model.safetensors.index.json")
with open(index_dst, "w") as f:
json.dump(index_data, f, indent=2)
print(f"Saved safetensors index to {index_dst}")
if __name__ == "__main__":
main()