mineself2016 commited on
Commit
6a44f5b
·
verified ·
1 Parent(s): f141719

Update Phase 3 to next-token train-from-scratch with checkpoint auto-resume

Browse files
Files changed (1) hide show
  1. examples/4_pretrain_from_scratch.py +23 -26
examples/4_pretrain_from_scratch.py CHANGED
@@ -13,10 +13,12 @@ from torch.utils.data import Dataset
13
  from pathlib import Path
14
  from transformers import (
15
  AutoTokenizer,
 
16
  AutoModelForMaskedLM,
17
  Trainer,
18
  TrainingArguments,
19
  )
 
20
 
21
 
22
  class PretrainingDataset(Dataset):
@@ -96,7 +98,8 @@ def main():
96
  print("=" * 80)
97
 
98
  model_id = "mineself2016/GeneMamba"
99
- checkpoint_dir = Path("./from_scratch_pretrain/checkpoint-last")
 
100
 
101
  # ============================================================
102
  # Step 1: Load tokenizer spec
@@ -117,33 +120,29 @@ def main():
117
  # ============================================================
118
  print("\n[Step 2] Building model (resume if checkpoint exists)...")
119
 
120
- from configuration_genemamba import GeneMambaConfig
121
- from modeling_genemamba import GeneMambaForMaskedLM
122
-
123
- model_config = GeneMambaConfig(
124
- vocab_size=25426,
125
- hidden_size=256, # Smaller for faster demo
126
- num_hidden_layers=12, # Reduced for demo
127
- intermediate_size=1024,
128
- max_position_embeddings=2048,
129
- mamba_mode="mean",
130
- embedding_pooling="mean",
131
- num_labels=2,
132
- hidden_dropout_prob=0.1,
133
- initializer_range=0.02,
134
- )
135
-
136
  if checkpoint_dir.exists():
 
 
 
 
 
137
  model = AutoModelForMaskedLM.from_pretrained(
138
- str(checkpoint_dir),
139
  trust_remote_code=True,
140
  local_files_only=True,
141
  )
142
- resume_from_checkpoint = str(checkpoint_dir)
143
- print(f"✓ Found checkpoint, resume from: {checkpoint_dir}")
144
  else:
145
- model = GeneMambaForMaskedLM(model_config)
146
- resume_from_checkpoint = None
147
  print("✓ No checkpoint found, start from scratch")
148
 
149
  # Count parameters
@@ -186,8 +185,6 @@ def main():
186
  # ============================================================
187
  print("\n[Step 5] Setting up training...")
188
 
189
- output_dir = "./from_scratch_pretrain"
190
-
191
  training_args = TrainingArguments(
192
  output_dir=output_dir,
193
  num_train_epochs=5,
@@ -297,8 +294,8 @@ def main():
297
  print("Phase 3 Complete! Model trained from scratch and ready to use.")
298
  print("=" * 80)
299
 
300
- return model, trainer, config
301
 
302
 
303
  if __name__ == "__main__":
304
- model, trainer, config = main()
 
13
  from pathlib import Path
14
  from transformers import (
15
  AutoTokenizer,
16
+ AutoConfig,
17
  AutoModelForMaskedLM,
18
  Trainer,
19
  TrainingArguments,
20
  )
21
+ from transformers.trainer_utils import get_last_checkpoint
22
 
23
 
24
  class PretrainingDataset(Dataset):
 
98
  print("=" * 80)
99
 
100
  model_id = "mineself2016/GeneMamba"
101
+ output_dir = "./from_scratch_pretrain"
102
+ checkpoint_dir = Path(output_dir) / "checkpoint-last"
103
 
104
  # ============================================================
105
  # Step 1: Load tokenizer spec
 
120
  # ============================================================
121
  print("\n[Step 2] Building model (resume if checkpoint exists)...")
122
 
123
+ model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
124
+ model_config.vocab_size = 25426
125
+ model_config.hidden_size = 256
126
+ model_config.num_hidden_layers = 12
127
+ model_config.intermediate_size = 1024
128
+ model_config.max_position_embeddings = 2048
129
+ model_config.mamba_mode = "mean"
130
+
131
+ resume_from_checkpoint = None
 
 
 
 
 
 
 
132
  if checkpoint_dir.exists():
133
+ resume_from_checkpoint = str(checkpoint_dir)
134
+ else:
135
+ resume_from_checkpoint = get_last_checkpoint(output_dir)
136
+
137
+ if resume_from_checkpoint is not None:
138
  model = AutoModelForMaskedLM.from_pretrained(
139
+ resume_from_checkpoint,
140
  trust_remote_code=True,
141
  local_files_only=True,
142
  )
143
+ print(f"✓ Found checkpoint, resume from: {resume_from_checkpoint}")
 
144
  else:
145
+ model = AutoModelForMaskedLM.from_config(model_config, trust_remote_code=True)
 
146
  print("✓ No checkpoint found, start from scratch")
147
 
148
  # Count parameters
 
185
  # ============================================================
186
  print("\n[Step 5] Setting up training...")
187
 
 
 
188
  training_args = TrainingArguments(
189
  output_dir=output_dir,
190
  num_train_epochs=5,
 
294
  print("Phase 3 Complete! Model trained from scratch and ready to use.")
295
  print("=" * 80)
296
 
297
+ return model, trainer, model_config
298
 
299
 
300
  if __name__ == "__main__":
301
+ model, trainer, model_config = main()