Spaces:
Running
Running
| """ | |
| Training script for crop disease detection model | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR | |
| import time | |
| import copy | |
| import json | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| import numpy as np | |
| from dataset import create_data_loaders, get_class_weights | |
| from model import create_model, ModelCheckpoint, get_model_summary | |
| class Trainer: | |
| """Training class for crop disease detection model""" | |
| def __init__(self, model, train_loader, val_loader, class_names, device='cpu'): | |
| self.model = model | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.class_names = class_names | |
| self.device = device | |
| # Training history | |
| self.history = { | |
| 'train_loss': [], | |
| 'train_acc': [], | |
| 'val_loss': [], | |
| 'val_acc': [], | |
| 'lr': [] | |
| } | |
| def train_epoch(self, criterion, optimizer): | |
| """Train for one epoch""" | |
| self.model.train() | |
| running_loss = 0.0 | |
| running_corrects = 0 | |
| total_samples = 0 | |
| for inputs, labels in self.train_loader: | |
| inputs = inputs.to(self.device) | |
| labels = labels.to(self.device) | |
| # Zero gradients | |
| optimizer.zero_grad() | |
| # Forward pass | |
| outputs = self.model(inputs) | |
| _, preds = torch.max(outputs, 1) | |
| loss = criterion(outputs, labels) | |
| # Backward pass | |
| loss.backward() | |
| optimizer.step() | |
| # Statistics | |
| running_loss += loss.item() * inputs.size(0) | |
| running_corrects += torch.sum(preds == labels.data) | |
| total_samples += inputs.size(0) | |
| epoch_loss = running_loss / total_samples | |
| epoch_acc = running_corrects.double() / total_samples | |
| return epoch_loss, epoch_acc.item() | |
| def validate_epoch(self, criterion): | |
| """Validate for one epoch""" | |
| self.model.eval() | |
| running_loss = 0.0 | |
| running_corrects = 0 | |
| total_samples = 0 | |
| with torch.no_grad(): | |
| for inputs, labels in self.val_loader: | |
| inputs = inputs.to(self.device) | |
| labels = labels.to(self.device) | |
| # Forward pass | |
| outputs = self.model(inputs) | |
| _, preds = torch.max(outputs, 1) | |
| loss = criterion(outputs, labels) | |
| # Statistics | |
| running_loss += loss.item() * inputs.size(0) | |
| running_corrects += torch.sum(preds == labels.data) | |
| total_samples += inputs.size(0) | |
| epoch_loss = running_loss / total_samples | |
| epoch_acc = running_corrects.double() / total_samples | |
| return epoch_loss, epoch_acc.item() | |
| def train(self, num_epochs=25, learning_rate=1e-4, weight_decay=1e-4, | |
| use_class_weights=True, checkpoint_path='models/crop_disease_resnet50.pth', | |
| fine_tune_epoch=10): | |
| """ | |
| Train the model | |
| Args: | |
| num_epochs: Number of training epochs | |
| learning_rate: Initial learning rate | |
| weight_decay: Weight decay for regularization | |
| use_class_weights: Use class weights for imbalanced data | |
| checkpoint_path: Path to save best model | |
| fine_tune_epoch: Epoch to start fine-tuning (unfreeze all layers) | |
| """ | |
| print("Starting training...") | |
| print(f"Device: {self.device}") | |
| print(f"Number of classes: {len(self.class_names)}") | |
| print(f"Training samples: {len(self.train_loader.dataset)}") | |
| print(f"Validation samples: {len(self.val_loader.dataset)}") | |
| # Setup loss function | |
| if use_class_weights: | |
| class_weights = get_class_weights('data') | |
| class_weights = class_weights.to(self.device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| print("Using weighted CrossEntropyLoss") | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| print("Using standard CrossEntropyLoss") | |
| # Setup optimizer | |
| optimizer = optim.Adam( | |
| filter(lambda p: p.requires_grad, self.model.parameters()), | |
| lr=learning_rate, | |
| weight_decay=weight_decay | |
| ) | |
| # Setup learning rate scheduler | |
| scheduler = ReduceLROnPlateau( | |
| optimizer, mode='max', factor=0.5, patience=5 | |
| ) | |
| # Setup model checkpoint | |
| checkpoint = ModelCheckpoint( | |
| filepath=checkpoint_path, | |
| monitor='val_accuracy', | |
| mode='max' | |
| ) | |
| # Training loop | |
| best_acc = 0.0 | |
| start_time = time.time() | |
| for epoch in range(num_epochs): | |
| epoch_start = time.time() | |
| # Fine-tuning: unfreeze all layers after specified epoch | |
| if epoch == fine_tune_epoch: | |
| print(f"\nEpoch {epoch}: Starting fine-tuning (unfreezing all layers)") | |
| self.model.unfreeze_features() | |
| # Reduce learning rate for fine-tuning | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = learning_rate * 0.1 | |
| print(f"Reduced learning rate to: {optimizer.param_groups[0]['lr']}") | |
| # Training phase | |
| train_loss, train_acc = self.train_epoch(criterion, optimizer) | |
| # Validation phase | |
| val_loss, val_acc = self.validate_epoch(criterion) | |
| # Update learning rate | |
| scheduler.step(val_acc) | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| # Save history | |
| self.history['train_loss'].append(train_loss) | |
| self.history['train_acc'].append(train_acc) | |
| self.history['val_loss'].append(val_loss) | |
| self.history['val_acc'].append(val_acc) | |
| self.history['lr'].append(current_lr) | |
| # Save checkpoint | |
| metrics = { | |
| 'val_accuracy': val_acc, | |
| 'val_loss': val_loss, | |
| 'train_accuracy': train_acc, | |
| 'train_loss': train_loss | |
| } | |
| checkpoint(self.model, optimizer, epoch, metrics) | |
| # Update best accuracy | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| # Print progress | |
| epoch_time = time.time() - epoch_start | |
| print(f'Epoch {epoch+1:2d}/{num_epochs} | ' | |
| f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | ' | |
| f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | ' | |
| f'LR: {current_lr:.2e} | Time: {epoch_time:.1f}s') | |
| # Training completed | |
| total_time = time.time() - start_time | |
| print(f'\nTraining completed in {total_time//60:.0f}m {total_time%60:.0f}s') | |
| print(f'Best validation accuracy: {best_acc:.4f}') | |
| # Save training history | |
| self.save_training_history() | |
| return self.model, self.history | |
| def save_training_history(self, filepath='outputs/training_history.json'): | |
| """Save training history to file""" | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| with open(filepath, 'w') as f: | |
| json.dump(self.history, f, indent=2) | |
| print(f"Training history saved to: {filepath}") | |
| def plot_training_curves(self, save_path='outputs/training_curves.png'): | |
| """Plot and save training curves""" | |
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) | |
| epochs = range(1, len(self.history['train_loss']) + 1) | |
| # Loss curves | |
| ax1.plot(epochs, self.history['train_loss'], 'b-', label='Training Loss') | |
| ax1.plot(epochs, self.history['val_loss'], 'r-', label='Validation Loss') | |
| ax1.set_title('Training and Validation Loss') | |
| ax1.set_xlabel('Epoch') | |
| ax1.set_ylabel('Loss') | |
| ax1.legend() | |
| ax1.grid(True) | |
| # Accuracy curves | |
| ax2.plot(epochs, self.history['train_acc'], 'b-', label='Training Accuracy') | |
| ax2.plot(epochs, self.history['val_acc'], 'r-', label='Validation Accuracy') | |
| ax2.set_title('Training and Validation Accuracy') | |
| ax2.set_xlabel('Epoch') | |
| ax2.set_ylabel('Accuracy') | |
| ax2.legend() | |
| ax2.grid(True) | |
| # Learning rate | |
| ax3.plot(epochs, self.history['lr'], 'g-', label='Learning Rate') | |
| ax3.set_title('Learning Rate Schedule') | |
| ax3.set_xlabel('Epoch') | |
| ax3.set_ylabel('Learning Rate') | |
| ax3.set_yscale('log') | |
| ax3.legend() | |
| ax3.grid(True) | |
| # Combined accuracy | |
| ax4.plot(epochs, self.history['train_acc'], 'b-', label='Training') | |
| ax4.plot(epochs, self.history['val_acc'], 'r-', label='Validation') | |
| ax4.set_title('Model Accuracy Comparison') | |
| ax4.set_xlabel('Epoch') | |
| ax4.set_ylabel('Accuracy') | |
| ax4.legend() | |
| ax4.grid(True) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Training curves saved to: {save_path}") | |
| def main(): | |
| """Main training function""" | |
| # Configuration | |
| config = { | |
| 'data_dir': 'data', | |
| 'batch_size': 32, # Increased for GPU training | |
| 'num_epochs': 20, | |
| 'learning_rate': 1e-4, | |
| 'weight_decay': 1e-4, | |
| 'fine_tune_epoch': 10, | |
| 'checkpoint_path': 'models/crop_disease_resnet50.pth' | |
| } | |
| # Device setup | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |
| # Create data loaders | |
| print("Loading dataset...") | |
| train_loader, val_loader, test_loader, class_names = create_data_loaders( | |
| data_dir=config['data_dir'], | |
| batch_size=config['batch_size'], | |
| num_workers=0 if device.type == 'cpu' else 2 # Use more workers for GPU | |
| ) | |
| print(f"Dataset loaded: {len(class_names)} classes") | |
| print(f"Classes: {class_names}") | |
| # Create model | |
| print("Creating model...") | |
| model = create_model(num_classes=len(class_names), device=device) | |
| get_model_summary(model) | |
| # Create trainer | |
| trainer = Trainer(model, train_loader, val_loader, class_names, device) | |
| # Start training | |
| trained_model, history = trainer.train( | |
| num_epochs=config['num_epochs'], | |
| learning_rate=config['learning_rate'], | |
| weight_decay=config['weight_decay'], | |
| checkpoint_path=config['checkpoint_path'], | |
| fine_tune_epoch=config['fine_tune_epoch'] | |
| ) | |
| # Plot training curves | |
| trainer.plot_training_curves() | |
| print("\nTraining completed successfully!") | |
| print(f"Best model saved at: {config['checkpoint_path']}") | |
| if __name__ == "__main__": | |
| main() | |