#!/usr/bin/env python3 """ ✅ OPTIMIZED Food101 + ResNet50 with major speed improvements ✅ Mixed precision training (2x faster) ✅ Better data loading (persistent workers) ✅ Progress bars and better logging ✅ Robust error handling and checkpointing """ import os import time import copy import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import logging import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch.cuda.amp import autocast, GradScaler # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # ------------------------- # OPTIMIZED Data Loaders # ------------------------- def get_food101_loaders(batch_size=64, num_workers=8): # Increased batch size and workers """Returns optimized train/val/test loaders + class names""" # More aggressive data augmentation transform_train = transforms.Compose([ transforms.Resize((256, 256)), # Resize larger first transforms.RandomCrop((224, 224)), # Then crop to avoid distortion transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) try: # Full train split (75k images) full_train = torchvision.datasets.Food101( root='./data', split='train', download=True, transform=transform_train ) # 90/10 train/val split with fixed seed for reproducibility torch.manual_seed(42) train_size = int(0.9 * len(full_train)) val_size = len(full_train) - train_size train_dataset, val_dataset = torch.utils.data.random_split( full_train, [train_size, val_size] ) # Test split (25k images) test_dataset = torchvision.datasets.Food101( root='./data', split='test', download=True, transform=transform_test ) logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}") # Optimized DataLoaders with persistent workers train_loader = DataLoader( train_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, drop_last=True ) val_loader = DataLoader( val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=True ) test_loader = DataLoader( test_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=True ) return train_loader, val_loader, test_loader, full_train.classes except Exception as e: logger.error(f"Error loading data: {e}") raise # ------------------------- # ResNet Building Blocks (same as original but with better initialization) # ------------------------- class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.downsample = downsample def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet50(nn.Module): def __init__(self, num_classes=101): super().__init__() self.inplanes = 64 self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(3, 2, 1) self.layer1 = self._make_layer(Bottleneck, 64, 3) self.layer2 = self._make_layer(Bottleneck, 128, 4, 2) self.layer3 = self._make_layer(Bottleneck, 256, 6, 2) self.layer4 = self._make_layer(Bottleneck, 512, 3, 2) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(512*Bottleneck.expansion, num_classes) # Better initialization self._initialize_weights() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes*block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False), nn.BatchNorm2d(planes*block.expansion) ) layers = [block(self.inplanes, planes, stride, downsample)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x # ------------------------- # OPTIMIZED Training Function with Mixed Precision # ------------------------- def train_model(model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from=None): """Optimized training loop with mixed precision and better checkpointing""" os.makedirs('./outputs', exist_ok=True) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Label smoothing for better generalization optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) # Mixed precision scaler scaler = GradScaler() best_val_acc = 0.0 train_losses, val_accuracies, learning_rates = [], [], [] start_epoch = 0 # Resume from checkpoint if provided if resume_from and os.path.exists(resume_from): logger.info(f"Resuming from {resume_from}") checkpoint = torch.load(resume_from, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] best_val_acc = checkpoint.get('best_val_accuracy', 0.0) train_losses = checkpoint.get('train_losses', []) val_accuracies = checkpoint.get('val_accuracies', []) learning_rates = checkpoint.get('learning_rates', []) logger.info(f"🚀 Starting training from epoch {start_epoch+1} for {num_epochs} total epochs...") # Track timing total_train_time = 0 for epoch in range(start_epoch, num_epochs): epoch_start = time.time() # Training phase model.train() running_loss = 0.0 train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False) for batch_idx, (images, labels) in enumerate(train_pbar): images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) optimizer.zero_grad() # Mixed precision forward pass with autocast(): outputs = model(images) loss = criterion(outputs, labels) # Mixed precision backward pass scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() running_loss += loss.item() train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'}) avg_train_loss = running_loss / len(train_loader) train_losses.append(avg_train_loss) learning_rates.append(optimizer.param_groups[0]['lr']) # Validation phase model.eval() val_loss = 0.0 correct = 0 total = 0 val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False) with torch.no_grad(): for images, labels in val_pbar: images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) with autocast(): outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'}) val_acc = 100. * correct / total val_accuracies.append(val_acc) avg_val_loss = val_loss / len(val_loader) # Save best model is_best = val_acc > best_val_acc if is_best: best_val_acc = val_acc # Save checkpoint every 10 epochs and if best if (epoch + 1) % 10 == 0 or is_best or epoch == num_epochs - 1: checkpoint = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'best_val_accuracy': best_val_acc, 'current_val_accuracy': val_acc, 'train_losses': train_losses, 'val_accuracies': val_accuracies, 'learning_rates': learning_rates, } if is_best: torch.save(checkpoint, './outputs/food101_resnet50_best.pth') # Save just the weights for easier loading torch.save(model.state_dict(), './outputs/food101_resnet50_best_weights.pth') if (epoch + 1) % 10 == 0: torch.save(checkpoint, f'./outputs/food101_resnet50_epoch_{epoch+1}.pth') scheduler.step() epoch_time = time.time() - epoch_start total_train_time += epoch_time logger.info(f"Epoch {epoch+1:3d}/{num_epochs} | " f"Train Loss: {avg_train_loss:.4f} | " f"Val Loss: {avg_val_loss:.4f} | " f"Val Acc: {val_acc:.2f}% | " f"Best: {best_val_acc:.2f}% | " f"LR: {optimizer.param_groups[0]['lr']:.6f} | " f"Time: {epoch_time:.1f}s") # Save final model final_checkpoint = { 'epoch': num_epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict(), 'final_val_accuracy': val_accuracies[-1], 'best_val_accuracy': best_val_acc, 'train_losses': train_losses, 'val_accuracies': val_accuracies, 'learning_rates': learning_rates, 'total_train_time': total_train_time, } torch.save(final_checkpoint, './outputs/food101_resnet50_final.pth') torch.save(model.state_dict(), './outputs/food101_resnet50_final_weights.pth') logger.info(f"📊 Total training time: {total_train_time/3600:.2f} hours") # Test final accuracy test_acc = evaluate_model(model, test_loader, device, "Test") logger.info(f"🎯 Final Test Accuracy: {test_acc:.2f}%") # Save comprehensive plots plot_training_curves(train_losses, val_accuracies, learning_rates) return best_val_acc, train_losses, val_accuracies def evaluate_model(model, test_loader, device, dataset_name="Test"): """Evaluate model with progress bar""" model.eval() correct = 0 total = 0 test_pbar = tqdm(test_loader, desc=f'{dataset_name} Evaluation', leave=False) with torch.no_grad(): for images, labels in test_pbar: images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) with autocast(): outputs = model(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'}) return 100. * correct / total def plot_training_curves(train_losses, val_accuracies, learning_rates): """Enhanced plotting with more visualizations""" epochs = np.arange(1, len(train_losses) + 1) plt.style.use('default') fig, axes = plt.subplots(2, 2, figsize=(16, 12)) fig.suptitle('Food101 ResNet50 Training Analysis', fontsize=16, fontweight='bold') # Training Loss axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2, alpha=0.8) axes[0, 0].set_title('Training Loss Over Time', fontweight='bold') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].grid(True, alpha=0.3) axes[0, 0].set_yscale('log') # Validation Accuracy axes[0, 1].plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8) axes[0, 1].set_title('Validation Accuracy Over Time', fontweight='bold') axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy (%)') axes[0, 1].grid(True, alpha=0.3) axes[0, 1].axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7, label=f'Best: {max(val_accuracies):.2f}%') axes[0, 1].legend() # Learning Rate Schedule axes[1, 0].plot(epochs, learning_rates, 'g-', linewidth=2, alpha=0.8) axes[1, 0].set_title('Learning Rate Schedule', fontweight='bold') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Learning Rate') axes[1, 0].grid(True, alpha=0.3) axes[1, 0].set_yscale('log') # Combined view ax_combined = axes[1, 1] ax_combined.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.8) ax_combined.set_xlabel('Epoch') ax_combined.set_ylabel('Loss', color='b') ax_combined.tick_params(axis='y', labelcolor='b') ax_combined.set_yscale('log') ax2 = ax_combined.twinx() ax2.plot(epochs, val_accuracies, 'r-', label='Val Accuracy', linewidth=2, alpha=0.8) ax2.set_ylabel('Accuracy (%)', color='r') ax2.tick_params(axis='y', labelcolor='r') ax_combined.set_title('Loss vs Accuracy', fontweight='bold') ax_combined.grid(True, alpha=0.3) plt.tight_layout() plt.savefig('./outputs/training_analysis.png', dpi=300, bbox_inches='tight') plt.close() # Additional detailed accuracy plot plt.figure(figsize=(12, 6)) plt.plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8) plt.fill_between(epochs, val_accuracies, alpha=0.3) plt.title('Validation Accuracy Progress', fontsize=14, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.grid(True, alpha=0.3) plt.axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7, label=f'Peak Accuracy: {max(val_accuracies):.2f}%') plt.legend() plt.tight_layout() plt.savefig('./outputs/accuracy_detail.png', dpi=300, bbox_inches='tight') plt.close() logger.info("📊 Saved enhanced training visualizations") def save_classes(classes): """Save Food101 class names with better formatting""" os.makedirs('./outputs', exist_ok=True) with open('./outputs/food101_classes.txt', 'w') as f: f.write("Food101 Classes (101 total)\n") f.write("=" * 30 + "\n\n") for i, cls in enumerate(sorted(classes), 1): f.write(f"{i:3d}. {cls.replace('_', ' ').title()}\n") # Also save as a simple list for easy loading with open('./outputs/food101_classes_simple.txt', 'w') as f: for cls in sorted(classes): f.write(f"{cls}\n") logger.info("📝 Saved class names to ./outputs/") def print_system_info(): """Print system information for debugging""" logger.info("🖥️ System Information:") logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA version: {torch.version.cuda}") logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") logger.info(f"Number of CPU cores: {os.cpu_count()}") # ------------------------- # MAIN # ------------------------- def main(): print_system_info() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") try: # Load data with optimized settings logger.info("📥 Loading Food101 dataset...") train_loader, val_loader, test_loader, classes = get_food101_loaders(batch_size=64, num_workers=8) save_classes(classes) # Model logger.info("🏗️ Building ResNet50...") model = ResNet50(num_classes=101).to(device) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Total parameters: {total_params/1e6:.1f}M") logger.info(f"Trainable parameters: {trainable_params/1e6:.1f}M") # Enable compilation for PyTorch 2.0+ if hasattr(torch, 'compile'): logger.info("🚀 Compiling model for faster training...") model = torch.compile(model) # Train best_val_acc, losses, accuracies = train_model( model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from='./outputs/food101_resnet50_best.pth' if os.path.exists('./outputs/food101_resnet50_best.pth') else None ) logger.info(f"\n🎉 TRAINING COMPLETE!") logger.info(f"🏆 Best Validation Accuracy: {best_val_acc:.2f}%") logger.info(f"\n📁 SAVED FILES:") logger.info(f" • ./outputs/food101_resnet50_best.pth (best checkpoint)") logger.info(f" • ./outputs/food101_resnet50_best_weights.pth (best weights only)") logger.info(f" • ./outputs/food101_resnet50_final.pth (final checkpoint)") logger.info(f" • ./outputs/food101_resnet50_final_weights.pth (final weights only)") logger.info(f" • ./outputs/training_analysis.png (comprehensive plots)") logger.info(f" • ./outputs/accuracy_detail.png (detailed accuracy)") logger.info(f" • ./outputs/food101_classes.txt (formatted class list)") logger.info(f" • ./outputs/food101_classes_simple.txt (simple class list)") except Exception as e: logger.error(f"❌ Training failed with error: {e}") raise if __name__ == "__main__": main()