| | """
|
| | Trainer: Main training loop for Vortex model.
|
| | Handles gradient accumulation, mixed precision, checkpointing.
|
| | """
|
| |
|
| | import os
|
| | import json
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.utils.data import DataLoader, Dataset
|
| | from typing import Optional, Dict, List, Callable
|
| | from pathlib import Path
|
| | import logging
|
| |
|
| | from ..training.losses import VortexLoss
|
| | from ..training.curriculum import CurriculumScheduler
|
| |
|
| |
|
| | class VortexDataset(Dataset):
|
| | """Simple dataset wrapper."""
|
| |
|
| | def __init__(
|
| | self,
|
| | shard_files: List[str],
|
| | tokenizer,
|
| | max_seq_len: int = 16384,
|
| | ):
|
| | """
|
| | Initialize dataset.
|
| |
|
| | Args:
|
| | shard_files: List of parquet shard files
|
| | tokenizer: Tokenizer for encoding text
|
| | max_seq_len: Maximum sequence length
|
| | """
|
| | self.shard_files = shard_files
|
| | self.tokenizer = tokenizer
|
| | self.max_seq_len = max_seq_len
|
| |
|
| |
|
| | self.samples = []
|
| | self._load_shards()
|
| |
|
| | def _load_shards(self):
|
| | """Load all shards."""
|
| | import pandas as pd
|
| |
|
| | for shard in self.shard_files:
|
| | df = pd.read_parquet(shard)
|
| | for _, row in df.iterrows():
|
| | self.samples.append({
|
| | "text": row["text"],
|
| | "dataset": row.get("dataset", ""),
|
| | "domain": row.get("domain", ""),
|
| | })
|
| |
|
| | def __len__(self) -> int:
|
| | return len(self.samples)
|
| |
|
| | def __getitem__(self, idx) -> Dict:
|
| | sample = self.samples[idx]
|
| | text = sample["text"]
|
| |
|
| |
|
| | encoding = self.tokenizer.encode(
|
| | text,
|
| | add_special_tokens=True,
|
| | return_tensors="pt",
|
| | )
|
| |
|
| | input_ids = encoding["input_ids"].squeeze(0)
|
| | attention_mask = encoding["attention_mask"].squeeze(0)
|
| |
|
| |
|
| | if len(input_ids) > self.max_seq_len:
|
| | input_ids = input_ids[:self.max_seq_len]
|
| | attention_mask = attention_mask[:self.max_seq_len]
|
| |
|
| |
|
| | labels = input_ids.clone()
|
| |
|
| | return {
|
| | "input_ids": input_ids,
|
| | "attention_mask": attention_mask,
|
| | "labels": labels,
|
| | "domain": sample["domain"],
|
| | }
|
| |
|
| |
|
| | class VortexTrainer:
|
| | """
|
| | Main trainer for Vortex model.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | model: nn.Module,
|
| | tokenizer,
|
| | train_dataset: Dataset,
|
| | config: Dict,
|
| | eval_dataset: Optional[Dataset] = None,
|
| | optimizer: Optional[torch.optim.Optimizer] = None,
|
| | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
| | ):
|
| | """
|
| | Initialize trainer.
|
| |
|
| | Args:
|
| | model: VortexModel
|
| | tokenizer: VortexScienceTokenizer
|
| | train_dataset: Training dataset
|
| | config: Training configuration
|
| | eval_dataset: Optional evaluation dataset
|
| | optimizer: Optional optimizer (created if None)
|
| | scheduler: Optional LR scheduler
|
| | """
|
| | self.model = model
|
| | self.tokenizer = tokenizer
|
| | self.train_dataset = train_dataset
|
| | self.eval_dataset = eval_dataset
|
| | self.config = config
|
| |
|
| | self.device = torch.device(config["device"])
|
| | self.use_amp = config.get("use_amp", True)
|
| | self.amp_dtype = getattr(torch, config.get("amp_dtype", "bfloat16"))
|
| |
|
| |
|
| | self.model.to(self.device)
|
| |
|
| |
|
| | if optimizer is None:
|
| | self.optimizer = self._create_optimizer()
|
| | else:
|
| | self.optimizer = optimizer
|
| |
|
| |
|
| | if scheduler is None:
|
| | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| | self.optimizer,
|
| | T_max=config["max_steps"],
|
| | )
|
| | else:
|
| | self.scheduler = scheduler
|
| |
|
| |
|
| | self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and self.device.type == "cuda" else None
|
| |
|
| |
|
| | self.loss_fn = VortexLoss(config)
|
| |
|
| |
|
| | self.curriculum = CurriculumScheduler(config, config["max_steps"])
|
| |
|
| |
|
| | self.log_dir = Path(config.get("log_dir", "logs"))
|
| | self.log_dir.mkdir(parents=True, exist_ok=True)
|
| | self.log_interval = config.get("log_interval", 100)
|
| |
|
| |
|
| | self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
|
| | self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| | self.save_interval = config.get("save_interval", 5000)
|
| |
|
| |
|
| | self.global_step = 0
|
| | self.best_eval_loss = float('inf')
|
| |
|
| |
|
| | self.train_loader = DataLoader(
|
| | train_dataset,
|
| | batch_size=config["micro_batch_size"],
|
| | shuffle=True,
|
| | num_workers=config.get("num_workers", 4),
|
| | pin_memory=config.get("pin_memory", True),
|
| | prefetch_factor=config.get("prefetch_factor", 2),
|
| | )
|
| |
|
| | if eval_dataset:
|
| | self.eval_loader = DataLoader(
|
| | eval_dataset,
|
| | batch_size=config["micro_batch_size"],
|
| | shuffle=False,
|
| | num_workers=config.get("num_workers", 4),
|
| | )
|
| |
|
| | def _create_optimizer(self) -> torch.optim.Optimizer:
|
| | """Create AdamW optimizer."""
|
| | return torch.optim.AdamW(
|
| | self.model.parameters(),
|
| | lr=self.config["learning_rate"],
|
| | betas=(self.config["beta1"], self.config["beta2"]),
|
| | weight_decay=self.config["weight_decay"],
|
| | )
|
| |
|
| | def train_step(
|
| | self,
|
| | batch: Dict,
|
| | current_step: int,
|
| | ) -> Dict[str, torch.Tensor]:
|
| | """
|
| | Single training step.
|
| |
|
| | Args:
|
| | batch: Batch dictionary
|
| | current_step: Current step number
|
| |
|
| | Returns:
|
| | Dictionary of losses
|
| | """
|
| | self.model.train()
|
| |
|
| |
|
| | input_ids = batch["input_ids"].to(self.device)
|
| | attention_mask = batch["attention_mask"].to(self.device)
|
| | labels = batch["labels"].to(self.device)
|
| |
|
| |
|
| | domain_ids = None
|
| | domain_tags = None
|
| |
|
| |
|
| | with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | domain_ids=domain_ids,
|
| | domain_tags=domain_tags,
|
| | return_dict=True,
|
| | )
|
| | logits = outputs["logits"]
|
| |
|
| |
|
| | losses = self.loss_fn(
|
| | logits=logits,
|
| | labels=labels,
|
| |
|
| | )
|
| |
|
| |
|
| | if self.scaler:
|
| | self.scaler.scale(losses["total_loss"]).backward()
|
| | else:
|
| | losses["total_loss"].backward()
|
| |
|
| | return losses
|
| |
|
| | def train_epoch(self):
|
| | """Train for one epoch."""
|
| | self.model.train()
|
| |
|
| | for batch_idx, batch in enumerate(self.train_loader):
|
| |
|
| | losses = self.train_step(batch, self.global_step)
|
| |
|
| |
|
| | if (self.global_step + 1) % self.config["gradient_accumulation_steps"] == 0:
|
| |
|
| | if self.config.get("clip_grad_norm", 0) > 0:
|
| | if self.scaler:
|
| | self.scaler.unscale_(self.optimizer)
|
| | torch.nn.utils.clip_grad_norm_(
|
| | self.model.parameters(),
|
| | self.config["clip_grad_norm"],
|
| | )
|
| |
|
| |
|
| | if self.scaler:
|
| | self.scaler.step(self.optimizer)
|
| | self.scaler.update()
|
| | else:
|
| | self.optimizer.step()
|
| |
|
| | self.optimizer.zero_grad()
|
| | self.scheduler.step()
|
| |
|
| |
|
| | if self.global_step % self.log_interval == 0:
|
| | self._log_losses(losses, batch_idx)
|
| |
|
| |
|
| | if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
|
| | self.evaluate()
|
| |
|
| |
|
| | if self.global_step % self.save_interval == 0:
|
| | self.save_checkpoint()
|
| |
|
| | self.global_step += 1
|
| |
|
| | if self.global_step >= self.config["max_steps"]:
|
| | print("Reached max steps")
|
| | return
|
| |
|
| | def evaluate(self) -> Dict[str, float]:
|
| | """Run evaluation."""
|
| | self.model.eval()
|
| | total_loss = 0.0
|
| | num_batches = 0
|
| |
|
| | with torch.no_grad():
|
| | for batch in self.eval_loader:
|
| | input_ids = batch["input_ids"].to(self.device)
|
| | attention_mask = batch["attention_mask"].to(self.device)
|
| | labels = batch["labels"].to(self.device)
|
| |
|
| | with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
|
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| | logits = outputs["logits"]
|
| | loss = F.cross_entropy(
|
| | logits.view(-1, logits.size(-1)),
|
| | labels.view(-1),
|
| | ignore_index=-100,
|
| | )
|
| |
|
| | total_loss += loss.item()
|
| | num_batches += 1
|
| |
|
| | avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
| | print(f"Evaluation at step {self.global_step}: loss = {avg_loss:.4f}")
|
| |
|
| | return {"eval_loss": avg_loss}
|
| |
|
| | def save_checkpoint(self, is_best: bool = False):
|
| | """Save model checkpoint."""
|
| | checkpoint = {
|
| | "step": self.global_step,
|
| | "model_state_dict": self.model.state_dict(),
|
| | "optimizer_state_dict": self.optimizer.state_dict(),
|
| | "scheduler_state_dict": self.scheduler.state_dict(),
|
| | "config": self.config,
|
| | "best_eval_loss": self.best_eval_loss,
|
| | }
|
| |
|
| | if self.scaler:
|
| | checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
| |
|
| |
|
| | checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step:06d}.pt"
|
| | torch.save(checkpoint, checkpoint_path)
|
| | print(f"Saved checkpoint to {checkpoint_path}")
|
| |
|
| |
|
| | if is_best:
|
| | best_path = self.checkpoint_dir / "best_model.pt"
|
| | torch.save(checkpoint, best_path)
|
| | print(f"Saved best model to {best_path}")
|
| |
|
| |
|
| | latest_path = self.checkpoint_dir / "latest.pt"
|
| | torch.save(checkpoint, latest_path)
|
| |
|
| | def load_checkpoint(self, checkpoint_path: str):
|
| | """Load checkpoint."""
|
| | checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| | self.model.load_state_dict(checkpoint["model_state_dict"])
|
| | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| | self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| | self.global_step = checkpoint["step"]
|
| | self.best_eval_loss = checkpoint.get("best_eval_loss", float('inf'))
|
| |
|
| | if self.scaler and "scaler_state_dict" in checkpoint:
|
| | self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
| |
|
| | print(f"Loaded checkpoint from {checkpoint_path} at step {self.global_step}")
|
| |
|
| | def _log_losses(self, losses: Dict[str, torch.Tensor], batch_idx: int):
|
| | """Log losses to console and file."""
|
| | loss_str = " | ".join([f"{k}: {v.item():.4f}" for k, v in losses.items()])
|
| | print(f"Step {self.global_step} | {loss_str}")
|
| |
|
| | def train(self):
|
| | """Main training loop."""
|
| | print("Starting training...")
|
| | print(f"Total steps: {self.config['max_steps']}")
|
| | print(f"Device: {self.device}")
|
| | print(f"Batch size: {self.config['micro_batch_size']}")
|
| | print(f"Gradient accumulation steps: {self.config['gradient_accumulation_steps']}")
|
| |
|
| | try:
|
| | self.train_epoch()
|
| | except KeyboardInterrupt:
|
| | print("Training interrupted")
|
| | finally:
|
| | self.save_checkpoint()
|
| |
|
| |
|
| | def test_trainer():
|
| | """Test trainer with small model."""
|
| | from models.vortex_model import VortexModel
|
| | from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| |
|
| |
|
| | config = VORTEX_7B_CONFIG.copy()
|
| | config["d_model"] = 256
|
| | config["num_layers"] = 2
|
| | config["num_heads"] = 4
|
| | config["vocab_size"] = 1000
|
| | config["max_steps"] = 10
|
| | config["device"] = "cpu"
|
| |
|
| |
|
| | model = VortexModel(config)
|
| |
|
| |
|
| | class DummyTokenizer:
|
| | def encode(self, text, add_special_tokens=True, return_tensors="pt"):
|
| | return {"input_ids": torch.randint(0, 1000, (1, 10)), "attention_mask": torch.ones(1, 10)}
|
| |
|
| | tokenizer = DummyTokenizer()
|
| |
|
| |
|
| | class DummyDataset(torch.utils.data.Dataset):
|
| | def __len__(self): return 10
|
| | def __getitem__(self, idx):
|
| | return {
|
| | "input_ids": torch.randint(0, 1000, (32,)),
|
| | "attention_mask": torch.ones(32),
|
| | "labels": torch.randint(0, 1000, (32,)),
|
| | "domain": "physics",
|
| | }
|
| |
|
| | train_dataset = DummyDataset()
|
| | eval_dataset = DummyDataset()
|
| |
|
| |
|
| | trainer = VortexTrainer(
|
| | model=model,
|
| | tokenizer=tokenizer,
|
| | train_dataset=train_dataset,
|
| | config=config,
|
| | eval_dataset=eval_dataset,
|
| | )
|
| |
|
| |
|
| | trainer.train()
|
| |
|
| | print("Trainer test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_trainer()
|
| |
|