|
|
|
|
|
""" |
|
|
MiniMind Training Script |
|
|
Train Mind2 models from scratch or with knowledge distillation. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from configs.model_config import get_config, estimate_params |
|
|
from model import Mind2ForCausalLM |
|
|
from training.trainer import Mind2Trainer, TrainingConfig |
|
|
from training.distillation import DistillationTrainer, DistillationConfig |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Train MiniMind (Mind2) models") |
|
|
|
|
|
|
|
|
parser.add_argument("--model", type=str, default="mind2-lite", |
|
|
choices=["mind2-nano", "mind2-lite", "mind2-pro"], |
|
|
help="Model variant to train") |
|
|
|
|
|
|
|
|
parser.add_argument("--train-data", type=str, required=True, |
|
|
help="Path to training data (JSONL format)") |
|
|
parser.add_argument("--eval-data", type=str, default=None, |
|
|
help="Path to evaluation data") |
|
|
|
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=3) |
|
|
parser.add_argument("--batch-size", type=int, default=8) |
|
|
parser.add_argument("--grad-accum", type=int, default=4) |
|
|
parser.add_argument("--lr", type=float, default=3e-4) |
|
|
parser.add_argument("--warmup-steps", type=int, default=1000) |
|
|
parser.add_argument("--max-steps", type=int, default=None) |
|
|
|
|
|
|
|
|
parser.add_argument("--teacher-model", type=str, default=None, |
|
|
help="Path to teacher model for distillation") |
|
|
parser.add_argument("--temperature", type=float, default=2.0) |
|
|
parser.add_argument("--alpha-kd", type=float, default=0.5) |
|
|
|
|
|
|
|
|
parser.add_argument("--output-dir", type=str, default="./outputs") |
|
|
parser.add_argument("--save-steps", type=int, default=1000) |
|
|
|
|
|
|
|
|
parser.add_argument("--device", type=str, default="cuda") |
|
|
parser.add_argument("--dtype", type=str, default="float16", |
|
|
choices=["float16", "bfloat16", "float32"]) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
device = args.device if torch.cuda.is_available() else "cpu" |
|
|
dtype = getattr(torch, args.dtype) |
|
|
|
|
|
print(f"=" * 60) |
|
|
print(f"MiniMind Training") |
|
|
print(f"=" * 60) |
|
|
print(f"Model: {args.model}") |
|
|
print(f"Device: {device}, Dtype: {args.dtype}") |
|
|
|
|
|
|
|
|
config = get_config(args.model) |
|
|
model = Mind2ForCausalLM(config).to(device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
params = estimate_params(config) |
|
|
print(f"Total params: {params['total_params_b']:.2f}B") |
|
|
print(f"Active params: {params['active_params_b']:.2f}B") |
|
|
print(f"Activation ratio: {params['activation_ratio']:.1%}") |
|
|
|
|
|
|
|
|
print(f"\nNote: Using dummy data. Replace with actual data loading.") |
|
|
train_data = torch.randint(0, config.vocab_size, (1000, 512)) |
|
|
train_loader = DataLoader( |
|
|
torch.utils.data.TensorDataset(train_data, train_data), |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True |
|
|
) |
|
|
|
|
|
|
|
|
if args.teacher_model: |
|
|
|
|
|
print(f"\nUsing knowledge distillation from: {args.teacher_model}") |
|
|
|
|
|
distill_config = DistillationConfig( |
|
|
learning_rate=args.lr, |
|
|
num_epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
gradient_accumulation_steps=args.grad_accum, |
|
|
temperature=args.temperature, |
|
|
alpha_kd=args.alpha_kd, |
|
|
alpha_ce=1.0 - args.alpha_kd, |
|
|
warmup_steps=args.warmup_steps, |
|
|
max_steps=args.max_steps, |
|
|
save_steps=args.save_steps, |
|
|
output_dir=args.output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
teacher = None |
|
|
|
|
|
trainer = DistillationTrainer( |
|
|
student_model=model, |
|
|
teacher_model=teacher, |
|
|
train_dataloader=train_loader, |
|
|
config=distill_config, |
|
|
) |
|
|
else: |
|
|
|
|
|
train_config = TrainingConfig( |
|
|
learning_rate=args.lr, |
|
|
num_epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
gradient_accumulation_steps=args.grad_accum, |
|
|
warmup_steps=args.warmup_steps, |
|
|
max_steps=args.max_steps, |
|
|
save_steps=args.save_steps, |
|
|
output_dir=args.output_dir, |
|
|
) |
|
|
|
|
|
|
|
|
class DictDataLoader: |
|
|
def __init__(self, loader): |
|
|
self.loader = loader |
|
|
|
|
|
def __iter__(self): |
|
|
for input_ids, labels in self.loader: |
|
|
yield { |
|
|
"input_ids": input_ids, |
|
|
"labels": labels, |
|
|
} |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.loader) |
|
|
|
|
|
trainer = Mind2Trainer( |
|
|
model=model, |
|
|
train_dataloader=DictDataLoader(train_loader), |
|
|
config=train_config, |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\nStarting training...") |
|
|
results = trainer.train() |
|
|
print(f"\nTraining complete!") |
|
|
print(f"Results: {results}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|