Vortex-7b-V1 / training /trainer.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
Trainer: Main training loop for Vortex model.
Handles gradient accumulation, mixed precision, checkpointing.
"""
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from typing import Optional, Dict, List, Callable
from pathlib import Path
import logging
from ..training.losses import VortexLoss
from ..training.curriculum import CurriculumScheduler
class VortexDataset(Dataset):
"""Simple dataset wrapper."""
def __init__(
self,
shard_files: List[str],
tokenizer,
max_seq_len: int = 16384,
):
"""
Initialize dataset.
Args:
shard_files: List of parquet shard files
tokenizer: Tokenizer for encoding text
max_seq_len: Maximum sequence length
"""
self.shard_files = shard_files
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
# Load all shards into memory (for simplicity - would stream in practice)
self.samples = []
self._load_shards()
def _load_shards(self):
"""Load all shards."""
import pandas as pd
for shard in self.shard_files:
df = pd.read_parquet(shard)
for _, row in df.iterrows():
self.samples.append({
"text": row["text"],
"dataset": row.get("dataset", ""),
"domain": row.get("domain", ""),
})
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx) -> Dict:
sample = self.samples[idx]
text = sample["text"]
# Tokenize
encoding = self.tokenizer.encode(
text,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# Truncate if needed
if len(input_ids) > self.max_seq_len:
input_ids = input_ids[:self.max_seq_len]
attention_mask = attention_mask[:self.max_seq_len]
# Labels are same as input_ids (causal LM)
labels = input_ids.clone()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"domain": sample["domain"],
}
class VortexTrainer:
"""
Main trainer for Vortex model.
"""
def __init__(
self,
model: nn.Module,
tokenizer,
train_dataset: Dataset,
config: Dict,
eval_dataset: Optional[Dataset] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
"""
Initialize trainer.
Args:
model: VortexModel
tokenizer: VortexScienceTokenizer
train_dataset: Training dataset
config: Training configuration
eval_dataset: Optional evaluation dataset
optimizer: Optional optimizer (created if None)
scheduler: Optional LR scheduler
"""
self.model = model
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.config = config
self.device = torch.device(config["device"])
self.use_amp = config.get("use_amp", True)
self.amp_dtype = getattr(torch, config.get("amp_dtype", "bfloat16"))
# Move model to device
self.model.to(self.device)
# Setup optimizer
if optimizer is None:
self.optimizer = self._create_optimizer()
else:
self.optimizer = optimizer
# Setup scheduler
if scheduler is None:
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config["max_steps"],
)
else:
self.scheduler = scheduler
# Setup AMP scaler
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and self.device.type == "cuda" else None
# Loss function
self.loss_fn = VortexLoss(config)
# Curriculum scheduler
self.curriculum = CurriculumScheduler(config, config["max_steps"])
# Logging
self.log_dir = Path(config.get("log_dir", "logs"))
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_interval = config.get("log_interval", 100)
# Checkpointing
self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.save_interval = config.get("save_interval", 5000)
# Training state
self.global_step = 0
self.best_eval_loss = float('inf')
# Data loader
self.train_loader = DataLoader(
train_dataset,
batch_size=config["micro_batch_size"],
shuffle=True,
num_workers=config.get("num_workers", 4),
pin_memory=config.get("pin_memory", True),
prefetch_factor=config.get("prefetch_factor", 2),
)
if eval_dataset:
self.eval_loader = DataLoader(
eval_dataset,
batch_size=config["micro_batch_size"],
shuffle=False,
num_workers=config.get("num_workers", 4),
)
def _create_optimizer(self) -> torch.optim.Optimizer:
"""Create AdamW optimizer."""
return torch.optim.AdamW(
self.model.parameters(),
lr=self.config["learning_rate"],
betas=(self.config["beta1"], self.config["beta2"]),
weight_decay=self.config["weight_decay"],
)
def train_step(
self,
batch: Dict,
current_step: int,
) -> Dict[str, torch.Tensor]:
"""
Single training step.
Args:
batch: Batch dictionary
current_step: Current step number
Returns:
Dictionary of losses
"""
self.model.train()
# Move batch to device
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
# Domain info (placeholder - would extract from batch)
domain_ids = None
domain_tags = None
# Forward pass with AMP
with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
domain_ids=domain_ids,
domain_tags=domain_tags,
return_dict=True,
)
logits = outputs["logits"]
# Compute losses
losses = self.loss_fn(
logits=logits,
labels=labels,
# Pass modules and masks for auxiliary losses
)
# Backward pass
if self.scaler:
self.scaler.scale(losses["total_loss"]).backward()
else:
losses["total_loss"].backward()
return losses
def train_epoch(self):
"""Train for one epoch."""
self.model.train()
for batch_idx, batch in enumerate(self.train_loader):
# Train step
losses = self.train_step(batch, self.global_step)
# Gradient accumulation
if (self.global_step + 1) % self.config["gradient_accumulation_steps"] == 0:
# Gradient clipping
if self.config.get("clip_grad_norm", 0) > 0:
if self.scaler:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config["clip_grad_norm"],
)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# Logging
if self.global_step % self.log_interval == 0:
self._log_losses(losses, batch_idx)
# Evaluation
if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
self.evaluate()
# Checkpointing
if self.global_step % self.save_interval == 0:
self.save_checkpoint()
self.global_step += 1
if self.global_step >= self.config["max_steps"]:
print("Reached max steps")
return
def evaluate(self) -> Dict[str, float]:
"""Run evaluation."""
self.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in self.eval_loader:
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
logits = outputs["logits"]
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
print(f"Evaluation at step {self.global_step}: loss = {avg_loss:.4f}")
return {"eval_loss": avg_loss}
def save_checkpoint(self, is_best: bool = False):
"""Save model checkpoint."""
checkpoint = {
"step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"config": self.config,
"best_eval_loss": self.best_eval_loss,
}
if self.scaler:
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
# Save latest
checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step:06d}.pt"
torch.save(checkpoint, checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path}")
# Save best
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
print(f"Saved best model to {best_path}")
# Save latest link
latest_path = self.checkpoint_dir / "latest.pt"
torch.save(checkpoint, latest_path)
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.global_step = checkpoint["step"]
self.best_eval_loss = checkpoint.get("best_eval_loss", float('inf'))
if self.scaler and "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
print(f"Loaded checkpoint from {checkpoint_path} at step {self.global_step}")
def _log_losses(self, losses: Dict[str, torch.Tensor], batch_idx: int):
"""Log losses to console and file."""
loss_str = " | ".join([f"{k}: {v.item():.4f}" for k, v in losses.items()])
print(f"Step {self.global_step} | {loss_str}")
def train(self):
"""Main training loop."""
print("Starting training...")
print(f"Total steps: {self.config['max_steps']}")
print(f"Device: {self.device}")
print(f"Batch size: {self.config['micro_batch_size']}")
print(f"Gradient accumulation steps: {self.config['gradient_accumulation_steps']}")
try:
self.train_epoch()
except KeyboardInterrupt:
print("Training interrupted")
finally:
self.save_checkpoint()
def test_trainer():
"""Test trainer with small model."""
from models.vortex_model import VortexModel
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from configs.vortex_7b_config import VORTEX_7B_CONFIG
# Small config for testing
config = VORTEX_7B_CONFIG.copy()
config["d_model"] = 256
config["num_layers"] = 2
config["num_heads"] = 4
config["vocab_size"] = 1000
config["max_steps"] = 10
config["device"] = "cpu"
# Create model
model = VortexModel(config)
# Create dummy tokenizer
class DummyTokenizer:
def encode(self, text, add_special_tokens=True, return_tensors="pt"):
return {"input_ids": torch.randint(0, 1000, (1, 10)), "attention_mask": torch.ones(1, 10)}
tokenizer = DummyTokenizer()
# Create dummy dataset
class DummyDataset(torch.utils.data.Dataset):
def __len__(self): return 10
def __getitem__(self, idx):
return {
"input_ids": torch.randint(0, 1000, (32,)),
"attention_mask": torch.ones(32),
"labels": torch.randint(0, 1000, (32,)),
"domain": "physics",
}
train_dataset = DummyDataset()
eval_dataset = DummyDataset()
# Create trainer
trainer = VortexTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
config=config,
eval_dataset=eval_dataset,
)
# Run a few steps
trainer.train()
print("Trainer test passed!")
if __name__ == "__main__":
test_trainer()