Charlie81 commited on
Commit
48f0e60
·
1 Parent(s): ac9f1eb

initial stuff

Browse files
myolmoe/config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "architectures": [
3
- "OlmoeForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
@@ -15,6 +15,8 @@
15
  "norm_topk_prob": false,
16
  "num_attention_heads": 16,
17
  "num_experts": 64,
 
 
18
  "num_experts_per_tok": 2,
19
  "num_hidden_layers": 16,
20
  "num_key_value_heads": 16,
@@ -29,4 +31,4 @@
29
  "transformers_version": "4.52.4",
30
  "use_cache": true,
31
  "vocab_size": 50304
32
- }
 
1
  {
2
  "architectures": [
3
+ "MyOlmoeForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
15
  "norm_topk_prob": false,
16
  "num_attention_heads": 16,
17
  "num_experts": 64,
18
+ "num_small_experts": 64,
19
+ "small_expert_intermediate_size": 512,
20
  "num_experts_per_tok": 2,
21
  "num_hidden_layers": 16,
22
  "num_key_value_heads": 16,
 
31
  "transformers_version": "4.52.4",
32
  "use_cache": true,
33
  "vocab_size": 50304
34
+ }
myolmoe/modeling_myolmoe.py CHANGED
@@ -156,6 +156,21 @@ class OlmoeMLP(nn.Module):
156
  def forward(self, x):
157
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
  return down_proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -446,15 +461,34 @@ OLMOE_ATTENTION_CLASSES = {
446
  }
447
 
448
 
 
449
  class OlmoeSparseMoeBlock(nn.Module):
450
  def __init__(self, config, layer_idx: int):
451
  super().__init__()
452
  self.layer_idx = layer_idx
453
  self.num_experts = config.num_experts
 
 
454
  self.top_k = config.num_experts_per_tok
455
  self.norm_topk_prob = config.norm_topk_prob
456
- self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
 
 
 
 
 
 
457
  self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
 
 
 
 
 
 
 
 
 
 
458
 
459
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
460
  batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -462,7 +496,6 @@ class OlmoeSparseMoeBlock(nn.Module):
462
  router_logits = self.gate(hidden_states)
463
  routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
464
 
465
- # === Routing ===
466
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
467
 
468
  if self.norm_topk_prob:
@@ -475,8 +508,9 @@ class OlmoeSparseMoeBlock(nn.Module):
475
  device=hidden_states.device,
476
  )
477
 
478
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
479
 
 
480
  for expert_idx in range(self.num_experts):
481
  expert_layer = self.experts[expert_idx]
482
  idx, top_x = torch.where(expert_mask[expert_idx])
@@ -486,10 +520,22 @@ class OlmoeSparseMoeBlock(nn.Module):
486
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
487
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
488
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
490
  return final_hidden_states, router_logits
491
-
492
-
493
  class OlmoeDecoderLayer(nn.Module):
494
  def __init__(self, config: OlmoeConfig, layer_idx: int):
495
  super().__init__()
 
156
  def forward(self, x):
157
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
  return down_proj
159
+
160
+ class SmallOlmoeMLP(nn.Module):
161
+ def __init__(self, config, small_expert_intermediate_size):
162
+ super().__init__()
163
+ self.config = config
164
+ self.hidden_size = config.hidden_size
165
+ self.intermediate_size = small_expert_intermediate_size
166
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
167
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
168
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
169
+ self.act_fn = ACT2FN[config.hidden_act]
170
+
171
+ def forward(self, x):
172
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
173
+ return down_proj
174
 
175
 
176
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
461
  }
462
 
463
 
464
+
465
  class OlmoeSparseMoeBlock(nn.Module):
466
  def __init__(self, config, layer_idx: int):
467
  super().__init__()
468
  self.layer_idx = layer_idx
469
  self.num_experts = config.num_experts
470
+ self.num_small_experts = getattr(config, "num_small_experts", 0) # Default to 0 if not specified
471
+ self.total_experts = self.num_experts + self.num_small_experts
472
  self.top_k = config.num_experts_per_tok
473
  self.norm_topk_prob = config.norm_topk_prob
474
+ self.routing_type = getattr(config, "routing_type", "topk")
475
+ self.n_step = getattr(config, "nth_step", 2)
476
+
477
+ # Gate now needs to handle both regular and small experts
478
+ self.gate = nn.Linear(config.hidden_size, self.total_experts, bias=False)
479
+
480
+ # Regular experts
481
  self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
