Charlie81 commited on
Commit
c4785c5
·
1 Parent(s): 36acce3
Files changed (3) hide show
  1. myolmoe/config.json +4 -1
  2. myolmoe/modeling_myolmoe.py +145 -14
  3. scripts/train.py +60 -148
myolmoe/config.json CHANGED
@@ -30,5 +30,8 @@
30
  "torch_dtype": "float32",
31
  "transformers_version": "4.52.4",
32
  "use_cache": true,
33
- "vocab_size": 50304
 
 
 
34
  }
 
30
  "torch_dtype": "float32",
31
  "transformers_version": "4.52.4",
32
  "use_cache": true,
33
+ "vocab_size": 50304,
34
+ "small_expert_intermediate_ratio": 0.5,
35
+ "small_expert_frequency": 4,
36
+ "small_expert_load_balancing_coef": 0.1
37
  }
myolmoe/modeling_myolmoe.py CHANGED
@@ -14,7 +14,103 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
14
  from transformers.modeling_utils import PreTrainedModel
15
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
  from transformers.utils import logging
17
- from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  logger = logging.get_logger(__name__)
20
 
@@ -143,21 +239,25 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
143
 
144
 
145
  class OlmoeMLP(nn.Module):
146
- def __init__(self, config):
147
  super().__init__()
148
  self.config = config
149
  self.hidden_size = config.hidden_size
150
- self.intermediate_size = config.intermediate_size
 
 
 
 
151
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
152
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
153
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
154
  self.act_fn = ACT2FN[config.hidden_act]
 
155
 
156
  def forward(self, x):
157
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
  return down_proj
159
 
160
-
161
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
162
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
163
  if n_rep == 1:
@@ -446,6 +546,7 @@ OLMOE_ATTENTION_CLASSES = {
446
  }
447
 
448
 
 
449
  class OlmoeSparseMoeBlock(nn.Module):
450
  def __init__(self, config, layer_idx: int):
451
  super().__init__()
@@ -453,10 +554,21 @@ class OlmoeSparseMoeBlock(nn.Module):
453
  self.num_experts = config.num_experts
454
  self.top_k = config.num_experts_per_tok
455
  self.norm_topk_prob = config.norm_topk_prob
456
- self.routing_type = getattr(config, "routing_type", "topk") # default to topk
457
- self.n_step = getattr(config, "nth_step", 2) # used in nth-descending
 
 
 
 
 
 
 
 
 
 
 
458
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
459
- self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
460
 
461
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
462
  batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -464,7 +576,6 @@ class OlmoeSparseMoeBlock(nn.Module):
464
  router_logits = self.gate(hidden_states)
465
  routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
466
 
467
- # === Routing ===
468
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
469
 
470
  if self.norm_topk_prob:
@@ -479,6 +590,18 @@ class OlmoeSparseMoeBlock(nn.Module):
479
 
480
  expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
481
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  for expert_idx in range(self.num_experts):
483
  expert_layer = self.experts[expert_idx]
484
  idx, top_x = torch.where(expert_mask[expert_idx])
@@ -489,8 +612,7 @@ class OlmoeSparseMoeBlock(nn.Module):
489
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
490
 
491
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
492
- return final_hidden_states, router_logits
493
-
494
 
495
  class OlmoeDecoderLayer(nn.Module):
496
  def __init__(self, config: OlmoeConfig, layer_idx: int):
@@ -536,9 +658,9 @@ class OlmoeDecoderLayer(nn.Module):
536
  hidden_states = residual + hidden_states
537
  residual = hidden_states
538
  hidden_states = self.post_attention_layernorm(hidden_states)
539
- hidden_states, router_logits = self.mlp(hidden_states)
540
- hidden_states = residual + hidden_states
541
- outputs = (hidden_states,)
542
  if output_attentions:
543
  outputs += (self_attn_weights,)
544
  if use_cache:
@@ -942,6 +1064,15 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
942
  if output_router_logits:
943
  output = (aux_loss,) + output
944
  return (loss,) + output if loss is not None else output
 
 
 
 
 
 
 
 
 
945
  return MoeCausalLMOutputWithPast(
946
  loss=loss,
947
  aux_loss=aux_loss,
@@ -952,4 +1083,4 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
952
  router_logits=outputs.router_logits,
953
  )
954
 
955
- __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
 
14
  from transformers.modeling_utils import PreTrainedModel
15
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
  from transformers.utils import logging
