import os import json import re from pathlib import Path from safetensors import safe_open from safetensors.torch import save_file import torch def main(): src_dir = Path("../GLM-4.7-Flash") dst_path = Path("model.safetensors") num_experts_to_keep = 2 # Find all safetensors files safetensor_files = sorted(src_dir.glob("*.safetensors")) print(f"Found {len(safetensor_files)} safetensors files") # Pattern to match expert weights expert_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\..+") gate_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.weight") bias_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.e_score_correction_bias") new_tensors = {} for sf_path in safetensor_files: print(f"Processing {sf_path.name}...") with safe_open(sf_path, framework="pt", device="cpu") as f: for key in f.keys(): tensor = f.get_tensor(key) # Check if this is an expert weight expert_match = expert_pattern.search(key) if expert_match: layer_idx = int(expert_match.group(1)) expert_idx = int(expert_match.group(2)) if expert_idx >= num_experts_to_keep: print(f" Skipping {key} (expert {expert_idx} >= {num_experts_to_keep})") continue new_tensors[key] = tensor continue # Check if this is the gate weight gate_match = gate_pattern.search(key) if gate_match: layer_idx = int(gate_match.group(1)) original_shape = tensor.shape # Gate weight is [num_experts, hidden_size], keep first 8 experts new_tensor = tensor[:num_experts_to_keep, :] print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}") new_tensors[key] = new_tensor continue # Check if this is the e_score_correction_bias bias_match = bias_pattern.search(key) if bias_match: layer_idx = int(bias_match.group(1)) original_shape = tensor.shape # Bias is [num_experts], keep first 8 new_tensor = tensor[:num_experts_to_keep] print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}") new_tensors[key] = new_tensor continue # Keep all other tensors as-is new_tensors[key] = tensor print(f"\nTotal tensors to save: {len(new_tensors)}") print(f"Saving to {dst_path}...") save_file(new_tensors, dst_path) print("Done!") # Also copy and modify config.json config_src = src_dir / "config.json" if config_src.exists(): with open(config_src, "r") as f: config = json.load(f) # Update number of experts # if "num_experts" in config: # print(f"Updating num_experts: {config['num_experts']} -> {num_experts_to_keep}") # config["num_experts"] = num_experts_to_keep # if "n_routed_experts" in config: # print(f"Updating n_routed_experts: {config['n_routed_experts']} -> {num_experts_to_keep}") # config["n_routed_experts"] = num_experts_to_keep # config_dst = Path("config.json") # with open(config_dst, "w") as f: # json.dump(config, f, indent=2) # print(f"Saved modified config to {config_dst}") # Create safetensors index file index_data = { "metadata": { "total_size": sum(t.numel() * t.element_size() for t in new_tensors.values()) }, "weight_map": {key: str(dst_path) for key in new_tensors.keys()} } index_dst = Path("model.safetensors.index.json") with open(index_dst, "w") as f: json.dump(index_data, f, indent=2) print(f"Saved safetensors index to {index_dst}") if __name__ == "__main__": main()