#!/usr/bin/env python3 """ MiniMind Training Script Train Mind2 models from scratch or with knowledge distillation. """ import argparse import sys from pathlib import Path # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) import torch from torch.utils.data import DataLoader from configs.model_config import get_config, estimate_params from model import Mind2ForCausalLM from training.trainer import Mind2Trainer, TrainingConfig from training.distillation import DistillationTrainer, DistillationConfig def parse_args(): parser = argparse.ArgumentParser(description="Train MiniMind (Mind2) models") # Model parser.add_argument("--model", type=str, default="mind2-lite", choices=["mind2-nano", "mind2-lite", "mind2-pro"], help="Model variant to train") # Data parser.add_argument("--train-data", type=str, required=True, help="Path to training data (JSONL format)") parser.add_argument("--eval-data", type=str, default=None, help="Path to evaluation data") # Training parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--warmup-steps", type=int, default=1000) parser.add_argument("--max-steps", type=int, default=None) # Distillation parser.add_argument("--teacher-model", type=str, default=None, help="Path to teacher model for distillation") parser.add_argument("--temperature", type=float, default=2.0) parser.add_argument("--alpha-kd", type=float, default=0.5) # Output parser.add_argument("--output-dir", type=str, default="./outputs") parser.add_argument("--save-steps", type=int, default=1000) # Hardware parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"]) return parser.parse_args() def main(): args = parse_args() # Setup device = args.device if torch.cuda.is_available() else "cpu" dtype = getattr(torch, args.dtype) print(f"=" * 60) print(f"MiniMind Training") print(f"=" * 60) print(f"Model: {args.model}") print(f"Device: {device}, Dtype: {args.dtype}") # Create model config = get_config(args.model) model = Mind2ForCausalLM(config).to(device=device, dtype=dtype) # Print model info params = estimate_params(config) print(f"Total params: {params['total_params_b']:.2f}B") print(f"Active params: {params['active_params_b']:.2f}B") print(f"Activation ratio: {params['activation_ratio']:.1%}") # Create dummy dataloader (replace with actual data loading) print(f"\nNote: Using dummy data. Replace with actual data loading.") train_data = torch.randint(0, config.vocab_size, (1000, 512)) train_loader = DataLoader( torch.utils.data.TensorDataset(train_data, train_data), batch_size=args.batch_size, shuffle=True ) # Training configuration if args.teacher_model: # Knowledge distillation print(f"\nUsing knowledge distillation from: {args.teacher_model}") distill_config = DistillationConfig( learning_rate=args.lr, num_epochs=args.epochs, batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, temperature=args.temperature, alpha_kd=args.alpha_kd, alpha_ce=1.0 - args.alpha_kd, warmup_steps=args.warmup_steps, max_steps=args.max_steps, save_steps=args.save_steps, output_dir=args.output_dir, ) # Load teacher (placeholder) teacher = None # Load actual teacher model trainer = DistillationTrainer( student_model=model, teacher_model=teacher, train_dataloader=train_loader, config=distill_config, ) else: # Standard training train_config = TrainingConfig( learning_rate=args.lr, num_epochs=args.epochs, batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, warmup_steps=args.warmup_steps, max_steps=args.max_steps, save_steps=args.save_steps, output_dir=args.output_dir, ) # Wrap dataloader to return dict format class DictDataLoader: def __init__(self, loader): self.loader = loader def __iter__(self): for input_ids, labels in self.loader: yield { "input_ids": input_ids, "labels": labels, } def __len__(self): return len(self.loader) trainer = Mind2Trainer( model=model, train_dataloader=DictDataLoader(train_loader), config=train_config, ) # Train print(f"\nStarting training...") results = trainer.train() print(f"\nTraining complete!") print(f"Results: {results}") if __name__ == "__main__": main()