17
+ # from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.modeling_rope_utils import rope_config_validation
20
+
21
+ class OlmoeConfig(PretrainedConfig):
22
+ r"""
23
+ This is the configuration class to store the configuration of a [`OlmoeModel`].
24
+ [Previous docstring remains the same...]
25
+
26
+ Args:
27
+ [Previous args remain the same...]
28
+ small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
29
+ Ratio of intermediate size for small experts compared to regular experts.
30
+ small_expert_frequency (`int`, *optional*, defaults to 4):
31
+ Frequency of small experts - every Nth expert will be small.
32
+ small_expert_load_balancing_coef (`float`, *optional*, defaults to 0.1):
33
+ Coefficient for small expert load balancing loss.
34
+ """
35
+ model_type = "olmoe"
36
+ keys_to_ignore_at_inference = ["past_key_values"]
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_size=50304,
41
+ hidden_size=2048,
42
+ intermediate_size=2048,
43
+ num_hidden_layers=16,
44
+ num_attention_heads=16,
45
+ num_key_value_heads=None,
46
+ hidden_act="silu",
47
+ max_position_embeddings=4096,
48
+ initializer_range=0.02,
49
+ rms_norm_eps=1e-05,
50
+ use_cache=True,
51
+ pad_token_id=1,
52
+ bos_token_id=None,
53
+ eos_token_id=50279,
54
+ tie_word_embeddings=False,
55
+ rope_theta=10000.0,
56
+ rope_scaling=None,
57
+ attention_bias=False,
58
+ attention_dropout=0.0,
59
+ clip_qkv=None,
60
+ num_experts_per_tok=8,
61
+ num_experts=64,
62
+ output_router_logits=False,
63
+ router_aux_loss_coef=0.01,
64
+ norm_topk_prob=False,
65
+ small_expert_intermediate_ratio=0.5,
66
+ small_expert_frequency=4,
67
+ small_expert_load_balancing_coef=0.1,
68
+ **kwargs,
69
+ ):
70
+ self.vocab_size = vocab_size
71
+ self.max_position_embeddings = max_position_embeddings
72
+ self.hidden_size = hidden_size
73
+ self.intermediate_size = intermediate_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+
77
+ # for backward compatibility
78
+ if num_key_value_heads is None:
79
+ num_key_value_heads = num_attention_heads
80
+
81
+ self.num_key_value_heads = num_key_value_heads
82
+ self.hidden_act = hidden_act
83
+ self.initializer_range = initializer_range
84
+ self.rms_norm_eps = rms_norm_eps
85
+ self.use_cache = use_cache
86
+ self.rope_theta = rope_theta
87
+ self.rope_scaling = rope_scaling
88
+ self.attention_bias = attention_bias
89
+ self.attention_dropout = attention_dropout
90
+ self.clip_qkv = clip_qkv
91
+ self.num_experts_per_tok = num_experts_per_tok
92
+ self.num_experts = num_experts
93
+ self.output_router_logits = output_router_logits
94
+ self.router_aux_loss_coef = router_aux_loss_coef
95
+ self.norm_topk_prob = norm_topk_prob
96
+
97
+ # Small expert parameters
98
+ self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
99
+ self.small_expert_frequency = small_expert_frequency
100
+ self.small_expert_load_balancing_coef = small_expert_load_balancing_coef
101
+
102
+ # Validate the correctness of rotary position embeddings parameters
103
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
104
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
105
+ rope_config_validation(self)
106
+
107
+ super().__init__(
108
+ pad_token_id=pad_token_id,
109
+ bos_token_id=bos_token_id,
110
+ eos_token_id=eos_token_id,
111
+ tie_word_embeddings=tie_word_embeddings,
112
+ **kwargs,
113
+ )
114
 
115
  logger = logging.get_logger(__name__)
116
 
 
239
 
240
 
241
  class OlmoeMLP(nn.Module):
242
+ def __init__(self, config, is_small=False):
243
  super().__init__()
244
  self.config = config
245
  self.hidden_size = config.hidden_size
246
+ if is_small:
247
+ self.intermediate_size = int(config.intermediate_size * config.small_expert_intermediate_ratio)
248
+ else:
249
+ self.intermediate_size = config.intermediate_size
250
+
251
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
252
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
253
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
254
  self.act_fn = ACT2FN[config.hidden_act]
255
+ self.is_small = is_small
256
 
257
  def forward(self, x):
258
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
259
  return down_proj
260
 
 
261
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
262
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
263
  if n_rep == 1:
 
