|
|
import os |
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
import torch |
|
|
|
|
|
def main(): |
|
|
src_dir = Path("../GLM-4.7-Flash") |
|
|
dst_path = Path("model.safetensors") |
|
|
num_experts_to_keep = 2 |
|
|
|
|
|
|
|
|
safetensor_files = sorted(src_dir.glob("*.safetensors")) |
|
|
print(f"Found {len(safetensor_files)} safetensors files") |
|
|
|
|
|
|
|
|
expert_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\..+") |
|
|
gate_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.weight") |
|
|
bias_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.e_score_correction_bias") |
|
|
|
|
|
new_tensors = {} |
|
|
|
|
|
for sf_path in safetensor_files: |
|
|
print(f"Processing {sf_path.name}...") |
|
|
|
|
|
with safe_open(sf_path, framework="pt", device="cpu") as f: |
|
|
for key in f.keys(): |
|
|
tensor = f.get_tensor(key) |
|
|
|
|
|
|
|
|
expert_match = expert_pattern.search(key) |
|
|
if expert_match: |
|
|
layer_idx = int(expert_match.group(1)) |
|
|
expert_idx = int(expert_match.group(2)) |
|
|
|
|
|
if expert_idx >= num_experts_to_keep: |
|
|
print(f" Skipping {key} (expert {expert_idx} >= {num_experts_to_keep})") |
|
|
continue |
|
|
|
|
|
new_tensors[key] = tensor |
|
|
continue |
|
|
|
|
|
|
|
|
gate_match = gate_pattern.search(key) |
|
|
if gate_match: |
|
|
layer_idx = int(gate_match.group(1)) |
|
|
original_shape = tensor.shape |
|
|
|
|
|
new_tensor = tensor[:num_experts_to_keep, :] |
|
|
print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}") |
|
|
new_tensors[key] = new_tensor |
|
|
continue |
|
|
|
|
|
|
|
|
bias_match = bias_pattern.search(key) |
|
|
if bias_match: |
|
|
layer_idx = int(bias_match.group(1)) |
|
|
original_shape = tensor.shape |
|
|
|
|
|
new_tensor = tensor[:num_experts_to_keep] |
|
|
print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}") |
|
|
new_tensors[key] = new_tensor |
|
|
continue |
|
|
|
|
|
|
|
|
new_tensors[key] = tensor |
|
|
|
|
|
print(f"\nTotal tensors to save: {len(new_tensors)}") |
|
|
print(f"Saving to {dst_path}...") |
|
|
|
|
|
save_file(new_tensors, dst_path) |
|
|
print("Done!") |
|
|
|
|
|
|
|
|
config_src = src_dir / "config.json" |
|
|
if config_src.exists(): |
|
|
with open(config_src, "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index_data = { |
|
|
"metadata": { |
|
|
"total_size": sum(t.numel() * t.element_size() for t in new_tensors.values()) |
|
|
}, |
|
|
"weight_map": {key: str(dst_path) for key in new_tensors.keys()} |
|
|
} |
|
|
|
|
|
index_dst = Path("model.safetensors.index.json") |
|
|
with open(index_dst, "w") as f: |
|
|
json.dump(index_data, f, indent=2) |
|
|
print(f"Saved safetensors index to {index_dst}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|