MiniMind / scripts /train.py
fariasultana's picture
MiniMind Max2 - Efficient MoE Language Model
8b187bb verified
#!/usr/bin/env python3
"""
MiniMind Training Script
Train Mind2 models from scratch or with knowledge distillation.
"""
import argparse
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from torch.utils.data import DataLoader
from configs.model_config import get_config, estimate_params
from model import Mind2ForCausalLM
from training.trainer import Mind2Trainer, TrainingConfig
from training.distillation import DistillationTrainer, DistillationConfig
def parse_args():
parser = argparse.ArgumentParser(description="Train MiniMind (Mind2) models")
# Model
parser.add_argument("--model", type=str, default="mind2-lite",
choices=["mind2-nano", "mind2-lite", "mind2-pro"],
help="Model variant to train")
# Data
parser.add_argument("--train-data", type=str, required=True,
help="Path to training data (JSONL format)")
parser.add_argument("--eval-data", type=str, default=None,
help="Path to evaluation data")
# Training
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--grad-accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--warmup-steps", type=int, default=1000)
parser.add_argument("--max-steps", type=int, default=None)
# Distillation
parser.add_argument("--teacher-model", type=str, default=None,
help="Path to teacher model for distillation")
parser.add_argument("--temperature", type=float, default=2.0)
parser.add_argument("--alpha-kd", type=float, default=0.5)
# Output
parser.add_argument("--output-dir", type=str, default="./outputs")
parser.add_argument("--save-steps", type=int, default=1000)
# Hardware
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="float16",
choices=["float16", "bfloat16", "float32"])
return parser.parse_args()
def main():
args = parse_args()
# Setup
device = args.device if torch.cuda.is_available() else "cpu"
dtype = getattr(torch, args.dtype)
print(f"=" * 60)
print(f"MiniMind Training")
print(f"=" * 60)
print(f"Model: {args.model}")
print(f"Device: {device}, Dtype: {args.dtype}")
# Create model
config = get_config(args.model)
model = Mind2ForCausalLM(config).to(device=device, dtype=dtype)
# Print model info
params = estimate_params(config)
print(f"Total params: {params['total_params_b']:.2f}B")
print(f"Active params: {params['active_params_b']:.2f}B")
print(f"Activation ratio: {params['activation_ratio']:.1%}")
# Create dummy dataloader (replace with actual data loading)
print(f"\nNote: Using dummy data. Replace with actual data loading.")
train_data = torch.randint(0, config.vocab_size, (1000, 512))
train_loader = DataLoader(
torch.utils.data.TensorDataset(train_data, train_data),
batch_size=args.batch_size,
shuffle=True
)
# Training configuration
if args.teacher_model:
# Knowledge distillation
print(f"\nUsing knowledge distillation from: {args.teacher_model}")
distill_config = DistillationConfig(
learning_rate=args.lr,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
temperature=args.temperature,
alpha_kd=args.alpha_kd,
alpha_ce=1.0 - args.alpha_kd,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
save_steps=args.save_steps,
output_dir=args.output_dir,
)
# Load teacher (placeholder)
teacher = None # Load actual teacher model
trainer = DistillationTrainer(
student_model=model,
teacher_model=teacher,
train_dataloader=train_loader,
config=distill_config,
)
else:
# Standard training
train_config = TrainingConfig(
learning_rate=args.lr,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
save_steps=args.save_steps,
output_dir=args.output_dir,
)
# Wrap dataloader to return dict format
class DictDataLoader:
def __init__(self, loader):
self.loader = loader
def __iter__(self):
for input_ids, labels in self.loader:
yield {
"input_ids": input_ids,
"labels": labels,
}
def __len__(self):
return len(self.loader)
trainer = Mind2Trainer(
model=model,
train_dataloader=DictDataLoader(train_loader),
config=train_config,
)
# Train
print(f"\nStarting training...")
results = trainer.train()
print(f"\nTraining complete!")
print(f"Results: {results}")
if __name__ == "__main__":
main()