import argparse import torch # type: ignore import torch.nn as nn # type: ignore from torchinfo import summary # type: ignore import torchvision # type: ignore import torchvision.transforms as T # type: ignore from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights # type: ignore from torch.utils.data import Subset, DataLoader # type: ignore import wandb # type: ignore device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # Load the pre-trained EfficientNetV2-S model model_weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 model = efficientnet_v2_s(weights=model_weights).to(device) model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=101).to(device) # load dataset and create dataloaders here dataset = torchvision.datasets.Food101(root='./data', split='train', download=True) # Split the dataset into training and testing sets train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) # transform functions train_transforms = T.Compose([ T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ColorJitter(0.2, 0.2, 0.2, 0.1), model_weights.transforms() ]) test_transforms = T.Compose([ model_weights.transforms() ]) # Apply transforms to datasets train_dataset.dataset.transform = train_transforms test_dataset.dataset.transform = test_transforms # Create DataLoaders for training and testing sets train_loader = DataLoader( train_dataset, batch_size=16, shuffle=True, num_workers=2, persistent_workers=True ) test_loader = DataLoader( test_dataset, batch_size=16, shuffle=False, num_workers=2, persistent_workers=True ) # checkpoint callback def save_checkpoint(epoch, model, optimizer, val_loss, path="checkpoints/best_model.pth"): torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_loss }, path) print(f"Checkpoint saved at epoch {epoch} to {path}") class CheckpointCallback: def __init__(self, path="checkpoints/best_model.pth"): self.best_loss = float('inf') self.path = path def __call__(self, epoch, model, optimizer, val_loss): if val_loss < self.best_loss: self.best_loss = val_loss save_checkpoint(epoch, model, optimizer, val_loss, self.path) return True return False # early stopping callback class EarlyStopping: def __init__(self, patience=3, min_delta=0.0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = float('inf') self.early_stop = False def __call__(self, val_loss): if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True # training function def train_model(run, model, train_loader, val_loader, loss_fn, optimizer, device, epochs=5, checkpoint=None, early_stopping=None): global_step = 0 model.to(device) for epoch in range(epochs): train_loss = 0.0 train_accuracy = 0.0 model.train() for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() y_preds = model(images) loss = loss_fn(y_preds, labels) loss.backward() optimizer.step() if global_step % 1 == 0: run.log({ "train/loss": loss.item() }, step=global_step) global_step += 1 train_loss += loss.item() * labels.size(0) train_accuracy += (y_preds.argmax(dim=1) == labels).sum().item() train_loss /= len(train_loader.dataset) train_accuracy /= len(train_loader.dataset) print(f"Epoch [{epoch + 1}/{epochs}], Loss: {train_loss:.4f} | Accuracy: {train_accuracy:.4f}") # validation phase model.eval() val_loss = 0.0 val_accuracy = 0.0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) y_preds = model(images) loss = loss_fn(y_preds, labels) val_loss += loss.item() * images.size(0) val_accuracy += (y_preds.argmax(dim=1) == labels).sum().item() val_loss /= len(val_loader.dataset) val_accuracy /= len(val_loader.dataset) print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}") run.log({ "val/loss": val_loss, "val/accuracy": val_accuracy, "train/accuracy": train_accuracy, "epoch": epoch + 1, }, step=global_step) # callbacks if checkpoint: checkpoint(epoch, model, optimizer, val_loss) if early_stopping: early_stopping(val_loss) if early_stopping.early_stop: print("Early stopping triggered") break run.finish() # evaluation function def evaluate_model(model, test_loader, loss_fn, device): model.eval() test_loss = 0.0 test_accuracy = 0.0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) y_preds = model(images) loss = loss_fn(y_preds, labels) test_loss += loss.item() * images.size(0) test_accuracy += (y_preds.argmax(dim=1) == labels).sum().item() test_loss /= len(test_loader.dataset) test_accuracy /= len(test_loader.dataset) print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.4f}") # initalization for wandb def initialize_wandb(project_name, run_name, config): run = wandb.init( entity="i24106-code-i", project=project_name, name=run_name, config=config ) return run if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train EfficientNetV2-S on Food-101 dataset") parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs") parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for optimizer") parser.add_argument("--model_path", type=str, default="checkpoints/best_model.pth", help="Path to save the best model checkpoint") parser.add_argument("--log_run_name", type=str, default="EfficientNetV2S_Run", help="WandB run name") args = parser.parse_args() saved_model = torch.load(args.model_path, map_location=device) model.load_state_dict(saved_model['model_state_dict']) model.to(device) # freeze all layers for p in model.features.parameters(): p.requires_grad = False # unfreeze last 2 blocks (tune N = 1,2,3) for p in model.features[-2:].parameters(): p.requires_grad = True # Define loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam([ {"params": model.features[-2:].parameters(), "lr": 1e-5}, {"params": model.classifier.parameters(), "lr": 1e-4}, ], weight_decay=1e-4) # Create checkpoint and early stopping callbacks checkpoint = CheckpointCallback(path=args.model_path) early_stopping = EarlyStopping(patience=3, min_delta=0.01) # val_loader indices = torch.randperm(len(test_dataset))[:int(0.1 * len(test_dataset))] val_set = Subset(test_dataset, indices) val_loader = DataLoader(val_set, batch_size=32, shuffle=False) # # Initialize wandb # config = { # "epochs": args.epochs, # "learning_rate": args.learning_rate, # "model": "EfficientNetV2-S", # "dataset": "Food-101" # } # run = initialize_wandb("Food101_Classification", args.log_run_name, config) # # Train the model # train_model(run, model, val_loader, val_loader, loss_fn, optimizer, device, epochs=args.epochs, checkpoint=checkpoint, early_stopping=early_stopping) # # Evaluate the model evaluate_model(model, test_loader, loss_fn, device)