Charlie81 commited on
Commit
a82f934
·
1 Parent(s): 438a56a

attempt fix and more prints

Browse files
Files changed (1) hide show
  1. scripts/train.py +22 -7
scripts/train.py CHANGED
@@ -41,7 +41,8 @@ def expand_model_with_small_experts(base_model):
41
  print("# DEBUG: Expanding model with small experts...")
42
  config = base_model.config
43
  config.num_small_experts = 64 # Add 64 small experts
44
- config.small_expert_intermediate_size = config.intermediate_size // 16 # Half size
 
45
  expanded_model = MyOlmoeForCausalLM(config)
46
 
47
  base_state_dict = base_model.state_dict()
@@ -61,11 +62,17 @@ def expand_model_with_small_experts(base_model):
61
  key = f'model.layers.{i}.mlp.experts.{i}.{proj}.weight'
62
  if key in base_state_dict:
63
  orig_weight = base_state_dict[key]
 
 
64
  if proj == 'down_proj':
65
- expanded_state_dict[key].copy_(orig_weight[:, :config.small_expert_intermediate_size])
 
66
  else:
67
- expanded_state_dict[key].copy_(orig_weight[:config.small_expert_intermediate_size])
68
- print(f"# DEBUG: Copied {proj} weights for expert {i}")
 
 
 
69
  else:
70
  print(f"# DEBUG: Missing {key} in base model")
71
 
@@ -75,9 +82,18 @@ def expand_model_with_small_experts(base_model):
75
  if gate_key in base_state_dict:
76
  original_gate = base_state_dict[gate_key]
77
  new_gate = expanded_state_dict[gate_key]
 
 
78
  new_gate[:, :config.num_experts].copy_(original_gate)
79
- torch.nn.init.normal_(new_gate[:, config.num_experts:], mean=0.0, std=config.initializer_range * 0.1)
80
- print(f"# DEBUG: Initialized gate for layer {i}")
 
 
 
 
 
 
 
81
  else:
82
  print(f"# DEBUG: Missing gate weight {gate_key}")
83
 
@@ -85,7 +101,6 @@ def expand_model_with_small_experts(base_model):
85
  expanded_model.load_state_dict(expanded_state_dict, strict=False)
86
 
87
  return expanded_model
88
-
89
  def main():
90
  model_path = "myolmoe"
91
  print("# DEBUG: Loading base model...")
 
41
  print("# DEBUG: Expanding model with small experts...")
42
  config = base_model.config
43
  config.num_small_experts = 64 # Add 64 small experts
44
+ # Changed from //16 to //2 for more reasonable size
45
+ config.small_expert_intermediate_size = config.intermediate_size // 2
46
  expanded_model = MyOlmoeForCausalLM(config)
47
 
48
  base_state_dict = base_model.state_dict()
 
62
  key = f'model.layers.{i}.mlp.experts.{i}.{proj}.weight'
63
  if key in base_state_dict:
64
  orig_weight = base_state_dict[key]
65
+ target_weight = expanded_state_dict[key]
66
+
67
  if proj == 'down_proj':
68
+ # For down_proj, we copy the first part of the input dimension
69
+ target_weight.copy_(orig_weight[:, :config.small_expert_intermediate_size])
70
  else:
71
+ # For gate_proj and up_proj, we copy the first part of the output dimension
72
+ target_weight.copy_(orig_weight[:config.small_expert_intermediate_size, :])
73
+
74
+ print(f"# DEBUG: Copied {proj} weights for expert {i} "
75
+ f"(original shape: {orig_weight.shape}, new shape: {target_weight.shape})")
76
  else:
77
  print(f"# DEBUG: Missing {key} in base model")
78
 
 
82
  if gate_key in base_state_dict:
83
  original_gate = base_state_dict[gate_key]
84
  new_gate = expanded_state_dict[gate_key]
85
+
86
+ # Copy original gate weights
87
  new_gate[:, :config.num_experts].copy_(original_gate)
88
+
89
+ # Initialize small experts gate weights
90
+ torch.nn.init.normal_(
91
+ new_gate[:, config.num_experts:],
92
+ mean=0.0,
93
+ std=config.initializer_range * 0.1
94
+ )
95
+ print(f"# DEBUG: Initialized gate for layer {i} "
96
+ f"(original shape: {original_gate.shape}, new shape: {new_gate.shape})")
97
  else:
98
  print(f"# DEBUG: Missing gate weight {gate_key}")
99
 
 
101
  expanded_model.load_state_dict(expanded_state_dict, strict=False)
102
 
103
  return expanded_model
 
104
  def main():
105
  model_path = "myolmoe"
106
  print("# DEBUG: Loading base model...")