overhaul
Browse files- myolmoe/config.json +4 -1
- myolmoe/modeling_myolmoe.py +145 -14
- scripts/train.py +60 -148
myolmoe/config.json
CHANGED
|
@@ -30,5 +30,8 @@
|
|
| 30 |
"torch_dtype": "float32",
|
| 31 |
"transformers_version": "4.52.4",
|
| 32 |
"use_cache": true,
|
| 33 |
-
"vocab_size": 50304
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
|
|
|
| 30 |
"torch_dtype": "float32",
|
| 31 |
"transformers_version": "4.52.4",
|
| 32 |
"use_cache": true,
|
| 33 |
+
"vocab_size": 50304,
|
| 34 |
+
"small_expert_intermediate_ratio": 0.5,
|
| 35 |
+
"small_expert_frequency": 4,
|
| 36 |
+
"small_expert_load_balancing_coef": 0.1
|
| 37 |
}
|
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -14,7 +14,103 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
|
|
| 14 |
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 16 |
from transformers.utils import logging
|
| 17 |
-
from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
logger = logging.get_logger(__name__)
|
| 20 |
|
|
@@ -143,21 +239,25 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
| 143 |
|
| 144 |
|
| 145 |
class OlmoeMLP(nn.Module):
|
| 146 |
-
def __init__(self, config):
|
| 147 |
super().__init__()
|
| 148 |
self.config = config
|
| 149 |
self.hidden_size = config.hidden_size
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 152 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 153 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 154 |
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
| 155 |
|
| 156 |
def forward(self, x):
|
| 157 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 158 |
return down_proj
|
| 159 |
|
| 160 |
-
|
| 161 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 162 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 163 |
if n_rep == 1:
|
|
@@ -446,6 +546,7 @@ OLMOE_ATTENTION_CLASSES = {
|
|
| 446 |
}
|
| 447 |
|
| 448 |
|
|
|
|
| 449 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 450 |
def __init__(self, config, layer_idx: int):
|
| 451 |
super().__init__()
|
|
@@ -453,10 +554,21 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 453 |
self.num_experts = config.num_experts
|
| 454 |
self.top_k = config.num_experts_per_tok
|
| 455 |
self.norm_topk_prob = config.norm_topk_prob
|
| 456 |
-
self.routing_type = getattr(config, "routing_type", "topk")
|
| 457 |
-
self.n_step = getattr(config, "nth_step", 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 459 |
-
self.
|
| 460 |
|
| 461 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 462 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
@@ -464,7 +576,6 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 464 |
router_logits = self.gate(hidden_states)
|
| 465 |
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 466 |
|
| 467 |
-
# === Routing ===
|
| 468 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 469 |
|
| 470 |
if self.norm_topk_prob:
|
|
@@ -479,6 +590,18 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 479 |
|
| 480 |
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
for expert_idx in range(self.num_experts):
|
| 483 |
expert_layer = self.experts[expert_idx]
|
| 484 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
@@ -489,8 +612,7 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 489 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 490 |
|
| 491 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 492 |
-
return final_hidden_states, router_logits
|
| 493 |
-
|
| 494 |
|
| 495 |
class OlmoeDecoderLayer(nn.Module):
|
| 496 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
@@ -536,9 +658,9 @@ class OlmoeDecoderLayer(nn.Module):
|
|
| 536 |
hidden_states = residual + hidden_states
|
| 537 |
residual = hidden_states
|
| 538 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 539 |
-
hidden_states, router_logits = self.mlp(hidden_states)
|
| 540 |
-
hidden_states = residual + hidden_states
|
| 541 |
-
outputs = (hidden_states,)
|
| 542 |
if output_attentions:
|
| 543 |
outputs += (self_attn_weights,)
|
| 544 |
if use_cache:
|
|
@@ -942,6 +1064,15 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
|
| 942 |
if output_router_logits:
|
| 943 |
output = (aux_loss,) + output
|
| 944 |
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
return MoeCausalLMOutputWithPast(
|
| 946 |
loss=loss,
|
| 947 |
aux_loss=aux_loss,
|
|
@@ -952,4 +1083,4 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
|
| 952 |
router_logits=outputs.router_logits,
|
| 953 |
)
|
| 954 |
|
| 955 |
-
__all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
|
|
|
|
| 14 |
from transformers.modeling_utils import PreTrainedModel
|
| 15 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 16 |
from transformers.utils import logging
|
| 17 |
+
# from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
|
| 18 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 20 |
+
|
| 21 |
+
class OlmoeConfig(PretrainedConfig):
|
| 22 |
+
r"""
|
| 23 |
+
This is the configuration class to store the configuration of a [`OlmoeModel`].
|
| 24 |
+
[Previous docstring remains the same...]
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
[Previous args remain the same...]
|
| 28 |
+
small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
|
| 29 |
+
Ratio of intermediate size for small experts compared to regular experts.
|
| 30 |
+
small_expert_frequency (`int`, *optional*, defaults to 4):
|
| 31 |
+
Frequency of small experts - every Nth expert will be small.
|
| 32 |
+
small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
|
| 33 |
+
Coefficient for small expert load balancing loss.
|
| 34 |
+
"""
|
| 35 |
+
model_type = "olmoe"
|
| 36 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
vocab_size=50304,
|
| 41 |
+
hidden_size=2048,
|
| 42 |
+
intermediate_size=2048,
|
| 43 |
+
num_hidden_layers=16,
|
| 44 |
+
num_attention_heads=16,
|
| 45 |
+
num_key_value_heads=None,
|
| 46 |
+
hidden_act="silu",
|
| 47 |
+
max_position_embeddings=4096,
|
| 48 |
+
initializer_range=0.02,
|
| 49 |
+
rms_norm_eps=1e-05,
|
| 50 |
+
use_cache=True,
|
| 51 |
+
pad_token_id=1,
|
| 52 |
+
bos_token_id=None,
|
| 53 |
+
eos_token_id=50279,
|
| 54 |
+
tie_word_embeddings=False,
|
| 55 |
+
rope_theta=10000.0,
|
| 56 |
+
rope_scaling=None,
|
| 57 |
+
attention_bias=False,
|
| 58 |
+
attention_dropout=0.0,
|
| 59 |
+
clip_qkv=None,
|
| 60 |
+
num_experts_per_tok=8,
|
| 61 |
+
num_experts=64,
|
| 62 |
+
output_router_logits=False,
|
| 63 |
+
router_aux_loss_coef=0.01,
|
| 64 |
+
norm_topk_prob=False,
|
| 65 |
+
small_expert_intermediate_ratio=0.5,
|
| 66 |
+
small_expert_frequency=4,
|
| 67 |
+
small_expert_load_balancing_coef=0.1,
|
| 68 |
+
**kwargs,
|
| 69 |
+
):
|
| 70 |
+
self.vocab_size = vocab_size
|
| 71 |
+
self.max_position_embeddings = max_position_embeddings
|
| 72 |
+
self.hidden_size = hidden_size
|
| 73 |
+
self.intermediate_size = intermediate_size
|
| 74 |
+
self.num_hidden_layers = num_hidden_layers
|
| 75 |
+
self.num_attention_heads = num_attention_heads
|
| 76 |
+
|
| 77 |
+
# for backward compatibility
|
| 78 |
+
if num_key_value_heads is None:
|
| 79 |
+
num_key_value_heads = num_attention_heads
|
| 80 |
+
|
| 81 |
+
self.num_key_value_heads = num_key_value_heads
|
| 82 |
+
self.hidden_act = hidden_act
|
| 83 |
+
self.initializer_range = initializer_range
|
| 84 |
+
self.rms_norm_eps = rms_norm_eps
|
| 85 |
+
self.use_cache = use_cache
|
| 86 |
+
self.rope_theta = rope_theta
|
| 87 |
+
self.rope_scaling = rope_scaling
|
| 88 |
+
self.attention_bias = attention_bias
|
| 89 |
+
self.attention_dropout = attention_dropout
|
| 90 |
+
self.clip_qkv = clip_qkv
|
| 91 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 92 |
+
self.num_experts = num_experts
|
| 93 |
+
self.output_router_logits = output_router_logits
|
| 94 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 95 |
+
self.norm_topk_prob = norm_topk_prob
|
| 96 |
+
|
| 97 |
+
# Small expert parameters
|
| 98 |
+
self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
|
| 99 |
+
self.small_expert_frequency = small_expert_frequency
|
| 100 |
+
self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
|
| 101 |
+
|
| 102 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 103 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 104 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 105 |
+
rope_config_validation(self)
|
| 106 |
+
|
| 107 |
+
super().__init__(
|
| 108 |
+
pad_token_id=pad_token_id,
|
| 109 |
+
bos_token_id=bos_token_id,
|
| 110 |
+
eos_token_id=eos_token_id,
|
| 111 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 112 |
+
**kwargs,
|
| 113 |
+
)
|
| 114 |
|
| 115 |
logger = logging.get_logger(__name__)
|
| 116 |
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
class OlmoeMLP(nn.Module):
|
| 242 |
+
def __init__(self, config, is_small=False):
|
| 243 |
super().__init__()
|
| 244 |
self.config = config
|
| 245 |
self.hidden_size = config.hidden_size
|
| 246 |
+
if is_small:
|
| 247 |
+
self.intermediate_size = int(config.intermediate_size * config.small_expert_intermediate_ratio)
|
| 248 |
+
else:
|
| 249 |
+
self.intermediate_size = config.intermediate_size
|
| 250 |
+
|
| 251 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 252 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 253 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 254 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 255 |
+
self.is_small = is_small
|
| 256 |
|
| 257 |
def forward(self, x):
|
| 258 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 259 |
return down_proj
|
| 260 |
|
|
|
|
| 261 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 262 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 263 |
if n_rep == 1:
|
|
|
|
| 546 |
}
|
| 547 |
|
| 548 |
|
| 549 |
+
|
| 550 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 551 |
def __init__(self, config, layer_idx: int):
|
| 552 |
super().__init__()
|
|
|
|
| 554 |
self.num_experts = config.num_experts
|
| 555 |
self.top_k = config.num_experts_per_tok
|
| 556 |
self.norm_topk_prob = config.norm_topk_prob
|
| 557 |
+
self.routing_type = getattr(config, "routing_type", "topk")
|
| 558 |
+
self.n_step = getattr(config, "nth_step", 2)
|
| 559 |
+
|
| 560 |
+
# Track which experts are small
|
| 561 |
+
self.small_expert_indices = []
|
| 562 |
+
self.experts = nn.ModuleList()
|
| 563 |
+
|
| 564 |
+
for i in range(self.num_experts):
|
| 565 |
+
is_small = (i % config.small_expert_frequency == 0)
|
| 566 |
+
if is_small:
|
| 567 |
+
self.small_expert_indices.append(i)
|
| 568 |
+
self.experts.append(OlmoeMLP(config, is_small=is_small))
|
| 569 |
+
|
| 570 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 571 |
+
self.small_expert_load_balancing_coef = config.small_expert_load_balancing_coef
|
| 572 |
|
| 573 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 574 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
|
|
| 576 |
router_logits = self.gate(hidden_states)
|
| 577 |
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 578 |
|
|
|
|
| 579 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 580 |
|
| 581 |
if self.norm_topk_prob:
|
|
|
|
| 590 |
|
| 591 |
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 592 |
|
| 593 |
+
# Calculate small expert load balancing loss
|
| 594 |
+
small_expert_mask = torch.zeros_like(expert_mask)
|
| 595 |
+
for idx in self.small_expert_indices:
|
| 596 |
+
small_expert_mask[idx] = expert_mask[idx]
|
| 597 |
+
|
| 598 |
+
small_expert_loss = load_balancing_loss_func(
|
| 599 |
+
router_logits,
|
| 600 |
+
self.num_experts,
|
| 601 |
+
self.top_k,
|
| 602 |
+
None
|
| 603 |
+
) * self.small_expert_load_balancing_coef
|
| 604 |
+
|
| 605 |
for expert_idx in range(self.num_experts):
|
| 606 |
expert_layer = self.experts[expert_idx]
|
| 607 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
|
|
| 612 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 613 |
|
| 614 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 615 |
+
return final_hidden_states, router_logits, small_expert_loss
|
|
|
|
| 616 |
|
| 617 |
class OlmoeDecoderLayer(nn.Module):
|
| 618 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
|
|
|
| 658 |
hidden_states = residual + hidden_states
|
| 659 |
residual = hidden_states
|
| 660 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 661 |
+
hidden_states, router_logits, small_expert_loss = self.mlp(hidden_states) #
|
| 662 |
+
hidden_states = residual + hidden_states #
|
| 663 |
+
outputs = (hidden_states, small_expert_loss) #
|
| 664 |
if output_attentions:
|
| 665 |
outputs += (self_attn_weights,)
|
| 666 |
if use_cache:
|
|
|
|
| 1064 |
if output_router_logits:
|
| 1065 |
output = (aux_loss,) + output
|
| 1066 |
return (loss,) + output if loss is not None else output
|
| 1067 |
+
#
|
| 1068 |
+
total_small_expert_loss = 0
|
| 1069 |
+
for layer_output in outputs:
|
| 1070 |
+
if len(layer_output) > 1 and isinstance(layer_output[1], torch.Tensor):
|
| 1071 |
+
total_small_expert_loss += layer_output[1]
|
| 1072 |
+
|
| 1073 |
+
if labels is not None:
|
| 1074 |
+
loss += total_small_expert_loss.to(loss.device)
|
| 1075 |
+
#
|
| 1076 |
return MoeCausalLMOutputWithPast(
|
| 1077 |
loss=loss,
|
| 1078 |
aux_loss=aux_loss,
|
|
|
|
| 1083 |
router_logits=outputs.router_logits,
|
| 1084 |
)
|
| 1085 |
|
| 1086 |
+
__all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel", "OlmoeConfig"]
|
scripts/train.py
CHANGED
|
@@ -1,170 +1,82 @@
|
|
| 1 |
-
|
| 2 |
import torch
|
| 3 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from datasets import load_dataset
|
| 5 |
-
from myolmoe
|
| 6 |
-
from torch.utils.data import Dataset
|
| 7 |
import os
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def __getitem__(self, idx):
|
| 21 |
-
item = self.dataset[idx]
|
| 22 |
-
text = item["text"] # Adjust based on your dataset structure
|
| 23 |
-
encoding = self.tokenizer(
|
| 24 |
-
text,
|
| 25 |
-
max_length=self.max_length,
|
| 26 |
-
padding="max_length",
|
| 27 |
-
truncation=True,
|
| 28 |
-
return_tensors="pt"
|
| 29 |
-
)
|
| 30 |
-
# DEBUG: Print the first few token IDs for inspection
|
| 31 |
-
if idx == 0:
|
| 32 |
-
print(f"# DEBUG: Sample input text: {text[:100]}")
|
| 33 |
-
print(f"# DEBUG: Tokenized input_ids[:10]: {encoding['input_ids'][0][:10]}")
|
| 34 |
-
return {
|
| 35 |
-
"input_ids": encoding["input_ids"].squeeze(),
|
| 36 |
-
"attention_mask": encoding["attention_mask"].squeeze(),
|
| 37 |
-
"labels": encoding["input_ids"].squeeze()
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
def expand_model_with_small_experts(base_model):
|
| 41 |
-
print("# DEBUG: Expanding model with small experts...")
|
| 42 |
-
config = base_model.config
|
| 43 |
-
config.num_small_experts = 64 # Add 64 small experts
|
| 44 |
-
config.small_expert_intermediate_size = config.intermediate_size // 32
|
| 45 |
-
expanded_model = MyOlmoeForCausalLM(config)
|
| 46 |
-
|
| 47 |
-
base_state_dict = base_model.state_dict()
|
| 48 |
-
expanded_state_dict = expanded_model.state_dict()
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if name in expanded_state_dict:
|
| 54 |
-
expanded_state_dict[name].copy_(param)
|
| 55 |
-
else:
|
| 56 |
-
print(f"# DEBUG: Skipped non-expert param {name} (not found in expanded model)")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
for proj in ['gate_proj', 'up_proj', 'down_proj']:
|
| 61 |
-
key = f'model.layers.{i}.mlp.experts.{i}.{proj}.weight'
|
| 62 |
-
if key in base_state_dict:
|
| 63 |
-
orig_weight = base_state_dict[key]
|
| 64 |
-
target_weight = expanded_state_dict[key]
|
| 65 |
-
|
| 66 |
-
if proj == 'down_proj':
|
| 67 |
-
# For down_proj, we copy the first part of the input dimension
|
| 68 |
-
target_weight.copy_(orig_weight[:, :config.small_expert_intermediate_size])
|
| 69 |
-
else:
|
| 70 |
-
# For gate_proj and up_proj, we copy the first part of the output dimension
|
| 71 |
-
target_weight.copy_(orig_weight[:config.small_expert_intermediate_size, :])
|
| 72 |
-
|
| 73 |
-
print(f"# DEBUG: Copied {proj} weights for expert {i} "
|
| 74 |
-
f"(original shape: {orig_weight.shape}, new shape: {target_weight.shape})")
|
| 75 |
-
else:
|
| 76 |
-
print(f"# DEBUG: Missing {key} in base model")
|
| 77 |
-
|
| 78 |
-
print("# DEBUG: Expanding and initializing gate weights...")
|
| 79 |
-
for i in range(config.num_hidden_layers):
|
| 80 |
-
gate_key = f'model.layers.{i}.mlp.gate.weight'
|
| 81 |
-
if gate_key in base_state_dict:
|
| 82 |
-
original_gate = base_state_dict[gate_key]
|
| 83 |
-
new_gate = expanded_state_dict[gate_key]
|
| 84 |
-
|
| 85 |
-
# Copy original gate weights
|
| 86 |
-
new_gate[:, :config.num_experts].copy_(original_gate)
|
| 87 |
-
|
| 88 |
-
# Initialize small experts gate weights
|
| 89 |
-
torch.nn.init.normal_(
|
| 90 |
-
new_gate[:, config.num_experts:],
|
| 91 |
-
mean=0.0,
|
| 92 |
-
std=config.initializer_range * 0.1
|
| 93 |
-
)
|
| 94 |
-
print(f"# DEBUG: Initialized gate for layer {i} "
|
| 95 |
-
f"(original shape: {original_gate.shape}, new shape: {new_gate.shape})")
|
| 96 |
-
else:
|
| 97 |
-
print(f"# DEBUG: Missing gate weight {gate_key}")
|
| 98 |
-
|
| 99 |
-
print("# DEBUG: Loading expanded state dict into model...")
|
| 100 |
-
expanded_model.load_state_dict(expanded_state_dict, strict=False)
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
print("# DEBUG: Loading tokenizer and preparing dataset...")
|
| 116 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 117 |
-
dataset = CustomDataset(tokenizer)
|
| 118 |
-
|
| 119 |
-
print("# DEBUG: Setting up training arguments...")
|
| 120 |
training_args = TrainingArguments(
|
| 121 |
output_dir="./output",
|
| 122 |
-
per_device_train_batch_size=
|
| 123 |
gradient_accumulation_steps=8,
|
| 124 |
-
learning_rate=1e-
|
| 125 |
-
num_train_epochs=
|
| 126 |
logging_dir="./logs",
|
| 127 |
-
|
| 128 |
save_steps=1000,
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
fp16=True,
|
| 132 |
gradient_checkpointing=True,
|
| 133 |
-
report_to="tensorboard"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
self.freeze_existing = kwargs.pop('freeze_existing_experts', False)
|
| 139 |
-
super().__init__(*args, **kwargs)
|
| 140 |
-
|
| 141 |
-
if self.freeze_existing:
|
| 142 |
-
print("# DEBUG: Freezing original expert parameters...")
|
| 143 |
-
frozen_count = 0
|
| 144 |
-
for name, param in self.model.named_parameters():
|
| 145 |
-
if "mlp.experts" in name and "small_experts" not in name:
|
| 146 |
-
param.requires_grad = False
|
| 147 |
-
frozen_count += 1
|
| 148 |
-
print(f"# DEBUG: Total frozen expert parameters: {frozen_count}")
|
| 149 |
-
|
| 150 |
-
print("# DEBUG: Initializing trainer...")
|
| 151 |
-
trainer = MoETrainer(
|
| 152 |
model=model,
|
| 153 |
args=training_args,
|
| 154 |
-
train_dataset=
|
| 155 |
-
|
| 156 |
-
|
| 157 |
)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
trainer.train()
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
print(f"# DEBUG: Saving final model to {output_dir}...")
|
| 165 |
-
model.save_pretrained(output_dir)
|
| 166 |
-
tokenizer.save_pretrained(output_dir)
|
| 167 |
-
print("# DEBUG: Training complete!")
|
| 168 |
|
| 169 |
if __name__ == "__main__":
|
| 170 |
-
main()
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoTokenizer,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
Trainer,
|
| 8 |
+
default_data_collator,
|
| 9 |
+
)
|
| 10 |
from datasets import load_dataset
|
| 11 |
+
from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
|
|
|
|
| 12 |
import os
|
|
|
|
| 13 |
|
| 14 |
+
def main():
|
| 15 |
+
# Load config and model
|
| 16 |
+
config = OlmoeConfig.from_pretrained("myolmoe/config.json")
|
| 17 |
+
model = MyOlmoeForCausalLM.from_pretrained(
|
| 18 |
+
"myolmoe",
|
| 19 |
+
config=config,
|
| 20 |
+
torch_dtype=torch.bfloat16,
|
| 21 |
+
device_map="auto"
|
| 22 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
# Load tokenizer
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained("myolmoe")
|
| 26 |
+
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# Load dataset
|
| 29 |
+
dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
def tokenize_function(examples):
|
| 32 |
+
return tokenizer(
|
| 33 |
+
examples["text"],
|
| 34 |
+
truncation=True,
|
| 35 |
+
max_length=4096,
|
| 36 |
+
padding="max_length"
|
| 37 |
+
)
|
| 38 |
|
| 39 |
+
tokenized_dataset = dataset.map(
|
| 40 |
+
tokenize_function,
|
| 41 |
+
batched=True,
|
| 42 |
+
remove_columns=dataset.column_names,
|
| 43 |
+
num_proc=4
|
| 44 |
+
)
|
| 45 |
|
| 46 |
+
# Training arguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
training_args = TrainingArguments(
|
| 48 |
output_dir="./output",
|
| 49 |
+
per_device_train_batch_size=2,
|
| 50 |
gradient_accumulation_steps=8,
|
| 51 |
+
learning_rate=1e-5,
|
| 52 |
+
num_train_epochs=1,
|
| 53 |
logging_dir="./logs",
|
| 54 |
+
logging_steps=10,
|
| 55 |
save_steps=1000,
|
| 56 |
+
save_total_limit=2,
|
| 57 |
+
bf16=True,
|
|
|
|
| 58 |
gradient_checkpointing=True,
|
| 59 |
+
report_to="tensorboard",
|
| 60 |
+
optim="adamw_torch",
|
| 61 |
+
lr_scheduler_type="cosine",
|
| 62 |
+
warmup_ratio=0.1,
|
| 63 |
+
max_grad_norm=1.0,
|
| 64 |
)
|
| 65 |
+
|
| 66 |
+
# Trainer
|
| 67 |
+
trainer = Trainer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
model=model,
|
| 69 |
args=training_args,
|
| 70 |
+
train_dataset=tokenized_dataset,
|
| 71 |
+
tokenizer=tokenizer,
|
| 72 |
+
data_collator=default_data_collator,
|
| 73 |
)
|
| 74 |
+
|
| 75 |
+
# Train
|
| 76 |
trainer.train()
|
| 77 |
+
|
| 78 |
+
# Save
|
| 79 |
+
trainer.save_model("./final_model")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
if __name__ == "__main__":
|
| 82 |
+
main()
|