attempt fix and more prints
Browse files- scripts/train.py +22 -7
scripts/train.py
CHANGED
|
@@ -41,7 +41,8 @@ def expand_model_with_small_experts(base_model):
|
|
| 41 |
print("# DEBUG: Expanding model with small experts...")
|
| 42 |
config = base_model.config
|
| 43 |
config.num_small_experts = 64 # Add 64 small experts
|
| 44 |
-
|
|
|
|
| 45 |
expanded_model = MyOlmoeForCausalLM(config)
|
| 46 |
|
| 47 |
base_state_dict = base_model.state_dict()
|
|
@@ -61,11 +62,17 @@ def expand_model_with_small_experts(base_model):
|
|
| 61 |
key = f'model.layers.{i}.mlp.experts.{i}.{proj}.weight'
|
| 62 |
if key in base_state_dict:
|
| 63 |
orig_weight = base_state_dict[key]
|
|
|
|
|
|
|
| 64 |
if proj == 'down_proj':
|
| 65 |
-
|
|
|
|
| 66 |
else:
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
else:
|
| 70 |
print(f"# DEBUG: Missing {key} in base model")
|
| 71 |
|
|
@@ -75,9 +82,18 @@ def expand_model_with_small_experts(base_model):
|
|
| 75 |
if gate_key in base_state_dict:
|
| 76 |
original_gate = base_state_dict[gate_key]
|
| 77 |
new_gate = expanded_state_dict[gate_key]
|
|
|
|
|
|
|
| 78 |
new_gate[:, :config.num_experts].copy_(original_gate)
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
print(f"# DEBUG: Missing gate weight {gate_key}")
|
| 83 |
|
|
@@ -85,7 +101,6 @@ def expand_model_with_small_experts(base_model):
|
|
| 85 |
expanded_model.load_state_dict(expanded_state_dict, strict=False)
|
| 86 |
|
| 87 |
return expanded_model
|
| 88 |
-
|
| 89 |
def main():
|
| 90 |
model_path = "myolmoe"
|
| 91 |
print("# DEBUG: Loading base model...")
|
|
|
|
| 41 |
print("# DEBUG: Expanding model with small experts...")
|
| 42 |
config = base_model.config
|
| 43 |
config.num_small_experts = 64 # Add 64 small experts
|
| 44 |
+
# Changed from //16 to //2 for more reasonable size
|
| 45 |
+
config.small_expert_intermediate_size = config.intermediate_size // 2
|
| 46 |
expanded_model = MyOlmoeForCausalLM(config)
|
| 47 |
|
| 48 |
base_state_dict = base_model.state_dict()
|
|
|
|
| 62 |
key = f'model.layers.{i}.mlp.experts.{i}.{proj}.weight'
|
| 63 |
if key in base_state_dict:
|
| 64 |
orig_weight = base_state_dict[key]
|
| 65 |
+
target_weight = expanded_state_dict[key]
|
| 66 |
+
|
| 67 |
if proj == 'down_proj':
|
| 68 |
+
# For down_proj, we copy the first part of the input dimension
|
| 69 |
+
target_weight.copy_(orig_weight[:, :config.small_expert_intermediate_size])
|
| 70 |
else:
|
| 71 |
+
# For gate_proj and up_proj, we copy the first part of the output dimension
|
| 72 |
+
target_weight.copy_(orig_weight[:config.small_expert_intermediate_size, :])
|
| 73 |
+
|
| 74 |
+
print(f"# DEBUG: Copied {proj} weights for expert {i} "
|
| 75 |
+
f"(original shape: {orig_weight.shape}, new shape: {target_weight.shape})")
|
| 76 |
else:
|
| 77 |
print(f"# DEBUG: Missing {key} in base model")
|
| 78 |
|
|
|
|
| 82 |
if gate_key in base_state_dict:
|
| 83 |
original_gate = base_state_dict[gate_key]
|
| 84 |
new_gate = expanded_state_dict[gate_key]
|
| 85 |
+
|
| 86 |
+
# Copy original gate weights
|
| 87 |
new_gate[:, :config.num_experts].copy_(original_gate)
|
| 88 |
+
|
| 89 |
+
# Initialize small experts gate weights
|
| 90 |
+
torch.nn.init.normal_(
|
| 91 |
+
new_gate[:, config.num_experts:],
|
| 92 |
+
mean=0.0,
|
| 93 |
+
std=config.initializer_range * 0.1
|
| 94 |
+
)
|
| 95 |
+
print(f"# DEBUG: Initialized gate for layer {i} "
|
| 96 |
+
f"(original shape: {original_gate.shape}, new shape: {new_gate.shape})")
|
| 97 |
else:
|
| 98 |
print(f"# DEBUG: Missing gate weight {gate_key}")
|
| 99 |
|
|
|
|
| 101 |
expanded_model.load_state_dict(expanded_state_dict, strict=False)
|
| 102 |
|
| 103 |
return expanded_model
|
|
|
|
| 104 |
def main():
|
| 105 |
model_path = "myolmoe"
|
| 106 |
print("# DEBUG: Loading base model...")
|