chandu1617's picture
Upload 10 files
7a59d7b verified
#!/usr/bin/env python3
"""
βœ… OPTIMIZED Food101 + ResNet50 with major speed improvements
βœ… Mixed precision training (2x faster)
βœ… Better data loading (persistent workers)
βœ… Progress bars and better logging
βœ… Robust error handling and checkpointing
"""
import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import logging
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# -------------------------
# OPTIMIZED Data Loaders
# -------------------------
def get_food101_loaders(batch_size=64, num_workers=8): # Increased batch size and workers
"""Returns optimized train/val/test loaders + class names"""
# More aggressive data augmentation
transform_train = transforms.Compose([
transforms.Resize((256, 256)), # Resize larger first
transforms.RandomCrop((224, 224)), # Then crop to avoid distortion
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
try:
# Full train split (75k images)
full_train = torchvision.datasets.Food101(
root='./data', split='train', download=True, transform=transform_train
)
# 90/10 train/val split with fixed seed for reproducibility
torch.manual_seed(42)
train_size = int(0.9 * len(full_train))
val_size = len(full_train) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
full_train, [train_size, val_size]
)
# Test split (25k images)
test_dataset = torchvision.datasets.Food101(
root='./data', split='test', download=True, transform=transform_test
)
logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
# Optimized DataLoaders with persistent workers
train_loader = DataLoader(
train_dataset, batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True, persistent_workers=True, drop_last=True
)
val_loader = DataLoader(
val_dataset, batch_size, shuffle=False, num_workers=num_workers,
pin_memory=True, persistent_workers=True
)
test_loader = DataLoader(
test_dataset, batch_size, shuffle=False, num_workers=num_workers,
pin_memory=True, persistent_workers=True
)
return train_loader, val_loader, test_loader, full_train.classes
except Exception as e:
logger.error(f"Error loading data: {e}")
raise
# -------------------------
# ResNet Building Blocks (same as original but with better initialization)
# -------------------------
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample: identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample: identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet50(nn.Module):
def __init__(self, num_classes=101):
super().__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(3, 2, 1)
self.layer1 = self._make_layer(Bottleneck, 64, 3)
self.layer2 = self._make_layer(Bottleneck, 128, 4, 2)
self.layer3 = self._make_layer(Bottleneck, 256, 6, 2)
self.layer4 = self._make_layer(Bottleneck, 512, 3, 2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512*Bottleneck.expansion, num_classes)
# Better initialization
self._initialize_weights()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes*block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False),
nn.BatchNorm2d(planes*block.expansion)
)
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# -------------------------
# OPTIMIZED Training Function with Mixed Precision
# -------------------------
def train_model(model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from=None):
"""Optimized training loop with mixed precision and better checkpointing"""
os.makedirs('./outputs', exist_ok=True)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Label smoothing for better generalization
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
# Mixed precision scaler
scaler = GradScaler()
best_val_acc = 0.0
train_losses, val_accuracies, learning_rates = [], [], []
start_epoch = 0
# Resume from checkpoint if provided
if resume_from and os.path.exists(resume_from):
logger.info(f"Resuming from {resume_from}")
checkpoint = torch.load(resume_from, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
best_val_acc = checkpoint.get('best_val_accuracy', 0.0)
train_losses = checkpoint.get('train_losses', [])
val_accuracies = checkpoint.get('val_accuracies', [])
learning_rates = checkpoint.get('learning_rates', [])
logger.info(f"πŸš€ Starting training from epoch {start_epoch+1} for {num_epochs} total epochs...")
# Track timing
total_train_time = 0
for epoch in range(start_epoch, num_epochs):
epoch_start = time.time()
# Training phase
model.train()
running_loss = 0.0
train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
for batch_idx, (images, labels) in enumerate(train_pbar):
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
optimizer.zero_grad()
# Mixed precision forward pass
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
# Mixed precision backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'})
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
learning_rates.append(optimizer.param_groups[0]['lr'])
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False)
with torch.no_grad():
for images, labels in val_pbar:
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'})
val_acc = 100. * correct / total
val_accuracies.append(val_acc)
avg_val_loss = val_loss / len(val_loader)
# Save best model
is_best = val_acc > best_val_acc
if is_best:
best_val_acc = val_acc
# Save checkpoint every 10 epochs and if best
if (epoch + 1) % 10 == 0 or is_best or epoch == num_epochs - 1:
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_accuracy': best_val_acc,
'current_val_accuracy': val_acc,
'train_losses': train_losses,
'val_accuracies': val_accuracies,
'learning_rates': learning_rates,
}
if is_best:
torch.save(checkpoint, './outputs/food101_resnet50_best.pth')
# Save just the weights for easier loading
torch.save(model.state_dict(), './outputs/food101_resnet50_best_weights.pth')
if (epoch + 1) % 10 == 0:
torch.save(checkpoint, f'./outputs/food101_resnet50_epoch_{epoch+1}.pth')
scheduler.step()
epoch_time = time.time() - epoch_start
total_train_time += epoch_time
logger.info(f"Epoch {epoch+1:3d}/{num_epochs} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"Val Acc: {val_acc:.2f}% | "
f"Best: {best_val_acc:.2f}% | "
f"LR: {optimizer.param_groups[0]['lr']:.6f} | "
f"Time: {epoch_time:.1f}s")
# Save final model
final_checkpoint = {
'epoch': num_epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'final_val_accuracy': val_accuracies[-1],
'best_val_accuracy': best_val_acc,
'train_losses': train_losses,
'val_accuracies': val_accuracies,
'learning_rates': learning_rates,
'total_train_time': total_train_time,
}
torch.save(final_checkpoint, './outputs/food101_resnet50_final.pth')
torch.save(model.state_dict(), './outputs/food101_resnet50_final_weights.pth')
logger.info(f"πŸ“Š Total training time: {total_train_time/3600:.2f} hours")
# Test final accuracy
test_acc = evaluate_model(model, test_loader, device, "Test")
logger.info(f"🎯 Final Test Accuracy: {test_acc:.2f}%")
# Save comprehensive plots
plot_training_curves(train_losses, val_accuracies, learning_rates)
return best_val_acc, train_losses, val_accuracies
def evaluate_model(model, test_loader, device, dataset_name="Test"):
"""Evaluate model with progress bar"""
model.eval()
correct = 0
total = 0
test_pbar = tqdm(test_loader, desc=f'{dataset_name} Evaluation', leave=False)
with torch.no_grad():
for images, labels in test_pbar:
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
with autocast():
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'})
return 100. * correct / total
def plot_training_curves(train_losses, val_accuracies, learning_rates):
"""Enhanced plotting with more visualizations"""
epochs = np.arange(1, len(train_losses) + 1)
plt.style.use('default')
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Food101 ResNet50 Training Analysis', fontsize=16, fontweight='bold')
# Training Loss
axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2, alpha=0.8)
axes[0, 0].set_title('Training Loss Over Time', fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')
# Validation Accuracy
axes[0, 1].plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8)
axes[0, 1].set_title('Validation Accuracy Over Time', fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7,
label=f'Best: {max(val_accuracies):.2f}%')
axes[0, 1].legend()
# Learning Rate Schedule
axes[1, 0].plot(epochs, learning_rates, 'g-', linewidth=2, alpha=0.8)
axes[1, 0].set_title('Learning Rate Schedule', fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_yscale('log')
# Combined view
ax_combined = axes[1, 1]
ax_combined.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.8)
ax_combined.set_xlabel('Epoch')
ax_combined.set_ylabel('Loss', color='b')
ax_combined.tick_params(axis='y', labelcolor='b')
ax_combined.set_yscale('log')
ax2 = ax_combined.twinx()
ax2.plot(epochs, val_accuracies, 'r-', label='Val Accuracy', linewidth=2, alpha=0.8)
ax2.set_ylabel('Accuracy (%)', color='r')
ax2.tick_params(axis='y', labelcolor='r')
ax_combined.set_title('Loss vs Accuracy', fontweight='bold')
ax_combined.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('./outputs/training_analysis.png', dpi=300, bbox_inches='tight')
plt.close()
# Additional detailed accuracy plot
plt.figure(figsize=(12, 6))
plt.plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8)
plt.fill_between(epochs, val_accuracies, alpha=0.3)
plt.title('Validation Accuracy Progress', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.grid(True, alpha=0.3)
plt.axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7,
label=f'Peak Accuracy: {max(val_accuracies):.2f}%')
plt.legend()
plt.tight_layout()
plt.savefig('./outputs/accuracy_detail.png', dpi=300, bbox_inches='tight')
plt.close()
logger.info("πŸ“Š Saved enhanced training visualizations")
def save_classes(classes):
"""Save Food101 class names with better formatting"""
os.makedirs('./outputs', exist_ok=True)
with open('./outputs/food101_classes.txt', 'w') as f:
f.write("Food101 Classes (101 total)\n")
f.write("=" * 30 + "\n\n")
for i, cls in enumerate(sorted(classes), 1):
f.write(f"{i:3d}. {cls.replace('_', ' ').title()}\n")
# Also save as a simple list for easy loading
with open('./outputs/food101_classes_simple.txt', 'w') as f:
for cls in sorted(classes):
f.write(f"{cls}\n")
logger.info("πŸ“ Saved class names to ./outputs/")
def print_system_info():
"""Print system information for debugging"""
logger.info("πŸ–₯️ System Information:")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
logger.info(f"Number of CPU cores: {os.cpu_count()}")
# -------------------------
# MAIN
# -------------------------
def main():
print_system_info()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
try:
# Load data with optimized settings
logger.info("πŸ“₯ Loading Food101 dataset...")
train_loader, val_loader, test_loader, classes = get_food101_loaders(batch_size=64, num_workers=8)
save_classes(classes)
# Model
logger.info("πŸ—οΈ Building ResNet50...")
model = ResNet50(num_classes=101).to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Total parameters: {total_params/1e6:.1f}M")
logger.info(f"Trainable parameters: {trainable_params/1e6:.1f}M")
# Enable compilation for PyTorch 2.0+
if hasattr(torch, 'compile'):
logger.info("πŸš€ Compiling model for faster training...")
model = torch.compile(model)
# Train
best_val_acc, losses, accuracies = train_model(
model, train_loader, val_loader, test_loader, device,
num_epochs=100, resume_from='./outputs/food101_resnet50_best.pth' if os.path.exists('./outputs/food101_resnet50_best.pth') else None
)
logger.info(f"\nπŸŽ‰ TRAINING COMPLETE!")
logger.info(f"πŸ† Best Validation Accuracy: {best_val_acc:.2f}%")
logger.info(f"\nπŸ“ SAVED FILES:")
logger.info(f" β€’ ./outputs/food101_resnet50_best.pth (best checkpoint)")
logger.info(f" β€’ ./outputs/food101_resnet50_best_weights.pth (best weights only)")
logger.info(f" β€’ ./outputs/food101_resnet50_final.pth (final checkpoint)")
logger.info(f" β€’ ./outputs/food101_resnet50_final_weights.pth (final weights only)")
logger.info(f" β€’ ./outputs/training_analysis.png (comprehensive plots)")
logger.info(f" β€’ ./outputs/accuracy_detail.png (detailed accuracy)")
logger.info(f" β€’ ./outputs/food101_classes.txt (formatted class list)")
logger.info(f" β€’ ./outputs/food101_classes_simple.txt (simple class list)")
except Exception as e:
logger.error(f"❌ Training failed with error: {e}")
raise
if __name__ == "__main__":
main()