mineself2016 commited on
Commit
dc14676
·
verified ·
1 Parent(s): 9681e2a

Remove duplicate examples/4_pretrain_from_scratch.py

Browse files
Files changed (1) hide show
  1. examples/4_pretrain_from_scratch.py +0 -301
examples/4_pretrain_from_scratch.py DELETED
@@ -1,301 +0,0 @@
1
- """
2
- Phase 3: Train from Scratch (Next-Token Objective)
3
- Demonstrates how to initialize and train GeneMamba with next-token prediction.
4
- If a checkpoint exists, training resumes from checkpoint automatically.
5
-
6
- Usage:
7
- python examples/4_pretrain_from_scratch.py
8
- """
9
-
10
- import torch
11
- import numpy as np
12
- from torch.utils.data import Dataset
13
- from pathlib import Path
14
- from transformers import (
15
- AutoTokenizer,
16
- AutoConfig,
17
- AutoModelForMaskedLM,
18
- Trainer,
19
- TrainingArguments,
20
- )
21
- from transformers.trainer_utils import get_last_checkpoint
22
-
23
-
24
- class PretrainingDataset(Dataset):
25
- """Dataset for pretraining."""
26
-
27
- def __init__(self, input_ids_list, max_length=2048):
28
- self.input_ids_list = input_ids_list
29
- self.max_length = max_length
30
-
31
- def __len__(self):
32
- return len(self.input_ids_list)
33
-
34
- def __getitem__(self, idx):
35
- input_ids = self.input_ids_list[idx]
36
-
37
- # Pad or truncate
38
- if len(input_ids) >= self.max_length:
39
- input_ids = input_ids[:self.max_length]
40
- else:
41
- input_ids = np.pad(
42
- input_ids,
43
- (0, self.max_length - len(input_ids)),
44
- constant_values=1
45
- )
46
-
47
- return {
48
- "input_ids": torch.tensor(input_ids, dtype=torch.long),
49
- }
50
-
51
-
52
- class NextTokenTrainer(Trainer):
53
- """Use next-token prediction loss: logits[:, :-1] vs labels[:, 1:]."""
54
-
55
- def compute_loss(self, model, inputs, return_outputs=False):
56
- input_ids = inputs["input_ids"]
57
- outputs = model(input_ids=input_ids)
58
- logits = outputs.logits
59
-
60
- shift_logits = logits[:, :-1, :].contiguous()
61
- shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device)
62
-
63
- loss_fct = torch.nn.CrossEntropyLoss()
64
- loss = loss_fct(
65
- shift_logits.view(-1, shift_logits.size(-1)),
66
- shift_labels.view(-1),
67
- )
68
-
69
- return (loss, outputs) if return_outputs else loss
70
-
71
-
72
- class NextTokenDataCollator:
73
- """Simple collator for pre-tokenized input_ids (no MLM masking)."""
74
-
75
- def __call__(self, batch):
76
- input_ids = torch.stack([item["input_ids"] for item in batch])
77
- return {"input_ids": input_ids}
78
-
79
-
80
- def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
81
- """Create mock pretraining data."""
82
-
83
- print("Creating mock pretraining dataset for from-scratch training...")
84
-
85
- sequences = []
86
- for _ in range(n_sequences):
87
- seq = np.random.randint(2, 25426, seq_len)
88
- sequences.append(seq)
89
-
90
- print(f"✓ Created {n_sequences} sequences")
91
-
92
- return sequences
93
-
94
-
95
- def main():
96
- print("=" * 80)
97
- print("GeneMamba Phase 3: Train from Scratch (Next-Token)")
98
- print("=" * 80)
99
-
100
- model_id = "mineself2016/GeneMamba"
101
- output_dir = "./from_scratch_pretrain"
102
- checkpoint_dir = Path(output_dir) / "checkpoint-last"
103
-
104
- # ============================================================
105
- # Step 1: Load tokenizer spec
106
- # ============================================================
107
- print("\n[Step 1] Loading tokenizer...")
108
-
109
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
110
-
111
- print("✓ Tokenizer loaded:")
112
- print(f" - vocab_size: {tokenizer.vocab_size}")
113
- print(f" - [UNK] token/id: {tokenizer.unk_token}/{tokenizer.unk_token_id}")
114
- print(f" - [PAD] token/id: {tokenizer.pad_token}/{tokenizer.pad_token_id}")
115
- print(f" - [CLS] token/id: {tokenizer.cls_token}/{tokenizer.cls_token_id}")
116
- print(f" - [MASK] token/id: {tokenizer.mask_token}/{tokenizer.mask_token_id}")
117
-
118
- # ============================================================
119
- # Step 2: Build config and initialize/resume model
120
- # ============================================================
121
- print("\n[Step 2] Building model (resume if checkpoint exists)...")
122
-
123
- model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
124
- model_config.vocab_size = 25426
125
- model_config.hidden_size = 256
126
- model_config.num_hidden_layers = 12
127
- model_config.intermediate_size = 1024
128
- model_config.max_position_embeddings = 2048
129
- model_config.mamba_mode = "mean"
130
-
131
- resume_from_checkpoint = None
132
- if checkpoint_dir.exists():
133
- resume_from_checkpoint = str(checkpoint_dir)
134
- else:
135
- resume_from_checkpoint = get_last_checkpoint(output_dir)
136
-
137
- if resume_from_checkpoint is not None:
138
- model = AutoModelForMaskedLM.from_pretrained(
139
- resume_from_checkpoint,
140
- trust_remote_code=True,
141
- local_files_only=True,
142
- )
143
- print(f"✓ Found checkpoint, resume from: {resume_from_checkpoint}")
144
- else:
145
- model = AutoModelForMaskedLM.from_config(model_config, trust_remote_code=True)
146
- print("✓ No checkpoint found, start from scratch")
147
-
148
- # Count parameters
149
- total_params = sum(p.numel() for p in model.parameters())
150
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
151
-
152
- print(f"✓ Model initialized:")
153
- print(f" - Total parameters: {total_params / 1e6:.2f}M")
154
- print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M")
155
-
156
- # ============================================================
157
- # Step 3: Prepare data
158
- # ============================================================
159
- print("\n[Step 3] Preparing training data...")
160
-
161
- sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
162
-
163
- # Split
164
- train_size = int(0.8 * len(sequences))
165
- train_sequences = sequences[:train_size]
166
- eval_sequences = sequences[train_size:]
167
-
168
- train_dataset = PretrainingDataset(train_sequences)
169
- eval_dataset = PretrainingDataset(eval_sequences)
170
-
171
- print(f"✓ Datasets created:")
172
- print(f" - Train: {len(train_dataset)}")
173
- print(f" - Eval: {len(eval_dataset)}")
174
-
175
- # ============================================================
176
- # Step 4: Data collator for next-token training
177
- # ============================================================
178
- print("\n[Step 4] Setting up data collator...")
179
-
180
- data_collator = NextTokenDataCollator()
181
- print(f"✓ Data collator ready")
182
-
183
- # ============================================================
184
- # Step 5: Training arguments
185
- # ============================================================
186
- print("\n[Step 5] Setting up training...")
187
-
188
- training_args = TrainingArguments(
189
- output_dir=output_dir,
190
- num_train_epochs=5,
191
- per_device_train_batch_size=16,
192
- per_device_eval_batch_size=16,
193
- learning_rate=5e-4,
194
- weight_decay=0.01,
195
- warmup_steps=500,
196
- logging_steps=50,
197
- eval_strategy="epoch",
198
- save_strategy="epoch",
199
- load_best_model_at_end=True,
200
- metric_for_best_model="eval_loss",
201
- report_to="none",
202
- seed=42,
203
- optim="adamw_torch",
204
- gradient_accumulation_steps=1,
205
- max_grad_norm=1.0,
206
- )
207
-
208
- print(f"✓ Training config:")
209
- print(f" - Output: {output_dir}")
210
- print(f" - Epochs: {training_args.num_train_epochs}")
211
- print(f" - Batch size: {training_args.per_device_train_batch_size}")
212
- print(f" - Learning rate: {training_args.learning_rate}")
213
-
214
- # ============================================================
215
- # Step 6: Train
216
- # ============================================================
217
- print("\n[Step 6] Starting training...")
218
- print("(This may take a while. In practice, use more GPUs/data for real pretraining)")
219
-
220
- trainer = NextTokenTrainer(
221
- model=model,
222
- args=training_args,
223
- train_dataset=train_dataset,
224
- eval_dataset=eval_dataset,
225
- data_collator=data_collator,
226
- )
227
-
228
- train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
229
-
230
- print(f"✓ Training complete!")
231
- print(f" - Final training loss: {train_result.training_loss:.4f}")
232
-
233
- # ============================================================
234
- # Step 7: Evaluate
235
- # ============================================================
236
- print("\n[Step 7] Evaluating...")
237
-
238
- eval_results = trainer.evaluate()
239
-
240
- print(f"✓ Evaluation Results:")
241
- for metric, value in eval_results.items():
242
- if isinstance(value, (int, float)):
243
- print(f" - {metric}: {value:.4f}")
244
-
245
- # ============================================================
246
- # Step 8: Save model and config
247
- # ============================================================
248
- print("\n[Step 8] Saving model...")
249
-
250
- save_dir = "./my_genemamba_from_scratch"
251
- model.save_pretrained(save_dir)
252
- model_config.save_pretrained(save_dir)
253
-
254
- print(f"✓ Model and config saved to '{save_dir}'")
255
- print(f" Files created:")
256
- print(f" - config.json")
257
- print(f" - model.safetensors (or pytorch_model.bin)")
258
-
259
- # ============================================================
260
- # Step 9: Reload and verify
261
- # ============================================================
262
- print("\n[Step 9] Reloading model from checkpoint...")
263
-
264
- loaded_model = AutoModelForMaskedLM.from_pretrained(
265
- save_dir,
266
- trust_remote_code=True,
267
- )
268
-
269
- loaded_model.eval()
270
-
271
- # Test inference
272
- with torch.no_grad():
273
- sample_input = torch.randint(2, 25426, (2, 2048))
274
- outputs = loaded_model(sample_input)
275
- logits = outputs.logits
276
-
277
- print(f"✓ Model reloaded and tested!")
278
- print(f" - Input shape: {sample_input.shape}")
279
- print(f" - Logits shape: {logits.shape}")
280
-
281
- # ============================================================
282
- # Step 10: Optional - Convert to different format
283
- # ============================================================
284
- print("\n[Step 10] Model ready for conversion/deployment!")
285
- print(f"✓ You can now:")
286
- print(f" 1. Push to Hugging Face Hub:")
287
- print(f" model.push_to_hub('your-username/GeneMamba-custom')")
288
- print(f" 2. Use with downstream tasks:")
289
- print(f" AutoModelForSequenceClassification.from_pretrained('{save_dir}', num_labels=N)")
290
- print(f" 3. Extract embeddings:")
291
- print(f" AutoModel.from_pretrained('{save_dir}')")
292
-
293
- print("\n" + "=" * 80)
294
- print("Phase 3 Complete! Model trained from scratch and ready to use.")
295
- print("=" * 80)
296
-
297
- return model, trainer, model_config
298
-
299
-
300
- if __name__ == "__main__":
301
- model, trainer, model_config = main()