import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR import os import json from typing import Dict, List, Optional, Any, Tuple from pathlib import Path import wandb from accelerate import Accelerator from transformers import get_cosine_schedule_with_warmup import logging from ..configs.config import Config, TrainingConfig from ..architecture.model import CompactAIModel logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TextDataset(Dataset): """Dataset for text training data.""" def __init__(self, data: List[Dict[str, Any]], tokenizer=None, max_length: int = 1024): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] # Handle different data formats if isinstance(item, dict) and "text" in item: text = item["text"] elif isinstance(item, str): text = item else: raise ValueError(f"Unsupported data format: {type(item)}") # Tokenize if tokenizer is provided if self.tokenizer: tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True, padding="max_length") return { "input_ids": torch.tensor(tokens, dtype=torch.long), "attention_mask": torch.tensor([1] * len(tokens), dtype=torch.long), } else: # Return raw text for processing later return {"text": text} def create_sample_data(num_samples: int = 1000) -> List[Dict[str, str]]: """Create sample training data for demonstration.""" import random templates = [ "Question: {question}\nAnswer: {answer}", "Solve: {problem}\nSolution: {solution}", "Explain: {topic}\nExplanation: {explanation}", "Translate: {text}\nTranslation: {translation}", ] questions = [ "What is 2 + 2?", "What is the capital of France?", "How does photosynthesis work?", "What is machine learning?", "Explain quantum computing", "What is the speed of light?" ] answers = [ "4", "Paris", "Plants convert sunlight into energy using chlorophyll", "A type of artificial intelligence", "Computing using quantum mechanics", "Approximately 299,792,458 meters per second" ] data = [] for i in range(num_samples): template = random.choice(templates) if "{question}" in template: question = random.choice(questions) answer = random.choice(answers) text = template.format(question=question, answer=answer) elif "{problem}" in template: text = template.format(problem="2x + 5 = 15", solution="x = 5") elif "{topic}" in template: text = template.format(topic="gravity", explanation="The force that attracts objects with mass") else: text = template.format(text="Hello", translation="Hola") data.append({"text": text}) return data class Trainer: """Training class for the compact AI model.""" def __init__( self, model: CompactAIModel, training_config: TrainingConfig, accelerator: Optional[Accelerator] = None, use_wandb: bool = False, output_dir: str = "checkpoints" ): self.model = model self.config = training_config self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True) # Initialize accelerator if accelerator is None: accelerator = Accelerator( mixed_precision="fp16" if training_config.mixed_precision else "no", gradient_accumulation_steps=training_config.gradient_accumulation_steps, ) self.accelerator = accelerator # Prepare model self.model = self.accelerator.prepare(self.model) # Optimizer self.optimizer = AdamW( self.model.parameters(), lr=training_config.learning_rate, weight_decay=training_config.weight_decay, ) self.optimizer = self.accelerator.prepare(self.optimizer) # Learning rate scheduler self.lr_scheduler = get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=training_config.warmup_steps, num_training_steps=training_config.num_epochs * 1000, # Rough estimate ) # Loss function self.criterion = nn.CrossEntropyLoss() # Initialize wandb self.use_wandb = use_wandb if use_wandb: wandb.init(project="compact-ai-model", config=training_config.__dict__) # Training state self.global_step = 0 self.best_loss = float('inf') def save_checkpoint(self, epoch: int, loss: float): """Save model checkpoint.""" checkpoint_path = self.output_dir / f"checkpoint_epoch_{epoch}" checkpoint_path.mkdir(exist_ok=True) # Save model unwrapped_model = self.accelerator.unwrap_model(self.model) torch.save(unwrapped_model.state_dict(), checkpoint_path / "pytorch_model.bin") # Save optimizer state torch.save(self.optimizer.state_dict(), checkpoint_path / "optimizer.bin") # Save training state training_state = { "epoch": epoch, "global_step": self.global_step, "best_loss": self.best_loss, "current_loss": loss, } with open(checkpoint_path / "training_state.json", "w") as f: json.dump(training_state, f) logger.info(f"Saved checkpoint to {checkpoint_path}") def load_checkpoint(self, checkpoint_path: str): """Load model checkpoint.""" checkpoint_path = Path(checkpoint_path) # Load model state model_state = torch.load(checkpoint_path / "pytorch_model.bin", map_location="cpu") unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.load_state_dict(model_state) # Load optimizer state optimizer_state = torch.load(checkpoint_path / "optimizer.bin", map_location="cpu") self.optimizer.load_state_dict(optimizer_state) # Load training state with open(checkpoint_path / "training_state.json", "r") as f: training_state = json.load(f) self.global_step = training_state["global_step"] self.best_loss = training_state["best_loss"] logger.info(f"Loaded checkpoint from {checkpoint_path}") def train_epoch(self, train_loader: DataLoader) -> float: """Train for one epoch.""" self.model.train() total_loss = 0.0 num_batches = 0 for batch_idx, batch in enumerate(train_loader): with self.accelerator.accumulate(self.model): # Forward pass input_ids = batch["input_ids"] attention_mask = batch.get("attention_mask") outputs = self.model(input_ids, attention_mask, use_thinking=True) logits = outputs["logits"] # Shift for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() # Compute loss loss = self.criterion( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) # Backward pass self.accelerator.backward(loss) # Gradient clipping if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) # Optimizer step self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() total_loss += loss.item() num_batches += 1 self.global_step += 1 # Logging if batch_idx % self.config.log_interval == 0: current_lr = self.lr_scheduler.get_last_lr()[0] logger.info( f"Step {self.global_step}: Loss = {loss.item():.4f}, LR = {current_lr:.6f}" ) if self.use_wandb: wandb.log({ "train/loss": loss.item(), "train/learning_rate": current_lr, "train/global_step": self.global_step, }) return total_loss / num_batches def evaluate(self, eval_loader: DataLoader) -> float: """Evaluate the model.""" self.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in eval_loader: input_ids = batch["input_ids"] attention_mask = batch.get("attention_mask") outputs = self.model(input_ids, attention_mask, use_thinking=False) # Eval without thinking for speed logits = outputs["logits"] # Shift for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() loss = self.criterion( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / num_batches if self.use_wandb: wandb.log({"eval/loss": avg_loss}) return avg_loss def train(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None): """Main training loop.""" logger.info("Starting training...") for epoch in range(self.config.num_epochs): logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}") # Train train_loss = self.train_epoch(train_loader) # Evaluate if eval_loader is not None: eval_loss = self.evaluate(eval_loader) logger.info(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Eval Loss = {eval_loss:.4f}") # Save best model if eval_loss < self.best_loss: self.best_loss = eval_loss self.save_checkpoint(epoch, eval_loss) # Save regular checkpoints if (epoch + 1) % 5 == 0: self.save_checkpoint(epoch, train_loss) logger.info("Training completed!") def main(): """Main training function.""" import argparse parser = argparse.ArgumentParser(description="Train Compact AI Model") parser.add_argument("--data_path", type=str, default="training_data.json", help="Path to training data") parser.add_argument("--batch_size", type=int, default=8, help="Batch size") parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length") parser.add_argument("--output_dir", type=str, default="checkpoints", help="Output directory") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases logging") parser.add_argument("--model_size", type=str, default="small", choices=["tiny", "small", "medium"], help="Model size") parser.add_argument("--resume_from", type=str, help="Resume training from checkpoint") args = parser.parse_args() # Create model from ..architecture.model import create_compact_model model = create_compact_model(args.model_size) # Create training config training_config = TrainingConfig( learning_rate=args.learning_rate, batch_size=args.batch_size, num_epochs=args.num_epochs, ) # Initialize trainer trainer = Trainer( model=model, training_config=training_config, use_wandb=args.use_wandb, output_dir=args.output_dir, ) # Load data if os.path.exists(args.data_path): with open(args.data_path, "r") as f: data = json.load(f) else: logger.info("Creating sample training data...") data = create_sample_data(10000) with open(args.data_path, "w") as f: json.dump(data, f) # Create dataset and dataloader dataset = TextDataset(data, max_length=args.max_length) train_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, ) # Resume training if specified if args.resume_from: trainer.load_checkpoint(args.resume_from) # Start training trainer.train(train_loader) if __name__ == "__main__": main()