| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.optim import AdamW |
| | from torch.optim.lr_scheduler import CosineAnnealingLR |
| | import os |
| | import json |
| | from typing import Dict, List, Optional, Any, Tuple |
| | from pathlib import Path |
| | import wandb |
| | from accelerate import Accelerator |
| | from transformers import get_cosine_schedule_with_warmup |
| | import logging |
| | from ..configs.config import Config, TrainingConfig |
| | from ..architecture.model import CompactAIModel |
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class TextDataset(Dataset): |
| | """Dataset for text training data.""" |
| |
|
| | def __init__(self, data: List[Dict[str, Any]], tokenizer=None, max_length: int = 1024): |
| | self.data = data |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| |
|
| | |
| | if isinstance(item, dict) and "text" in item: |
| | text = item["text"] |
| | elif isinstance(item, str): |
| | text = item |
| | else: |
| | raise ValueError(f"Unsupported data format: {type(item)}") |
| |
|
| | |
| | if self.tokenizer: |
| | tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True, padding="max_length") |
| | return { |
| | "input_ids": torch.tensor(tokens, dtype=torch.long), |
| | "attention_mask": torch.tensor([1] * len(tokens), dtype=torch.long), |
| | } |
| | else: |
| | |
| | return {"text": text} |
| |
|
| |
|
| | def create_sample_data(num_samples: int = 1000) -> List[Dict[str, str]]: |
| | """Create sample training data for demonstration.""" |
| | import random |
| |
|
| | templates = [ |
| | "Question: {question}\nAnswer: {answer}", |
| | "Solve: {problem}\nSolution: {solution}", |
| | "Explain: {topic}\nExplanation: {explanation}", |
| | "Translate: {text}\nTranslation: {translation}", |
| | ] |
| |
|
| | questions = [ |
| | "What is 2 + 2?", "What is the capital of France?", "How does photosynthesis work?", |
| | "What is machine learning?", "Explain quantum computing", "What is the speed of light?" |
| | ] |
| |
|
| | answers = [ |
| | "4", "Paris", "Plants convert sunlight into energy using chlorophyll", |
| | "A type of artificial intelligence", "Computing using quantum mechanics", |
| | "Approximately 299,792,458 meters per second" |
| | ] |
| |
|
| | data = [] |
| | for i in range(num_samples): |
| | template = random.choice(templates) |
| |
|
| | if "{question}" in template: |
| | question = random.choice(questions) |
| | answer = random.choice(answers) |
| | text = template.format(question=question, answer=answer) |
| | elif "{problem}" in template: |
| | text = template.format(problem="2x + 5 = 15", solution="x = 5") |
| | elif "{topic}" in template: |
| | text = template.format(topic="gravity", explanation="The force that attracts objects with mass") |
| | else: |
| | text = template.format(text="Hello", translation="Hola") |
| |
|
| | data.append({"text": text}) |
| |
|
| | return data |
| |
|
| |
|
| | class Trainer: |
| | """Training class for the compact AI model.""" |
| |
|
| | def __init__( |
| | self, |
| | model: CompactAIModel, |
| | training_config: TrainingConfig, |
| | accelerator: Optional[Accelerator] = None, |
| | use_wandb: bool = False, |
| | output_dir: str = "checkpoints" |
| | ): |
| | self.model = model |
| | self.config = training_config |
| | self.output_dir = Path(output_dir) |
| | self.output_dir.mkdir(exist_ok=True) |
| |
|
| | |
| | if accelerator is None: |
| | accelerator = Accelerator( |
| | mixed_precision="fp16" if training_config.mixed_precision else "no", |
| | gradient_accumulation_steps=training_config.gradient_accumulation_steps, |
| | ) |
| | self.accelerator = accelerator |
| |
|
| | |
| | self.model = self.accelerator.prepare(self.model) |
| |
|
| | |
| | self.optimizer = AdamW( |
| | self.model.parameters(), |
| | lr=training_config.learning_rate, |
| | weight_decay=training_config.weight_decay, |
| | ) |
| | self.optimizer = self.accelerator.prepare(self.optimizer) |
| |
|
| | |
| | self.lr_scheduler = get_cosine_schedule_with_warmup( |
| | self.optimizer, |
| | num_warmup_steps=training_config.warmup_steps, |
| | num_training_steps=training_config.num_epochs * 1000, |
| | ) |
| |
|
| | |
| | self.criterion = nn.CrossEntropyLoss() |
| |
|
| | |
| | self.use_wandb = use_wandb |
| | if use_wandb: |
| | wandb.init(project="compact-ai-model", config=training_config.__dict__) |
| |
|
| | |
| | self.global_step = 0 |
| | self.best_loss = float('inf') |
| |
|
| | def save_checkpoint(self, epoch: int, loss: float): |
| | """Save model checkpoint.""" |
| | checkpoint_path = self.output_dir / f"checkpoint_epoch_{epoch}" |
| | checkpoint_path.mkdir(exist_ok=True) |
| |
|
| | |
| | unwrapped_model = self.accelerator.unwrap_model(self.model) |
| | torch.save(unwrapped_model.state_dict(), checkpoint_path / "pytorch_model.bin") |
| |
|
| | |
| | torch.save(self.optimizer.state_dict(), checkpoint_path / "optimizer.bin") |
| |
|
| | |
| | training_state = { |
| | "epoch": epoch, |
| | "global_step": self.global_step, |
| | "best_loss": self.best_loss, |
| | "current_loss": loss, |
| | } |
| | with open(checkpoint_path / "training_state.json", "w") as f: |
| | json.dump(training_state, f) |
| |
|
| | logger.info(f"Saved checkpoint to {checkpoint_path}") |
| |
|
| | def load_checkpoint(self, checkpoint_path: str): |
| | """Load model checkpoint.""" |
| | checkpoint_path = Path(checkpoint_path) |
| |
|
| | |
| | model_state = torch.load(checkpoint_path / "pytorch_model.bin", map_location="cpu") |
| | unwrapped_model = self.accelerator.unwrap_model(self.model) |
| | unwrapped_model.load_state_dict(model_state) |
| |
|
| | |
| | optimizer_state = torch.load(checkpoint_path / "optimizer.bin", map_location="cpu") |
| | self.optimizer.load_state_dict(optimizer_state) |
| |
|
| | |
| | with open(checkpoint_path / "training_state.json", "r") as f: |
| | training_state = json.load(f) |
| |
|
| | self.global_step = training_state["global_step"] |
| | self.best_loss = training_state["best_loss"] |
| |
|
| | logger.info(f"Loaded checkpoint from {checkpoint_path}") |
| |
|
| | def train_epoch(self, train_loader: DataLoader) -> float: |
| | """Train for one epoch.""" |
| | self.model.train() |
| | total_loss = 0.0 |
| | num_batches = 0 |
| |
|
| | for batch_idx, batch in enumerate(train_loader): |
| | with self.accelerator.accumulate(self.model): |
| | |
| | input_ids = batch["input_ids"] |
| | attention_mask = batch.get("attention_mask") |
| |
|
| | outputs = self.model(input_ids, attention_mask, use_thinking=True) |
| | logits = outputs["logits"] |
| |
|
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = input_ids[..., 1:].contiguous() |
| |
|
| | |
| | loss = self.criterion( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1) |
| | ) |
| |
|
| | |
| | self.accelerator.backward(loss) |
| |
|
| | |
| | if self.accelerator.sync_gradients: |
| | self.accelerator.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| |
|
| | |
| | self.optimizer.step() |
| | self.lr_scheduler.step() |
| | self.optimizer.zero_grad() |
| |
|
| | total_loss += loss.item() |
| | num_batches += 1 |
| | self.global_step += 1 |
| |
|
| | |
| | if batch_idx % self.config.log_interval == 0: |
| | current_lr = self.lr_scheduler.get_last_lr()[0] |
| | logger.info( |
| | f"Step {self.global_step}: Loss = {loss.item():.4f}, LR = {current_lr:.6f}" |
| | ) |
| |
|
| | if self.use_wandb: |
| | wandb.log({ |
| | "train/loss": loss.item(), |
| | "train/learning_rate": current_lr, |
| | "train/global_step": self.global_step, |
| | }) |
| |
|
| | return total_loss / num_batches |
| |
|
| | def evaluate(self, eval_loader: DataLoader) -> float: |
| | """Evaluate the model.""" |
| | self.model.eval() |
| | total_loss = 0.0 |
| | num_batches = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in eval_loader: |
| | input_ids = batch["input_ids"] |
| | attention_mask = batch.get("attention_mask") |
| |
|
| | outputs = self.model(input_ids, attention_mask, use_thinking=False) |
| | logits = outputs["logits"] |
| |
|
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = input_ids[..., 1:].contiguous() |
| |
|
| | loss = self.criterion( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1) |
| | ) |
| |
|
| | total_loss += loss.item() |
| | num_batches += 1 |
| |
|
| | avg_loss = total_loss / num_batches |
| |
|
| | if self.use_wandb: |
| | wandb.log({"eval/loss": avg_loss}) |
| |
|
| | return avg_loss |
| |
|
| | def train(self, train_loader: DataLoader, eval_loader: Optional[DataLoader] = None): |
| | """Main training loop.""" |
| | logger.info("Starting training...") |
| |
|
| | for epoch in range(self.config.num_epochs): |
| | logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}") |
| |
|
| | |
| | train_loss = self.train_epoch(train_loader) |
| |
|
| | |
| | if eval_loader is not None: |
| | eval_loss = self.evaluate(eval_loader) |
| | logger.info(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Eval Loss = {eval_loss:.4f}") |
| |
|
| | |
| | if eval_loss < self.best_loss: |
| | self.best_loss = eval_loss |
| | self.save_checkpoint(epoch, eval_loss) |
| |
|
| | |
| | if (epoch + 1) % 5 == 0: |
| | self.save_checkpoint(epoch, train_loss) |
| |
|
| | logger.info("Training completed!") |
| |
|
| |
|
| | def main(): |
| | """Main training function.""" |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="Train Compact AI Model") |
| | parser.add_argument("--data_path", type=str, default="training_data.json", help="Path to training data") |
| | parser.add_argument("--batch_size", type=int, default=8, help="Batch size") |
| | parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs") |
| | parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") |
| | parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length") |
| | parser.add_argument("--output_dir", type=str, default="checkpoints", help="Output directory") |
| | parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases logging") |
| | parser.add_argument("--model_size", type=str, default="small", choices=["tiny", "small", "medium"], help="Model size") |
| | parser.add_argument("--resume_from", type=str, help="Resume training from checkpoint") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | from ..architecture.model import create_compact_model |
| | model = create_compact_model(args.model_size) |
| |
|
| | |
| | training_config = TrainingConfig( |
| | learning_rate=args.learning_rate, |
| | batch_size=args.batch_size, |
| | num_epochs=args.num_epochs, |
| | ) |
| |
|
| | |
| | trainer = Trainer( |
| | model=model, |
| | training_config=training_config, |
| | use_wandb=args.use_wandb, |
| | output_dir=args.output_dir, |
| | ) |
| |
|
| | |
| | if os.path.exists(args.data_path): |
| | with open(args.data_path, "r") as f: |
| | data = json.load(f) |
| | else: |
| | logger.info("Creating sample training data...") |
| | data = create_sample_data(10000) |
| | with open(args.data_path, "w") as f: |
| | json.dump(data, f) |
| |
|
| | |
| | dataset = TextDataset(data, max_length=args.max_length) |
| | train_loader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=4, |
| | pin_memory=True, |
| | ) |
| |
|
| | |
| | if args.resume_from: |
| | trainer.load_checkpoint(args.resume_from) |
| |
|
| | |
| | trainer.train(train_loader) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |