initial stuff
Browse files- myolmoe/config.json +4 -2
- myolmoe/modeling_myolmoe.py +51 -5
- scripts/train.py +85 -0
myolmoe/config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
|
@@ -15,6 +15,8 @@
|
|
| 15 |
"norm_topk_prob": false,
|
| 16 |
"num_attention_heads": 16,
|
| 17 |
"num_experts": 64,
|
|
|
|
|
|
|
| 18 |
"num_experts_per_tok": 2,
|
| 19 |
"num_hidden_layers": 16,
|
| 20 |
"num_key_value_heads": 16,
|
|
@@ -29,4 +31,4 @@
|
|
| 29 |
"transformers_version": "4.52.4",
|
| 30 |
"use_cache": true,
|
| 31 |
"vocab_size": 50304
|
| 32 |
-
}
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"MyOlmoeForCausalLM"
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
|
|
|
| 15 |
"norm_topk_prob": false,
|
| 16 |
"num_attention_heads": 16,
|
| 17 |
"num_experts": 64,
|
| 18 |
+
"num_small_experts": 64,
|
| 19 |
+
"small_expert_intermediate_size": 512,
|
| 20 |
"num_experts_per_tok": 2,
|
| 21 |
"num_hidden_layers": 16,
|
| 22 |
"num_key_value_heads": 16,
|
|
|
|
| 31 |
"transformers_version": "4.52.4",
|
| 32 |
"use_cache": true,
|
| 33 |
"vocab_size": 50304
|
| 34 |
+
}
|
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -156,6 +156,21 @@ class OlmoeMLP(nn.Module):
|
|
| 156 |
def forward(self, x):
|
| 157 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 158 |
return down_proj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
@@ -446,15 +461,34 @@ OLMOE_ATTENTION_CLASSES = {
|
|
| 446 |
}
|
| 447 |
|
| 448 |
|
|
|
|
| 449 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 450 |
def __init__(self, config, layer_idx: int):
|
| 451 |
super().__init__()
|
| 452 |
self.layer_idx = layer_idx
|
| 453 |
self.num_experts = config.num_experts
|
|
|
|
|
|
|
| 454 |
self.top_k = config.num_experts_per_tok
|
| 455 |
self.norm_topk_prob = config.norm_topk_prob
|
| 456 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 460 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
@@ -462,7 +496,6 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 462 |
router_logits = self.gate(hidden_states)
|
| 463 |
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 464 |
|
| 465 |
-
# === Routing ===
|
| 466 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 467 |
|
| 468 |
if self.norm_topk_prob:
|
|
@@ -475,8 +508,9 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 475 |
device=hidden_states.device,
|
| 476 |
)
|
| 477 |
|
| 478 |
-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.
|
| 479 |
|
|
|
|
| 480 |
for expert_idx in range(self.num_experts):
|
| 481 |
expert_layer = self.experts[expert_idx]
|
| 482 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
@@ -486,10 +520,22 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|
| 486 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 487 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 490 |
return final_hidden_states, router_logits
|
| 491 |
-
|
| 492 |
-
|
| 493 |
class OlmoeDecoderLayer(nn.Module):
|
| 494 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
| 495 |
super().__init__()
|
|
|
|
| 156 |
def forward(self, x):
|
| 157 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 158 |
return down_proj
|
| 159 |
+
|
| 160 |
+
class SmallOlmoeMLP(nn.Module):
|
| 161 |
+
def __init__(self, config, small_expert_intermediate_size):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.config = config
|
| 164 |
+
self.hidden_size = config.hidden_size
|
| 165 |
+
self.intermediate_size = small_expert_intermediate_size
|
| 166 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 167 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 168 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 169 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 173 |
+
return down_proj
|
| 174 |
|
| 175 |
|
| 176 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
| 461 |
}
|
| 462 |
|
| 463 |
|
| 464 |
+
|
| 465 |
class OlmoeSparseMoeBlock(nn.Module):
|
| 466 |
def __init__(self, config, layer_idx: int):
|
| 467 |
super().__init__()
|
| 468 |
self.layer_idx = layer_idx
|
| 469 |
self.num_experts = config.num_experts
|
| 470 |
+
self.num_small_experts = getattr(config, "num_small_experts", 0) # Default to 0 if not specified
|
| 471 |
+
self.total_experts = self.num_experts + self.num_small_experts
|
| 472 |
self.top_k = config.num_experts_per_tok
|
| 473 |
self.norm_topk_prob = config.norm_topk_prob
|
| 474 |
+
self.routing_type = getattr(config, "routing_type", "topk")
|
| 475 |
+
self.n_step = getattr(config, "nth_step", 2)
|
| 476 |
+
|
| 477 |
+
# Gate now needs to handle both regular and small experts
|
| 478 |
+
self.gate = nn.Linear(config.hidden_size, self.total_experts, bias=False)
|
| 479 |
+
|
| 480 |
+
# Regular experts
|
| 481 |
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 482 |
+
|
| 483 |
+
# Small experts (if any)
|
| 484 |
+
self.small_experts = nn.ModuleList()
|
| 485 |
+
if self.num_small_experts > 0:
|
| 486 |
+
small_expert_intermediate_size = getattr(config, "small_expert_intermediate_size",
|
| 487 |
+
config.intermediate_size // 2) # Default to half size
|
| 488 |
+
self.small_experts = nn.ModuleList([
|
| 489 |
+
SmallOlmoeMLP(config, small_expert_intermediate_size)
|
| 490 |
+
for _ in range(self.num_small_experts)
|
| 491 |
+
])
|
| 492 |
|
| 493 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 494 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
|
|
| 496 |
router_logits = self.gate(hidden_states)
|
| 497 |
routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 498 |
|
|
|
|
| 499 |
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
|
| 500 |
|
| 501 |
if self.norm_topk_prob:
|
|
|
|
| 508 |
device=hidden_states.device,
|
| 509 |
)
|
| 510 |
|
| 511 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.total_experts).permute(2, 1, 0)
|
| 512 |
|
| 513 |
+
# Process regular experts
|
| 514 |
for expert_idx in range(self.num_experts):
|
| 515 |
expert_layer = self.experts[expert_idx]
|
| 516 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
|
|
| 520 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 521 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 522 |
|
| 523 |
+
# Process small experts
|
| 524 |
+
for small_expert_idx in range(self.num_small_experts):
|
| 525 |
+
expert_layer = self.small_experts[small_expert_idx]
|
| 526 |
+
# Offset by num_experts since small experts come after regular ones
|
| 527 |
+
global_expert_idx = self.num_experts + small_expert_idx
|
| 528 |
+
idx, top_x = torch.where(expert_mask[global_expert_idx])
|
| 529 |
+
if top_x.numel() == 0:
|
| 530 |
+
continue
|
| 531 |
+
current_state = hidden_states[top_x]
|
| 532 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 533 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 534 |
+
|
| 535 |
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 536 |
return final_hidden_states, router_logits
|
| 537 |
+
|
| 538 |
+
|
| 539 |
class OlmoeDecoderLayer(nn.Module):
|
| 540 |
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
| 541 |
super().__init__()
|
scripts/train.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/train_small_experts.py
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import TrainingArguments, Trainer
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
class CustomDataset(Dataset):
|
| 9 |
+
def __init__(self, tokenizer, dataset_name="allenai/tulu-v2-sft-mixture", max_length=512):
|
| 10 |
+
self.dataset = load_dataset(dataset_name)
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
self.max_length = max_length
|
| 13 |
+
|
| 14 |
+
def __len__(self):
|
| 15 |
+
return len(self.dataset)
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, idx):
|
| 18 |
+
item = self.dataset[idx]
|
| 19 |
+
text = item["text"] # Adjust based on your dataset structure
|
| 20 |
+
encoding = self.tokenizer(
|
| 21 |
+
text,
|
| 22 |
+
max_length=self.max_length,
|
| 23 |
+
padding="max_length",
|
| 24 |
+
truncation=True,
|
| 25 |
+
return_tensors="pt"
|
| 26 |
+
)
|
| 27 |
+
return {
|
| 28 |
+
"input_ids": encoding["input_ids"].squeeze(),
|
| 29 |
+
"attention_mask": encoding["attention_mask"].squeeze(),
|
| 30 |
+
"labels": encoding["input_ids"].squeeze()
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
# Load base model
|
| 35 |
+
model_path = "myolmoe"
|
| 36 |
+
base_model = MyOlmoeForCausalLM.from_pretrained(model_path)
|
| 37 |
+
|
| 38 |
+
# Create new config with small experts
|
| 39 |
+
config = base_model.config
|
| 40 |
+
config.num_small_experts = 64 # Add 64 small experts
|
| 41 |
+
config.small_expert_intermediate_size = 512 # Half the size of regular experts
|
| 42 |
+
|
| 43 |
+
# Initialize new model with same weights but expanded architecture
|
| 44 |
+
model = MyOlmoeForCausalLM(config)
|
| 45 |
+
|
| 46 |
+
# Copy existing weights
|
| 47 |
+
model.load_state_dict(base_model.state_dict(), strict=False)
|
| 48 |
+
|
| 49 |
+
# Initialize small experts (they'll start with random weights)
|
| 50 |
+
# You might want to initialize them differently, perhaps with smaller variance
|
| 51 |
+
|
| 52 |
+
# Prepare dataset
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 54 |
+
dataset = CustomDataset(tokenizer)
|
| 55 |
+
|
| 56 |
+
# Training arguments
|
| 57 |
+
training_args = TrainingArguments(
|
| 58 |
+
output_dir="./output",
|
| 59 |
+
per_device_train_batch_size=4,
|
| 60 |
+
gradient_accumulation_steps=8,
|
| 61 |
+
learning_rate=1e-4,
|
| 62 |
+
num_train_epochs=3,
|
| 63 |
+
logging_dir="./logs",
|
| 64 |
+
save_strategy="steps",
|
| 65 |
+
save_steps=1000,
|
| 66 |
+
evaluation_strategy="steps",
|
| 67 |
+
eval_steps=500,
|
| 68 |
+
fp16=True,
|
| 69 |
+
gradient_checkpointing=True,
|
| 70 |
+
report_to="tensorboard"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Trainer
|
| 74 |
+
trainer = Trainer(
|
| 75 |
+
model=model,
|
| 76 |
+
args=training_args,
|
| 77 |
+
train_dataset=dataset,
|
| 78 |
+
eval_dataset=dataset, # In practice, use a separate validation set
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Train
|
| 82 |
+
trainer.train()
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|