Spaces:
Running
Running
| """ | |
| Simple Training Script for Pest and Disease Classification | |
| Using Rich for progress display | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from pathlib import Path | |
| import json | |
| import argparse | |
| from rich.console import Console | |
| from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn | |
| from rich.table import Table | |
| from rich.panel import Panel | |
| from dataset import get_dataloaders, calculate_class_weights | |
| from model import create_model | |
| console = Console() | |
| def train_epoch(model, dataloader, criterion, optimizer, device, progress, task): | |
| """Train for one epoch with progress bar""" | |
| model.train() | |
| running_loss = 0.0 | |
| running_corrects = 0 | |
| total_samples = 0 | |
| for inputs, labels in dataloader: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| _, preds = torch.max(outputs, 1) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * inputs.size(0) | |
| running_corrects += torch.sum(preds == labels.data) | |
| total_samples += inputs.size(0) | |
| progress.update(task, advance=1) | |
| epoch_loss = running_loss / total_samples | |
| epoch_acc = running_corrects.double() / total_samples | |
| return epoch_loss, epoch_acc.item() | |
| def validate_epoch(model, dataloader, criterion, device, progress, task): | |
| """Validate for one epoch with progress bar""" | |
| model.eval() | |
| running_loss = 0.0 | |
| running_corrects = 0 | |
| total_samples = 0 | |
| with torch.no_grad(): | |
| for inputs, labels in dataloader: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| _, preds = torch.max(outputs, 1) | |
| running_loss += loss.item() * inputs.size(0) | |
| running_corrects += torch.sum(preds == labels.data) | |
| total_samples += inputs.size(0) | |
| progress.update(task, advance=1) | |
| epoch_loss = running_loss / total_samples | |
| epoch_acc = running_corrects.double() / total_samples | |
| return epoch_loss, epoch_acc.item() | |
| def train_model(model, train_loader, val_loader, criterion, optimizer, | |
| num_epochs, device, save_dir): | |
| """ | |
| Simple training loop with Rich progress display | |
| """ | |
| save_dir = Path(save_dir) | |
| save_dir.mkdir(exist_ok=True) | |
| best_val_acc = 0.0 | |
| history = { | |
| 'train_loss': [], | |
| 'train_acc': [], | |
| 'val_loss': [], | |
| 'val_acc': [] | |
| } | |
| console.print("\n[bold green]Starting Training[/bold green]") | |
| for epoch in range(num_epochs): | |
| console.print(f"\n[bold cyan]Epoch {epoch+1}/{num_epochs}[/bold cyan]") | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), | |
| TimeRemainingColumn(), | |
| console=console | |
| ) as progress: | |
| # Training | |
| train_task = progress.add_task( | |
| "[red]Training...", | |
| total=len(train_loader) | |
| ) | |
| train_loss, train_acc = train_epoch( | |
| model, train_loader, criterion, optimizer, | |
| device, progress, train_task | |
| ) | |
| # Validation | |
| val_task = progress.add_task( | |
| "[green]Validating...", | |
| total=len(val_loader) | |
| ) | |
| val_loss, val_acc = validate_epoch( | |
| model, val_loader, criterion, device, | |
| progress, val_task | |
| ) | |
| # Create results table | |
| table = Table(show_header=True, header_style="bold magenta") | |
| table.add_column("Split", style="cyan") | |
| table.add_column("Loss", justify="right", style="yellow") | |
| table.add_column("Accuracy", justify="right", style="green") | |
| table.add_row("Train", f"{train_loss:.4f}", f"{train_acc:.4f}") | |
| table.add_row("Val", f"{val_loss:.4f}", f"{val_acc:.4f}") | |
| console.print(table) | |
| # Save history | |
| history['train_loss'].append(train_loss) | |
| history['train_acc'].append(train_acc) | |
| history['val_loss'].append(val_loss) | |
| history['val_acc'].append(val_acc) | |
| # Save best model | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_acc': val_acc, | |
| 'val_loss': val_loss, | |
| }, save_dir / 'best_model.pth') | |
| console.print(f"[bold green]✓ Saved best model (Val Acc: {val_acc:.4f})[/bold green]") | |
| # Save checkpoint every 10 epochs | |
| if (epoch + 1) % 10 == 0: | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_acc': val_acc, | |
| 'val_loss': val_loss, | |
| }, save_dir / f'checkpoint_epoch_{epoch+1}.pth') | |
| console.print(f"[yellow]Checkpoint saved at epoch {epoch+1}[/yellow]") | |
| # Save training history | |
| with open(save_dir / 'training_history.json', 'w') as f: | |
| json.dump(history, f, indent=2) | |
| console.print(f"\n[bold green]Training Complete![/bold green]") | |
| console.print(f"[bold]Best Val Acc: {best_val_acc:.4f}[/bold]") | |
| console.print(f"[bold]Results saved to: {save_dir}/[/bold]") | |
| return model, history | |
| def main(args): | |
| """Main training function""" | |
| # Print configuration | |
| config_panel = Panel.fit( | |
| f"""[bold]Configuration[/bold] | |
| Backbone: {args.backbone} | |
| Batch Size: {args.batch_size} | |
| Image Size: {args.img_size} | |
| Epochs: {args.epochs} | |
| Learning Rate: {args.lr} | |
| Optimizer: {args.optimizer} | |
| Device: {args.device} | |
| Class Weights: {args.use_class_weights}""", | |
| title="Training Settings", | |
| border_style="blue" | |
| ) | |
| console.print(config_panel) | |
| # Set device | |
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') | |
| console.print(f"\n[bold]Using device: {device}[/bold]") | |
| # Load data | |
| console.print("\n[bold]Loading datasets...[/bold]") | |
| loaders = get_dataloaders( | |
| csv_file=args.csv_file, | |
| label_mapping_file=args.label_mapping, | |
| batch_size=args.batch_size, | |
| img_size=args.img_size, | |
| num_workers=args.num_workers | |
| ) | |
| # Create model | |
| console.print(f"\n[bold]Creating model: {args.backbone}[/bold]") | |
| model = create_model( | |
| num_classes=loaders['num_classes'], | |
| backbone=args.backbone, | |
| pretrained=True, | |
| dropout=args.dropout | |
| ) | |
| model = model.to(device) | |
| # Loss function | |
| if args.use_class_weights: | |
| class_weights = calculate_class_weights(args.csv_file, args.label_mapping) | |
| class_weights = class_weights.to(device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| console.print("[bold]Using weighted CrossEntropyLoss[/bold]") | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| console.print("[bold]Using CrossEntropyLoss[/bold]") | |
| # Optimizer | |
| if args.optimizer == 'adam': | |
| optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
| elif args.optimizer == 'adamw': | |
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
| elif args.optimizer == 'sgd': | |
| optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, | |
| weight_decay=args.weight_decay) | |
| # Train model | |
| model, history = train_model( | |
| model=model, | |
| train_loader=loaders['train'], | |
| val_loader=loaders['val'], | |
| criterion=criterion, | |
| optimizer=optimizer, | |
| num_epochs=args.epochs, | |
| device=device, | |
| save_dir=args.save_dir | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Simple Training for Pest and Disease Classifier') | |
| # Data parameters | |
| parser.add_argument('--csv_file', type=str, default='dataset.csv') | |
| parser.add_argument('--label_mapping', type=str, default='label_mapping.json') | |
| # Model parameters | |
| parser.add_argument('--backbone', type=str, default='resnet50', | |
| choices=['resnet50', 'resnet101', 'efficientnet_b0', | |
| 'efficientnet_b3', 'mobilenet_v2']) | |
| parser.add_argument('--dropout', type=float, default=0.3) | |
| # Training parameters | |
| parser.add_argument('--batch_size', type=int, default=64) | |
| parser.add_argument('--img_size', type=int, default=224) | |
| parser.add_argument('--epochs', type=int, default=50) | |
| parser.add_argument('--lr', type=float, default=0.001) | |
| parser.add_argument('--optimizer', type=str, default='adamw', | |
| choices=['adam', 'adamw', 'sgd']) | |
| parser.add_argument('--weight_decay', type=float, default=0.01) | |
| parser.add_argument('--use_class_weights', action='store_true') | |
| # System parameters | |
| parser.add_argument('--device', type=str, default='cuda', | |
| choices=['cuda', 'cpu']) | |
| parser.add_argument('--num_workers', type=int, default=8) | |
| parser.add_argument('--save_dir', type=str, default='checkpoints') | |
| args = parser.parse_args() | |
| main(args) | |