""" Trainer: Main training loop for Vortex model. Handles gradient accumulation, mixed precision, checkpointing. """ import os import json import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from typing import Optional, Dict, List, Callable from pathlib import Path import logging from ..training.losses import VortexLoss from ..training.curriculum import CurriculumScheduler class VortexDataset(Dataset): """Simple dataset wrapper.""" def __init__( self, shard_files: List[str], tokenizer, max_seq_len: int = 16384, ): """ Initialize dataset. Args: shard_files: List of parquet shard files tokenizer: Tokenizer for encoding text max_seq_len: Maximum sequence length """ self.shard_files = shard_files self.tokenizer = tokenizer self.max_seq_len = max_seq_len # Load all shards into memory (for simplicity - would stream in practice) self.samples = [] self._load_shards() def _load_shards(self): """Load all shards.""" import pandas as pd for shard in self.shard_files: df = pd.read_parquet(shard) for _, row in df.iterrows(): self.samples.append({ "text": row["text"], "dataset": row.get("dataset", ""), "domain": row.get("domain", ""), }) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx) -> Dict: sample = self.samples[idx] text = sample["text"] # Tokenize encoding = self.tokenizer.encode( text, add_special_tokens=True, return_tensors="pt", ) input_ids = encoding["input_ids"].squeeze(0) attention_mask = encoding["attention_mask"].squeeze(0) # Truncate if needed if len(input_ids) > self.max_seq_len: input_ids = input_ids[:self.max_seq_len] attention_mask = attention_mask[:self.max_seq_len] # Labels are same as input_ids (causal LM) labels = input_ids.clone() return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "domain": sample["domain"], } class VortexTrainer: """ Main trainer for Vortex model. """ def __init__( self, model: nn.Module, tokenizer, train_dataset: Dataset, config: Dict, eval_dataset: Optional[Dataset] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ): """ Initialize trainer. Args: model: VortexModel tokenizer: VortexScienceTokenizer train_dataset: Training dataset config: Training configuration eval_dataset: Optional evaluation dataset optimizer: Optional optimizer (created if None) scheduler: Optional LR scheduler """ self.model = model self.tokenizer = tokenizer self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.config = config self.device = torch.device(config["device"]) self.use_amp = config.get("use_amp", True) self.amp_dtype = getattr(torch, config.get("amp_dtype", "bfloat16")) # Move model to device self.model.to(self.device) # Setup optimizer if optimizer is None: self.optimizer = self._create_optimizer() else: self.optimizer = optimizer # Setup scheduler if scheduler is None: self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=config["max_steps"], ) else: self.scheduler = scheduler # Setup AMP scaler self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and self.device.type == "cuda" else None # Loss function self.loss_fn = VortexLoss(config) # Curriculum scheduler self.curriculum = CurriculumScheduler(config, config["max_steps"]) # Logging self.log_dir = Path(config.get("log_dir", "logs")) self.log_dir.mkdir(parents=True, exist_ok=True) self.log_interval = config.get("log_interval", 100) # Checkpointing self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints")) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.save_interval = config.get("save_interval", 5000) # Training state self.global_step = 0 self.best_eval_loss = float('inf') # Data loader self.train_loader = DataLoader( train_dataset, batch_size=config["micro_batch_size"], shuffle=True, num_workers=config.get("num_workers", 4), pin_memory=config.get("pin_memory", True), prefetch_factor=config.get("prefetch_factor", 2), ) if eval_dataset: self.eval_loader = DataLoader( eval_dataset, batch_size=config["micro_batch_size"], shuffle=False, num_workers=config.get("num_workers", 4), ) def _create_optimizer(self) -> torch.optim.Optimizer: """Create AdamW optimizer.""" return torch.optim.AdamW( self.model.parameters(), lr=self.config["learning_rate"], betas=(self.config["beta1"], self.config["beta2"]), weight_decay=self.config["weight_decay"], ) def train_step( self, batch: Dict, current_step: int, ) -> Dict[str, torch.Tensor]: """ Single training step. Args: batch: Batch dictionary current_step: Current step number Returns: Dictionary of losses """ self.model.train() # Move batch to device input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) # Domain info (placeholder - would extract from batch) domain_ids = None domain_tags = None # Forward pass with AMP with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, domain_ids=domain_ids, domain_tags=domain_tags, return_dict=True, ) logits = outputs["logits"] # Compute losses losses = self.loss_fn( logits=logits, labels=labels, # Pass modules and masks for auxiliary losses ) # Backward pass if self.scaler: self.scaler.scale(losses["total_loss"]).backward() else: losses["total_loss"].backward() return losses def train_epoch(self): """Train for one epoch.""" self.model.train() for batch_idx, batch in enumerate(self.train_loader): # Train step losses = self.train_step(batch, self.global_step) # Gradient accumulation if (self.global_step + 1) % self.config["gradient_accumulation_steps"] == 0: # Gradient clipping if self.config.get("clip_grad_norm", 0) > 0: if self.scaler: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config["clip_grad_norm"], ) # Optimizer step if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() # Logging if self.global_step % self.log_interval == 0: self._log_losses(losses, batch_idx) # Evaluation if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0: self.evaluate() # Checkpointing if self.global_step % self.save_interval == 0: self.save_checkpoint() self.global_step += 1 if self.global_step >= self.config["max_steps"]: print("Reached max steps") return def evaluate(self) -> Dict[str, float]: """Run evaluation.""" self.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in self.eval_loader: input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) logits = outputs["logits"] loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, ) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 print(f"Evaluation at step {self.global_step}: loss = {avg_loss:.4f}") return {"eval_loss": avg_loss} def save_checkpoint(self, is_best: bool = False): """Save model checkpoint.""" checkpoint = { "step": self.global_step, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "config": self.config, "best_eval_loss": self.best_eval_loss, } if self.scaler: checkpoint["scaler_state_dict"] = self.scaler.state_dict() # Save latest checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step:06d}.pt" torch.save(checkpoint, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") # Save best if is_best: best_path = self.checkpoint_dir / "best_model.pt" torch.save(checkpoint, best_path) print(f"Saved best model to {best_path}") # Save latest link latest_path = self.checkpoint_dir / "latest.pt" torch.save(checkpoint, latest_path) def load_checkpoint(self, checkpoint_path: str): """Load checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) self.global_step = checkpoint["step"] self.best_eval_loss = checkpoint.get("best_eval_loss", float('inf')) if self.scaler and "scaler_state_dict" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) print(f"Loaded checkpoint from {checkpoint_path} at step {self.global_step}") def _log_losses(self, losses: Dict[str, torch.Tensor], batch_idx: int): """Log losses to console and file.""" loss_str = " | ".join([f"{k}: {v.item():.4f}" for k, v in losses.items()]) print(f"Step {self.global_step} | {loss_str}") def train(self): """Main training loop.""" print("Starting training...") print(f"Total steps: {self.config['max_steps']}") print(f"Device: {self.device}") print(f"Batch size: {self.config['micro_batch_size']}") print(f"Gradient accumulation steps: {self.config['gradient_accumulation_steps']}") try: self.train_epoch() except KeyboardInterrupt: print("Training interrupted") finally: self.save_checkpoint() def test_trainer(): """Test trainer with small model.""" from models.vortex_model import VortexModel from tokenizer.vortex_tokenizer import VortexScienceTokenizer from configs.vortex_7b_config import VORTEX_7B_CONFIG # Small config for testing config = VORTEX_7B_CONFIG.copy() config["d_model"] = 256 config["num_layers"] = 2 config["num_heads"] = 4 config["vocab_size"] = 1000 config["max_steps"] = 10 config["device"] = "cpu" # Create model model = VortexModel(config) # Create dummy tokenizer class DummyTokenizer: def encode(self, text, add_special_tokens=True, return_tensors="pt"): return {"input_ids": torch.randint(0, 1000, (1, 10)), "attention_mask": torch.ones(1, 10)} tokenizer = DummyTokenizer() # Create dummy dataset class DummyDataset(torch.utils.data.Dataset): def __len__(self): return 10 def __getitem__(self, idx): return { "input_ids": torch.randint(0, 1000, (32,)), "attention_mask": torch.ones(32), "labels": torch.randint(0, 1000, (32,)), "domain": "physics", } train_dataset = DummyDataset() eval_dataset = DummyDataset() # Create trainer trainer = VortexTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, config=config, eval_dataset=eval_dataset, ) # Run a few steps trainer.train() print("Trainer test passed!") if __name__ == "__main__": test_trainer()