Spaces:
Build error
Build error
| """ | |
| Training and Evaluation Pipeline | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR | |
| from torch.cuda.amp import GradScaler, autocast | |
| from tqdm import tqdm | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| import time | |
| import json | |
| from pathlib import Path | |
| import config | |
| from models import get_model | |
| from dataset import create_data_loaders | |
| class EarlyStopping: | |
| """Early stopping to stop training when validation loss doesn't improve""" | |
| def __init__(self, patience: int = 5, min_delta: float = 0.001): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.counter = 0 | |
| self.best_loss = None | |
| self.early_stop = False | |
| def __call__(self, val_loss: float) -> bool: | |
| if self.best_loss is None: | |
| self.best_loss = val_loss | |
| elif val_loss > self.best_loss - self.min_delta: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| else: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| return self.early_stop | |
| class Trainer: | |
| """Model trainer with mixed precision and various optimizations""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| model_name: str, | |
| train_loader, | |
| val_loader, | |
| num_classes: int, | |
| device: str = config.DEVICE | |
| ): | |
| self.model = model.to(device) | |
| self.model_name = model_name | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.device = device | |
| self.num_classes = num_classes | |
| # Loss function with label smoothing | |
| self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| # Optimizer - AdamW with weight decay | |
| self.optimizer = optim.AdamW( | |
| model.parameters(), | |
| lr=config.LEARNING_RATE, | |
| weight_decay=config.WEIGHT_DECAY | |
| ) | |
| # Learning rate scheduler | |
| self.scheduler = OneCycleLR( | |
| self.optimizer, | |
| max_lr=config.LEARNING_RATE * 10, | |
| epochs=config.EPOCHS, | |
| steps_per_epoch=len(train_loader), | |
| pct_start=0.1, | |
| anneal_strategy='cos' | |
| ) | |
| # Mixed precision training | |
| self.scaler = GradScaler() | |
| # Early stopping | |
| self.early_stopping = EarlyStopping(patience=config.EARLY_STOPPING_PATIENCE) | |
| # History | |
| self.history = { | |
| 'train_loss': [], | |
| 'train_acc': [], | |
| 'val_loss': [], | |
| 'val_acc': [], | |
| 'lr': [] | |
| } | |
| # Best model tracking | |
| self.best_val_acc = 0.0 | |
| self.best_model_state = None | |
| def train_epoch(self) -> Tuple[float, float]: | |
| """Train for one epoch""" | |
| self.model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| pbar = tqdm(self.train_loader, desc="Training", leave=False) | |
| for images, labels in pbar: | |
| images = images.to(self.device) | |
| labels = labels.to(self.device) | |
| self.optimizer.zero_grad() | |
| # Mixed precision forward pass | |
| with autocast(): | |
| outputs = self.model(images) | |
| loss = self.criterion(outputs, labels) | |
| # Backward pass with gradient scaling | |
| self.scaler.scale(loss).backward() | |
| # Gradient clipping | |
| self.scaler.unscale_(self.optimizer) | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.scheduler.step() | |
| running_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| pbar.set_postfix({ | |
| 'loss': f'{loss.item():.4f}', | |
| 'acc': f'{100.*correct/total:.2f}%' | |
| }) | |
| epoch_loss = running_loss / len(self.train_loader) | |
| epoch_acc = 100. * correct / total | |
| return epoch_loss, epoch_acc | |
| def validate(self) -> Tuple[float, float]: | |
| """Validate the model""" | |
| self.model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in tqdm(self.val_loader, desc="Validating", leave=False): | |
| images = images.to(self.device) | |
| labels = labels.to(self.device) | |
| with autocast(): | |
| outputs = self.model(images) | |
| loss = self.criterion(outputs, labels) | |
| running_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| epoch_loss = running_loss / len(self.val_loader) | |
| epoch_acc = 100. * correct / total | |
| return epoch_loss, epoch_acc | |
| def train(self, epochs: int = config.EPOCHS) -> Dict: | |
| """Full training loop""" | |
| print(f"\n{'='*60}") | |
| print(f"Training {self.model_name}") | |
| print(f"{'='*60}") | |
| print(f"Device: {self.device}") | |
| print(f"Epochs: {epochs}") | |
| print(f"Batch size: {config.BATCH_SIZE}") | |
| print(f"Learning rate: {config.LEARNING_RATE}") | |
| start_time = time.time() | |
| for epoch in range(epochs): | |
| print(f"\nEpoch [{epoch+1}/{epochs}]") | |
| # Train | |
| train_loss, train_acc = self.train_epoch() | |
| # Validate | |
| val_loss, val_acc = self.validate() | |
| # Get current learning rate | |
| current_lr = self.optimizer.param_groups[0]['lr'] | |
| # Update history | |
| self.history['train_loss'].append(train_loss) | |
| self.history['train_acc'].append(train_acc) | |
| self.history['val_loss'].append(val_loss) | |
| self.history['val_acc'].append(val_acc) | |
| self.history['lr'].append(current_lr) | |
| print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%") | |
| print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%") | |
| print(f" LR: {current_lr:.6f}") | |
| # Save best model | |
| if val_acc > self.best_val_acc: | |
| self.best_val_acc = val_acc | |
| self.best_model_state = self.model.state_dict().copy() | |
| print(f" *** New best model! ***") | |
| # Early stopping | |
| if self.early_stopping(val_loss): | |
| print(f"\nEarly stopping triggered at epoch {epoch+1}") | |
| break | |
| training_time = time.time() - start_time | |
| # Load best model | |
| if self.best_model_state is not None: | |
| self.model.load_state_dict(self.best_model_state) | |
| print(f"\nTraining completed in {training_time/60:.2f} minutes") | |
| print(f"Best validation accuracy: {self.best_val_acc:.2f}%") | |
| # Add training time to history | |
| self.history['training_time'] = training_time | |
| self.history['best_val_acc'] = self.best_val_acc | |
| return self.history | |
| def save_model(self, path: Optional[Path] = None): | |
| """Save the trained model""" | |
| if path is None: | |
| path = config.MODELS_DIR / f"{self.model_name.lower().replace(' ', '_')}.pth" | |
| torch.save({ | |
| 'model_state_dict': self.model.state_dict(), | |
| 'model_name': self.model_name, | |
| 'num_classes': self.num_classes, | |
| 'best_val_acc': self.best_val_acc, | |
| 'history': self.history | |
| }, path) | |
| print(f"Model saved to {path}") | |
| return path | |
| def train_all_models(): | |
| """Train all 5 models and return results""" | |
| print("\n" + "="*70) | |
| print("TRAINING 5 MODELS FOR INDONESIAN HERBAL PLANTS CLASSIFICATION") | |
| print("="*70) | |
| # Create data loaders | |
| train_loader, val_loader, test_loader, class_names = create_data_loaders() | |
| num_classes = len(class_names) | |
| # Save class names | |
| with open(config.OUTPUT_DIR / "class_names.json", 'w') as f: | |
| json.dump(class_names, f, indent=2) | |
| results = {} | |
| for model_name in config.MODEL_NAMES: | |
| print(f"\n{'#'*70}") | |
| print(f"# Model: {model_name.upper()}") | |
| print(f"{'#'*70}") | |
| # Create model | |
| model = get_model(model_name, num_classes, pretrained=True) | |
| # Count parameters | |
| params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Total parameters: {params:,}") | |
| print(f"Trainable parameters: {trainable_params:,}") | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| model_name=model_name, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| num_classes=num_classes | |
| ) | |
| # Train | |
| history = trainer.train(epochs=config.EPOCHS) | |
| # Save model | |
| model_path = trainer.save_model() | |
| # Store results | |
| results[model_name] = { | |
| 'history': history, | |
| 'model_path': str(model_path), | |
| 'params': params, | |
| 'trainable_params': trainable_params | |
| } | |
| # Save results summary | |
| with open(config.OUTPUT_DIR / "training_results.json", 'w') as f: | |
| # Convert to serializable format | |
| serializable_results = {} | |
| for name, data in results.items(): | |
| serializable_results[name] = { | |
| 'best_val_acc': data['history']['best_val_acc'], | |
| 'training_time': data['history']['training_time'], | |
| 'params': data['params'], | |
| 'model_path': data['model_path'] | |
| } | |
| json.dump(serializable_results, f, indent=2) | |
| return results, test_loader, class_names | |
| if __name__ == "__main__": | |
| results, test_loader, class_names = train_all_models() | |