546
  }
547
 
548
 
549
+
550
  class OlmoeSparseMoeBlock(nn.Module):
551
  def __init__(self, config, layer_idx: int):
552
  super().__init__()
 
554
  self.num_experts = config.num_experts
555
  self.top_k = config.num_experts_per_tok
556
  self.norm_topk_prob = config.norm_topk_prob
557
+ self.routing_type = getattr(config, "routing_type", "topk")
558
+ self.n_step = getattr(config, "nth_step", 2)
559
+
560
+ # Track which experts are small
561
+ self.small_expert_indices = []
562
+ self.experts = nn.ModuleList()
563
+
564
+ for i in range(self.num_experts):
565
+ is_small = (i % config.small_expert_frequency == 0)
566
+ if is_small:
567
+ self.small_expert_indices.append(i)
568
+ self.experts.append(OlmoeMLP(config, is_small=is_small))
569
+
570
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
571
+ self.small_expert_load_balancing_coef = config.small_expert_load_balancing_coef
572
 
573
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
574
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
576
  router_logits = self.gate(hidden_states)
577
  routing_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
578
 
 
579
  routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
580
 
581
  if self.norm_topk_prob:
 
590
 
591
  expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
592
 
593
+ # Calculate small expert load balancing loss
594
+ small_expert_mask = torch.zeros_like(expert_mask)
595
+ for idx in self.small_expert_indices:
596
+ small_expert_mask[idx] = expert_mask[idx]
597
+
598
+ small_expert_loss = load_balancing_loss_func(
599
+ router_logits,
600
+ self.num_experts,
601
+ self.top_k,
602
+ None
603
+ ) * self.small_expert_load_balancing_coef
604
+
605
  for expert_idx in range(self.num_experts):
606
  expert_layer = self.experts[expert_idx]
607
  idx, top_x = torch.where(expert_mask[expert_idx])
 
612
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
613
 
614
  final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
615
+ return final_hidden_states, router_logits, small_expert_loss
 
616
 
617
  class OlmoeDecoderLayer(nn.Module):
618
  def __init__(self, config: OlmoeConfig, layer_idx: int):
 
658
  hidden_states = residual + hidden_states
659
  residual = hidden_states
660
  hidden_states = self.post_attention_layernorm(hidden_states)
661
+ hidden_states, router_logits, small_expert_loss = self.mlp(hidden_states) #
662
+ hidden_states = residual + hidden_states #
663
+ outputs = (hidden_states, small_expert_loss) #
664
  if output_attentions:
665
  outputs += (self_attn_weights,)
666
  if use_cache:
 
1064
  if output_router_logits:
1065
  output = (aux_loss,) + output
1066
  return (loss,) + output if loss is not None else output
1067
+ #
1068
+ total_small_expert_loss = 0
1069
+ for layer_output in outputs:
1070
+ if len(layer_output) > 1 and isinstance(layer_output[1], torch.Tensor):
1071
+ total_small_expert_loss += layer_output[1]
1072
+
1073
+ if labels is not None:
1074
+ loss += total_small_expert_loss.to(loss.device)
1075
+ #
1076
  return MoeCausalLMOutputWithPast(
1077
  loss=loss,
1078
  aux_loss=aux_loss,
 
1083
  router_logits=outputs.router_logits,
1084
  )
1085
 
1086
+ __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel", "OlmoeConfig"]
scripts/train.py CHANGED
@@ -1,170 +1,82 @@
1
- # scripts/train_small_experts.py
2
  import torch
3
- from transformers import TrainingArguments, Trainer, AutoTokenizer
 
 
 
 
 
 
4
  from datasets import load_dataset
5
- from myolmoe.modeling_myolmoe import MyOlmoeForCausalLM, OlmoeConfig
6
- from torch.utils.data import Dataset
7
  import os
8
- from tqdm import tqdm
9
 
10
- class CustomDataset(Dataset):
11
- def __init__(self, tokenizer, dataset_name="allenai/tulu-v2-sft-mixture", max_length=512):
12
- print(f"# DEBUG: Loading dataset '{dataset_name}' with max length {max_length}")
13
- self.dataset = load_dataset(dataset_name, split="train") # Use train split
14
- self.tokenizer = tokenizer
15
- self.max_length = max_length
16
-
17
- def __len__(self):
18
- return len(self.dataset)
19
-
20
- def __getitem__(self, idx):
21
- item = self.dataset[idx]
22
- text = item["text"] # Adjust based on your dataset structure
23
- encoding = self.tokenizer(
24
- text,
25
- max_length=self.max_length,
26
- padding="max_length",
27
- truncation=True,
28
- return_tensors="pt"
29
- )
30
- # DEBUG: Print the first few token IDs for inspection
31
- if idx == 0:
32
- print(f"# DEBUG: Sample input text: {text[:100]}")
33
- print(f"# DEBUG: Tokenized input_ids[:10]: {encoding['input_ids'][0][:10]}")
34
- return {
35
- "input_ids": encoding["input_ids"].squeeze(),
36
- "attention_mask": encoding["attention_mask"].squeeze(),
37
- "labels": encoding["input_ids"].squeeze()
38
- }
39
-
40
- def expand_model_with_small_experts(base_model):
41
- print("# DEBUG: Expanding model with small experts...")
42
- config = base_model.config
43
- config.num_small_experts = 64 # Add 64 small experts
44
- config.small_expert_intermediate_size = config.intermediate_size // 32
45
- expanded_model = MyOlmoeForCausalLM(config)
46
-
47
- base_state_dict = base_model.state_dict()
48
- expanded_state_dict = expanded_model.state_dict()
49
 
50
- print("# DEBUG: Copying non-expert parameters...")
51
- for name, param in base_state_dict.items():
52
- if "experts" not in name and "gate" not in name:
53
- if name in expanded_state_dict:
54
- expanded_state_dict[name].copy_(param)
55
- else:
56
- print(f"# DEBUG: Skipped non-expert param {name} (not found in expanded model)")
57
 
58
- print("# DEBUG: Copying expert weights...")
59
- for i in range(config.num_experts):
60
- for proj in ['gate_proj', 'up_proj', 'down_proj']:
61
- key = f'model.layers.{i}.mlp.experts.{i}.{proj}.weight'
62
- if key in base_state_dict:
63
- orig_weight = base_state_dict[key]
64
- target_weight = expanded_state_dict[key]
65
-
66
- if proj == 'down_proj':
67
- # For down_proj, we copy the first part of the input dimension
68
- target_weight.copy_(orig_weight[:, :config.small_expert_intermediate_size])
69
- else:
70
- # For gate_proj and up_proj, we copy the first part of the output dimension
71
- target_weight.copy_(orig_weight[:config.small_expert_intermediate_size, :])
72
-
73
- print(f"# DEBUG: Copied {proj} weights for expert {i} "
74
- f"(original shape: {orig_weight.shape}, new shape: {target_weight.shape})")
75
- else:
76
- print(f"# DEBUG: Missing {key} in base model")
77
-
78
- print("# DEBUG: Expanding and initializing gate weights...")
79
- for i in range(config.num_hidden_layers):
80
- gate_key = f'model.layers.{i}.mlp.gate.weight'
81
- if gate_key in base_state_dict:
82
- original_gate = base_state_dict[gate_key]
83
- new_gate = expanded_state_dict[gate_key]
84
-
85
- # Copy original gate weights
86
- new_gate[:, :config.num_experts].copy_(original_gate)
87
-
88
- # Initialize small experts gate weights
89
- torch.nn.init.normal_(
90
- new_gate[:, config.num_experts:],
91
- mean=0.0,
92
- std=config.initializer_range * 0.1
93
- )
94
- print(f"# DEBUG: Initialized gate for layer {i} "
95
- f"(original shape: {original_gate.shape}, new shape: {new_gate.shape})")
96
- else:
97
- print(f"# DEBUG: Missing gate weight {gate_key}")
98
-
99
- print("# DEBUG: Loading expanded state dict into model...")
100
- expanded_model.load_state_dict(expanded_state_dict, strict=False)
101
 
102
- return expanded_model
103
- def main():
104
- model_path = "myolmoe"
105
- print("# DEBUG: Loading base model...")
106
- base_model = MyOlmoeForCausalLM.from_pretrained(model_path)
 
 
107
 
108
- print(f"# DEBUG: Base model has {base_model.config.num_experts} experts")
109
-
110
- print("# DEBUG: Calling expand_model_with_small_experts()...")
111
- model = expand_model_with_small_experts(base_model)
 
 
112
 