482
+
483
+ # Small experts (if any)
484
+ self.small_experts = nn.ModuleList()
485
+ if self.num_small_experts > 0:
486
+ small_expert_intermediate_size = getattr(config, "small_expert_intermediate_size",
487
+ config.intermediate_size // 2) # Default to half size
488
+ self.small_experts = nn.ModuleList([
489
+ SmallOlmoeMLP(config, small_expert_intermediate_size)
490
+ for _ in range(self.num_small_experts)
491
+ ])
492
 
493
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
494
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
496
  router_logits = self.gate(hidden_states)
497
  routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
498
 
 
499
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
500
 
501
  if self.norm_topk_prob:
 
508
  device=hidden_states.device,
509
  )
510
 
511
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.total_experts).permute(2, 1, 0)
512
 
513
+ # Process regular experts
514
  for expert_idx in range(self.num_experts):
515
  expert_layer = self.experts[expert_idx]
516
  idx, top_x = torch.where(expert_mask[expert_idx])
 
520
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
521
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
522
 
523
+ # Process small experts
524
+ for small_expert_idx in range(self.num_small_experts):
525
+ expert_layer = self.small_experts[small_expert_idx]
526
+ # Offset by num_experts since small experts come after regular ones
527
+ global_expert_idx = self.num_experts + small_expert_idx
528
+ idx, top_x = torch.where(expert_mask[global_expert_idx])
529
+ if top_x.numel() == 0:
530
+ continue
531
+ current_state = hidden_states[top_x]
532
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
533
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
534
+
535
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
536
  return final_hidden_states, router_logits
537
+
538
+
539
  class OlmoeDecoderLayer(nn.Module):
540
  def __init__(self, config: OlmoeConfig, layer_idx: int):
541
  super().__init__()
scripts/train.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/train_small_experts.py
2
+ import torch
3
+ from transformers import TrainingArguments, Trainer
4
+ from datasets import load_dataset
5
+ from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
6
+ from torch.utils.data import Dataset
7
+
8
+ class CustomDataset(Dataset):
9
+ def __init__(self, tokenizer, dataset_name="allenai/tulu-v2-sft-mixture", max_length=512):
10
+ self.dataset = load_dataset(dataset_name)
11
+ self.tokenizer = tokenizer
12
+ self.max_length = max_length
13
+
14
+ def __len__(self):
15
+ return len(self.dataset)
16
+
17
+ def __getitem__(self, idx):
18
+ item = self.dataset[idx]
19
+ text = item["text"] # Adjust based on your dataset structure
20
+ encoding = self.tokenizer(
21
+ text,
22
+ max_length=self.max_length,
23
+ padding="max_length",
24
+ truncation=True,
25
+ return_tensors="pt"
26
+ )
27
+ return {
28
+ "input_ids": encoding["input_ids"].squeeze(),
29
+ "attention_mask": encoding["attention_mask"].squeeze(),
30
+ "labels": encoding["input_ids"].squeeze()
31
+ }
32
+
33
+ def main():
34
+ # Load base model
35
+ model_path = "myolmoe"
36
+ base_model = MyOlmoeForCausalLM.from_pretrained(model_path)
37
+
38
+ # Create new config with small experts
39
+ config = base_model.config
40
+ config.num_small_experts = 64 # Add 64 small experts
41
+ config.small_expert_intermediate_size = 512 # Half the size of regular experts
42
+
43
+ # Initialize new model with same weights but expanded architecture
44
+ model = MyOlmoeForCausalLM(config)
45
+
46
+ # Copy existing weights
47
+ model.load_state_dict(base_model.state_dict(), strict=False)
48
+
49
+ # Initialize small experts (they'll start with random weights)
50
+ # You might want to initialize them differently, perhaps with smaller variance
51
+
52
+ # Prepare dataset
53
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
54
+ dataset = CustomDataset(tokenizer)
55
+
56
+ # Training arguments
57
+ training_args = TrainingArguments(
58
+ output_dir="./output",
59
+ per_device_train_batch_size=4,
60
+ gradient_accumulation_steps=8,
61
+ learning_rate=1e-4,
62
+ num_train_epochs=3,
63
+ logging_dir="./logs",
64
+ save_strategy="steps",
65
+ save_steps=1000,
66
+ evaluation_strategy="steps",
67
+ eval_steps=500,
68
+ fp16=True,
69
+ gradient_checkpointing=True,
70
+ report_to="tensorboard"
71
+ )
72
+
73
+ # Trainer
74
+ trainer = Trainer(
75
+ model=model,
76
+ args=training_args,
77
+ train_dataset=dataset,
78
+ eval_dataset=dataset, # In practice, use a separate validation set
79
+ )
80
+
81
+ # Train
82
+ trainer.train()
83
+
84
+ if __name__ == "__main__":
85
+ main()