Trouter-Library commited on
Commit
a13f30f
·
verified ·
1 Parent(s): bbad13f

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +446 -0
train.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-OSC Training Script
3
+ Fine-tuning and training utilities for Helion-OSC model
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import json
9
+ import logging
10
+ from typing import Optional, Dict, Any, List
11
+ from dataclasses import dataclass, field
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ TrainingArguments,
16
+ Trainer,
17
+ DataCollatorForLanguageModeling,
18
+ EarlyStoppingCallback
19
+ )
20
+ from datasets import load_dataset, Dataset, DatasetDict
21
+ from peft import (
22
+ LoraConfig,
23
+ get_peft_model,
24
+ prepare_model_for_kbit_training,
25
+ TaskType
26
+ )
27
+ import wandb
28
+ from torch.utils.data import DataLoader
29
+
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class ModelArguments:
36
+ """Arguments for model configuration"""
37
+ model_name_or_path: str = field(
38
+ default="DeepXR/Helion-OSC",
39
+ metadata={"help": "Path to pretrained model or model identifier"}
40
+ )
41
+ use_lora: bool = field(
42
+ default=True,
43
+ metadata={"help": "Whether to use LoRA for efficient fine-tuning"}
44
+ )
45
+ lora_r: int = field(
46
+ default=16,
47
+ metadata={"help": "LoRA attention dimension"}
48
+ )
49
+ lora_alpha: int = field(
50
+ default=32,
51
+ metadata={"help": "LoRA alpha parameter"}
52
+ )
53
+ lora_dropout: float = field(
54
+ default=0.05,
55
+ metadata={"help": "LoRA dropout probability"}
56
+ )
57
+ load_in_8bit: bool = field(
58
+ default=False,
59
+ metadata={"help": "Load model in 8-bit precision"}
60
+ )
61
+ load_in_4bit: bool = field(
62
+ default=False,
63
+ metadata={"help": "Load model in 4-bit precision"}
64
+ )
65
+
66
+
67
+ @dataclass
68
+ class DataArguments:
69
+ """Arguments for data processing"""
70
+ dataset_name: Optional[str] = field(
71
+ default=None,
72
+ metadata={"help": "Name of the dataset to use"}
73
+ )
74
+ dataset_path: Optional[str] = field(
75
+ default=None,
76
+ metadata={"help": "Path to local dataset"}
77
+ )
78
+ train_file: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "Path to training data file"}
81
+ )
82
+ validation_file: Optional[str] = field(
83
+ default=None,
84
+ metadata={"help": "Path to validation data file"}
85
+ )
86
+ max_seq_length: int = field(
87
+ default=2048,
88
+ metadata={"help": "Maximum sequence length"}
89
+ )
90
+ preprocessing_num_workers: int = field(
91
+ default=4,
92
+ metadata={"help": "Number of workers for preprocessing"}
93
+ )
94
+
95
+
96
+ class HelionOSCTrainer:
97
+ """Trainer class for Helion-OSC model"""
98
+
99
+ def __init__(
100
+ self,
101
+ model_args: ModelArguments,
102
+ data_args: DataArguments,
103
+ training_args: TrainingArguments
104
+ ):
105
+ self.model_args = model_args
106
+ self.data_args = data_args
107
+ self.training_args = training_args
108
+
109
+ # Initialize tokenizer
110
+ self.tokenizer = self._load_tokenizer()
111
+
112
+ # Initialize model
113
+ self.model = self._load_model()
114
+
115
+ # Load and preprocess data
116
+ self.datasets = self._load_datasets()
117
+
118
+ logger.info("Trainer initialized successfully")
119
+
120
+ def _load_tokenizer(self):
121
+ """Load and configure tokenizer"""
122
+ logger.info("Loading tokenizer...")
123
+ tokenizer = AutoTokenizer.from_pretrained(
124
+ self.model_args.model_name_or_path,
125
+ trust_remote_code=True,
126
+ padding_side="right"
127
+ )
128
+
129
+ if tokenizer.pad_token is None:
130
+ tokenizer.pad_token = tokenizer.eos_token
131
+
132
+ return tokenizer
133
+
134
+ def _load_model(self):
135
+ """Load and configure model"""
136
+ logger.info("Loading model...")
137
+
138
+ model_kwargs = {
139
+ "trust_remote_code": True,
140
+ "low_cpu_mem_usage": True
141
+ }
142
+
143
+ # Configure quantization
144
+ if self.model_args.load_in_8bit:
145
+ model_kwargs["load_in_8bit"] = True
146
+ elif self.model_args.load_in_4bit:
147
+ model_kwargs["load_in_4bit"] = True
148
+ model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
149
+ model_kwargs["bnb_4bit_use_double_quant"] = True
150
+ model_kwargs["bnb_4bit_quant_type"] = "nf4"
151
+ else:
152
+ model_kwargs["torch_dtype"] = torch.bfloat16
153
+
154
+ model = AutoModelForCausalLM.from_pretrained(
155
+ self.model_args.model_name_or_path,
156
+ **model_kwargs
157
+ )
158
+
159
+ # Apply LoRA if requested
160
+ if self.model_args.use_lora:
161
+ logger.info("Applying LoRA configuration...")
162
+
163
+ if self.model_args.load_in_8bit or self.model_args.load_in_4bit:
164
+ model = prepare_model_for_kbit_training(model)
165
+
166
+ lora_config = LoraConfig(
167
+ r=self.model_args.lora_r,
168
+ lora_alpha=self.model_args.lora_alpha,
169
+ target_modules=[
170
+ "q_proj",
171
+ "k_proj",
172
+ "v_proj",
173
+ "o_proj",
174
+ "gate_proj",
175
+ "up_proj",
176
+ "down_proj"
177
+ ],
178
+ lora_dropout=self.model_args.lora_dropout,
179
+ bias="none",
180
+ task_type=TaskType.CAUSAL_LM
181
+ )
182
+
183
+ model = get_peft_model(model, lora_config)
184
+ model.print_trainable_parameters()
185
+
186
+ return model
187
+
188
+ def _load_datasets(self) -> DatasetDict:
189
+ """Load and preprocess datasets"""
190
+ logger.info("Loading datasets...")
191
+
192
+ if self.data_args.dataset_name:
193
+ # Load from HuggingFace Hub
194
+ datasets = load_dataset(self.data_args.dataset_name)
195
+ elif self.data_args.train_file:
196
+ # Load from local files
197
+ data_files = {"train": self.data_args.train_file}
198
+ if self.data_args.validation_file:
199
+ data_files["validation"] = self.data_args.validation_file
200
+
201
+ datasets = load_dataset("json", data_files=data_files)
202
+ else:
203
+ raise ValueError("Must provide either dataset_name or train_file")
204
+
205
+ # Preprocess datasets
206
+ logger.info("Preprocessing datasets...")
207
+ datasets = datasets.map(
208
+ self._preprocess_function,
209
+ batched=True,
210
+ num_proc=self.data_args.preprocessing_num_workers,
211
+ remove_columns=datasets["train"].column_names,
212
+ desc="Preprocessing datasets"
213
+ )
214
+
215
+ return datasets
216
+
217
+ def _preprocess_function(self, examples):
218
+ """Preprocess examples for training"""
219
+ # Tokenize inputs
220
+ if "prompt" in examples and "completion" in examples:
221
+ # Instruction-following format
222
+ texts = [
223
+ f"{prompt}\n{completion}"
224
+ for prompt, completion in zip(examples["prompt"], examples["completion"])
225
+ ]
226
+ elif "text" in examples:
227
+ # Raw text format
228
+ texts = examples["text"]
229
+ else:
230
+ raise ValueError("Dataset must contain 'text' or 'prompt'/'completion' columns")
231
+
232
+ # Tokenize
233
+ tokenized = self.tokenizer(
234
+ texts,
235
+ truncation=True,
236
+ max_length=self.data_args.max_seq_length,
237
+ padding="max_length",
238
+ return_tensors=None
239
+ )
240
+
241
+ # Create labels (same as input_ids for causal LM)
242
+ tokenized["labels"] = tokenized["input_ids"].copy()
243
+
244
+ return tokenized
245
+
246
+ def train(self):
247
+ """Train the model"""
248
+ logger.info("Starting training...")
249
+
250
+ # Data collator
251
+ data_collator = DataCollatorForLanguageModeling(
252
+ tokenizer=self.tokenizer,
253
+ mlm=False
254
+ )
255
+
256
+ # Initialize trainer
257
+ trainer = Trainer(
258
+ model=self.model,
259
+ args=self.training_args,
260
+ train_dataset=self.datasets["train"],
261
+ eval_dataset=self.datasets.get("validation"),
262
+ tokenizer=self.tokenizer,
263
+ data_collator=data_collator,
264
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
265
+ )
266
+
267
+ # Train
268
+ train_result = trainer.train()
269
+
270
+ # Save model
271
+ trainer.save_model()
272
+
273
+ # Save metrics
274
+ metrics = train_result.metrics
275
+ trainer.log_metrics("train", metrics)
276
+ trainer.save_metrics("train", metrics)
277
+ trainer.save_state()
278
+
279
+ logger.info("Training completed successfully!")
280
+
281
+ return trainer, metrics
282
+
283
+ def evaluate(self, trainer: Optional[Trainer] = None):
284
+ """Evaluate the model"""
285
+ if trainer is None:
286
+ data_collator = DataCollatorForLanguageModeling(
287
+ tokenizer=self.tokenizer,
288
+ mlm=False
289
+ )
290
+
291
+ trainer = Trainer(
292
+ model=self.model,
293
+ args=self.training_args,
294
+ eval_dataset=self.datasets.get("validation"),
295
+ tokenizer=self.tokenizer,
296
+ data_collator=data_collator
297
+ )
298
+
299
+ logger.info("Evaluating model...")
300
+ metrics = trainer.evaluate()
301
+
302
+ trainer.log_metrics("eval", metrics)
303
+ trainer.save_metrics("eval", metrics)
304
+
305
+ return metrics
306
+
307
+
308
+ def create_code_dataset(examples: List[Dict[str, str]]) -> Dataset:
309
+ """
310
+ Create a dataset from code examples
311
+
312
+ Args:
313
+ examples: List of dictionaries with 'prompt' and 'completion' keys
314
+
315
+ Returns:
316
+ Dataset object
317
+ """
318
+ return Dataset.from_dict({
319
+ "prompt": [ex["prompt"] for ex in examples],
320
+ "completion": [ex["completion"] for ex in examples]
321
+ })
322
+
323
+
324
+ def create_math_dataset(examples: List[Dict[str, str]]) -> Dataset:
325
+ """
326
+ Create a dataset from math examples
327
+
328
+ Args:
329
+ examples: List of dictionaries with 'problem' and 'solution' keys
330
+
331
+ Returns:
332
+ Dataset object
333
+ """
334
+ return Dataset.from_dict({
335
+ "prompt": [f"Problem: {ex['problem']}\nSolution:" for ex in examples],
336
+ "completion": [ex["solution"] for ex in examples]
337
+ })
338
+
339
+
340
+ def main():
341
+ """Main training script"""
342
+ import argparse
343
+
344
+ parser = argparse.ArgumentParser(description="Train Helion-OSC model")
345
+
346
+ # Model arguments
347
+ parser.add_argument("--model_name_or_path", type=str, default="DeepXR/Helion-OSC")
348
+ parser.add_argument("--use_lora", action="store_true", default=True)
349
+ parser.add_argument("--lora_r", type=int, default=16)
350
+ parser.add_argument("--lora_alpha", type=int, default=32)
351
+ parser.add_argument("--lora_dropout", type=float, default=0.05)
352
+ parser.add_argument("--load_in_8bit", action="store_true")
353
+ parser.add_argument("--load_in_4bit", action="store_true")
354
+
355
+ # Data arguments
356
+ parser.add_argument("--dataset_name", type=str, default=None)
357
+ parser.add_argument("--dataset_path", type=str, default=None)
358
+ parser.add_argument("--train_file", type=str, required=True)
359
+ parser.add_argument("--validation_file", type=str, default=None)
360
+ parser.add_argument("--max_seq_length", type=int, default=2048)
361
+ parser.add_argument("--preprocessing_num_workers", type=int, default=4)
362
+
363
+ # Training arguments
364
+ parser.add_argument("--output_dir", type=str, required=True)
365
+ parser.add_argument("--num_train_epochs", type=int, default=3)
366
+ parser.add_argument("--per_device_train_batch_size", type=int, default=4)
367
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
368
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
369
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
370
+ parser.add_argument("--warmup_steps", type=int, default=100)
371
+ parser.add_argument("--logging_steps", type=int, default=10)
372
+ parser.add_argument("--save_steps", type=int, default=500)
373
+ parser.add_argument("--eval_steps", type=int, default=500)
374
+ parser.add_argument("--save_total_limit", type=int, default=3)
375
+ parser.add_argument("--fp16", action="store_true")
376
+ parser.add_argument("--bf16", action="store_true")
377
+ parser.add_argument("--gradient_checkpointing", action="store_true")
378
+ parser.add_argument("--use_wandb", action="store_true")
379
+
380
+ args = parser.parse_args()
381
+
382
+ # Create argument objects
383
+ model_args = ModelArguments(
384
+ model_name_or_path=args.model_name_or_path,
385
+ use_lora=args.use_lora,
386
+ lora_r=args.lora_r,
387
+ lora_alpha=args.lora_alpha,
388
+ lora_dropout=args.lora_dropout,
389
+ load_in_8bit=args.load_in_8bit,
390
+ load_in_4bit=args.load_in_4bit
391
+ )
392
+
393
+ data_args = DataArguments(
394
+ dataset_name=args.dataset_name,
395
+ dataset_path=args.dataset_path,
396
+ train_file=args.train_file,
397
+ validation_file=args.validation_file,
398
+ max_seq_length=args.max_seq_length,
399
+ preprocessing_num_workers=args.preprocessing_num_workers
400
+ )
401
+
402
+ training_args = TrainingArguments(
403
+ output_dir=args.output_dir,
404
+ num_train_epochs=args.num_train_epochs,
405
+ per_device_train_batch_size=args.per_device_train_batch_size,
406
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
407
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
408
+ learning_rate=args.learning_rate,
409
+ warmup_steps=args.warmup_steps,
410
+ logging_steps=args.logging_steps,
411
+ save_steps=args.save_steps,
412
+ eval_steps=args.eval_steps,
413
+ save_total_limit=args.save_total_limit,
414
+ fp16=args.fp16,
415
+ bf16=args.bf16,
416
+ gradient_checkpointing=args.gradient_checkpointing,
417
+ report_to="wandb" if args.use_wandb else "none",
418
+ load_best_model_at_end=True,
419
+ metric_for_best_model="eval_loss",
420
+ greater_is_better=False,
421
+ evaluation_strategy="steps",
422
+ save_strategy="steps",
423
+ logging_dir=f"{args.output_dir}/logs",
424
+ remove_unused_columns=False
425
+ )
426
+
427
+ # Initialize trainer
428
+ helion_trainer = HelionOSCTrainer(
429
+ model_args=model_args,
430
+ data_args=data_args,
431
+ training_args=training_args
432
+ )
433
+
434
+ # Train
435
+ trainer, metrics = helion_trainer.train()
436
+
437
+ # Evaluate
438
+ if args.validation_file:
439
+ eval_metrics = helion_trainer.evaluate(trainer)
440
+ logger.info(f"Evaluation metrics: {eval_metrics}")
441
+
442
+ logger.info("Training pipeline completed!")
443
+
444
+
445
+ if __name__ == "__main__":
446
+ main()