113
- print(f"# DEBUG: Expanded model has {model.config.num_experts} regular experts and {model.config.num_small_experts} small experts")
114
-
115
- print("# DEBUG: Loading tokenizer and preparing dataset...")
116
- tokenizer = AutoTokenizer.from_pretrained(model_path)
117
- dataset = CustomDataset(tokenizer)
118
-
119
- print("# DEBUG: Setting up training arguments...")
120
  training_args = TrainingArguments(
121
  output_dir="./output",
122
- per_device_train_batch_size=4,
123
  gradient_accumulation_steps=8,
124
- learning_rate=1e-4,
125
- num_train_epochs=3,
126
  logging_dir="./logs",
127
- save_strategy="steps",
128
  save_steps=1000,
129
- evaluation_strategy="steps",
130
- eval_steps=500,
131
- fp16=True,
132
  gradient_checkpointing=True,
133
- report_to="tensorboard"
 
 
 
 
134
  )
135
-
136
- class MoETrainer(Trainer):
137
- def __init__(self, *args, **kwargs):
138
- self.freeze_existing = kwargs.pop('freeze_existing_experts', False)
139
- super().__init__(*args, **kwargs)
140
-
141
- if self.freeze_existing:
142
- print("# DEBUG: Freezing original expert parameters...")
143
- frozen_count = 0
144
- for name, param in self.model.named_parameters():
145
- if "mlp.experts" in name and "small_experts" not in name:
146
- param.requires_grad = False
147
- frozen_count += 1
148
- print(f"# DEBUG: Total frozen expert parameters: {frozen_count}")
149
-
150
- print("# DEBUG: Initializing trainer...")
151
- trainer = MoETrainer(
152
  model=model,
153
  args=training_args,
154
- train_dataset=dataset,
155
- eval_dataset=dataset,
156
- freeze_existing_experts=True
157
  )
158
-
159
- print("# DEBUG: Starting training...")
160
  trainer.train()
161
-
162
- output_dir = "./final_model"
163
- os.makedirs(output_dir, exist_ok=True)
164
- print(f"# DEBUG: Saving final model to {output_dir}...")
165
- model.save_pretrained(output_dir)
166
- tokenizer.save_pretrained(output_dir)
167
- print("# DEBUG: Training complete!")
168
 
169
  if __name__ == "__main__":
170
- main()
 
1
+ #!/usr/bin/env python3
2
  import torch
3
+ from torch.utils.data import DataLoader
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ TrainingArguments,
7
+ Trainer,
8
+ default_data_collator,
9
+ )
10
  from datasets import load_dataset
11
+ from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
 
12
  import os
 
13
 
14
+ def main():
15
+ # Load config and model
16
+ config = OlmoeConfig.from_pretrained("myolmoe/config.json")
17
+ model = MyOlmoeForCausalLM.from_pretrained(
18
+ "myolmoe",
19
+ config=config,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map="auto"
22
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Load tokenizer
25
+ tokenizer = AutoTokenizer.from_pretrained("myolmoe")
26
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
27
 
28
+ # Load dataset
29
+ dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def tokenize_function(examples):
32
+ return tokenizer(
33
+ examples["text"],
34
+ truncation=True,
35
+ max_length=4096,
36
+ padding="max_length"
37
+ )
38
 
39
+ tokenized_dataset = dataset.map(
40
+ tokenize_function,
41
+ batched=True,
42
+ remove_columns=dataset.column_names,
43
+ num_proc=4
44
+ )
45
 
46
+ # Training arguments
 
 
 
 
 
 
47
  training_args = TrainingArguments(
48
  output_dir="./output",
49
+ per_device_train_batch_size=2,
50
  gradient_accumulation_steps=8,
51
+ learning_rate=1e-5,
52
+ num_train_epochs=1,
53
  logging_dir="./logs",
54
+ logging_steps=10,
55
  save_steps=1000,
56
+ save_total_limit=2,
57
+ bf16=True,
 
58
  gradient_checkpointing=True,
59
+ report_to="tensorboard",
60
+ optim="adamw_torch",
61
+ lr_scheduler_type="cosine",
62
+ warmup_ratio=0.1,
63
+ max_grad_norm=1.0,
64
  )
65
+
66
+ # Trainer
67
+ trainer = Trainer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  model=model,
69
  args=training_args,
70
+ train_dataset=tokenized_dataset,
71
+ tokenizer=tokenizer,
72
+ data_collator=default_data_collator,
73
  )
74
+
75
+ # Train
76
  trainer.train()
77
+
78
+ # Save
79
+ trainer.save_model("./final_model")
 
 
 
 
80
 
81
  if __name__ == "__main__":
82
+ main()