File size: 3,751 Bytes
068d828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import json
import re
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
import torch

def main():
  src_dir = Path("../GLM-4.7-Flash")
  dst_path = Path("model.safetensors")
  num_experts_to_keep = 2
  
  # Find all safetensors files
  safetensor_files = sorted(src_dir.glob("*.safetensors"))
  print(f"Found {len(safetensor_files)} safetensors files")
  
  # Pattern to match expert weights
  expert_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\..+")
  gate_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.weight")
  bias_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.e_score_correction_bias")
  
  new_tensors = {}
  
  for sf_path in safetensor_files:
    print(f"Processing {sf_path.name}...")
    
    with safe_open(sf_path, framework="pt", device="cpu") as f:
      for key in f.keys():
        tensor = f.get_tensor(key)
        
        # Check if this is an expert weight
        expert_match = expert_pattern.search(key)
        if expert_match:
          layer_idx = int(expert_match.group(1))
          expert_idx = int(expert_match.group(2))
          
          if expert_idx >= num_experts_to_keep:
            print(f"  Skipping {key} (expert {expert_idx} >= {num_experts_to_keep})")
            continue
          
          new_tensors[key] = tensor
          continue
        
        # Check if this is the gate weight
        gate_match = gate_pattern.search(key)
        if gate_match:
          layer_idx = int(gate_match.group(1))
          original_shape = tensor.shape
          # Gate weight is [num_experts, hidden_size], keep first 8 experts
          new_tensor = tensor[:num_experts_to_keep, :]
          print(f"  Resizing {key}: {original_shape} -> {new_tensor.shape}")
          new_tensors[key] = new_tensor
          continue
        
        # Check if this is the e_score_correction_bias
        bias_match = bias_pattern.search(key)
        if bias_match:
          layer_idx = int(bias_match.group(1))
          original_shape = tensor.shape
          # Bias is [num_experts], keep first 8
          new_tensor = tensor[:num_experts_to_keep]
          print(f"  Resizing {key}: {original_shape} -> {new_tensor.shape}")
          new_tensors[key] = new_tensor
          continue
        
        # Keep all other tensors as-is
        new_tensors[key] = tensor
  
  print(f"\nTotal tensors to save: {len(new_tensors)}")
  print(f"Saving to {dst_path}...")
  
  save_file(new_tensors, dst_path)
  print("Done!")
  
  # Also copy and modify config.json
  config_src = src_dir / "config.json"
  if config_src.exists():
    with open(config_src, "r") as f:
      config = json.load(f)
    
    # Update number of experts
    # if "num_experts" in config:
    #   print(f"Updating num_experts: {config['num_experts']} -> {num_experts_to_keep}")
    #   config["num_experts"] = num_experts_to_keep
    # if "n_routed_experts" in config:
    #   print(f"Updating n_routed_experts: {config['n_routed_experts']} -> {num_experts_to_keep}")
    #   config["n_routed_experts"] = num_experts_to_keep
    
    # config_dst = Path("config.json")
    # with open(config_dst, "w") as f:
    #   json.dump(config, f, indent=2)
    # print(f"Saved modified config to {config_dst}")
  
  # Create safetensors index file
  index_data = {
    "metadata": {
      "total_size": sum(t.numel() * t.element_size() for t in new_tensors.values())
    },
    "weight_map": {key: str(dst_path) for key in new_tensors.keys()}
  }
  
  index_dst = Path("model.safetensors.index.json")
  with open(index_dst, "w") as f:
    json.dump(index_data, f, indent=2)
  print(f"Saved safetensors index to {index_dst}")

if __name__ == "__main__":
  main()