|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchinfo import summary |
|
|
import torchvision |
|
|
import torchvision.transforms as T |
|
|
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights |
|
|
from torch.utils.data import Subset, DataLoader |
|
|
import wandb |
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model_weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 |
|
|
model = efficientnet_v2_s(weights=model_weights).to(device) |
|
|
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=101).to(device) |
|
|
|
|
|
|
|
|
dataset = torchvision.datasets.Food101(root='./data', split='train', download=True) |
|
|
|
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
|
test_size = len(dataset) - train_size |
|
|
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) |
|
|
|
|
|
|
|
|
|
|
|
train_transforms = T.Compose([ |
|
|
T.RandomResizedCrop(224), |
|
|
T.RandomHorizontalFlip(), |
|
|
T.ColorJitter(0.2, 0.2, 0.2, 0.1), |
|
|
model_weights.transforms() |
|
|
]) |
|
|
|
|
|
test_transforms = T.Compose([ |
|
|
model_weights.transforms() |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataset.dataset.transform = train_transforms |
|
|
test_dataset.dataset.transform = test_transforms |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=16, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
persistent_workers=True |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=16, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
persistent_workers=True |
|
|
) |
|
|
|
|
|
|
|
|
def save_checkpoint(epoch, model, optimizer, val_loss, path="checkpoints/best_model.pth"): |
|
|
|
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'val_loss': val_loss |
|
|
}, path) |
|
|
|
|
|
print(f"Checkpoint saved at epoch {epoch} to {path}") |
|
|
|
|
|
class CheckpointCallback: |
|
|
def __init__(self, path="checkpoints/best_model.pth"): |
|
|
self.best_loss = float('inf') |
|
|
self.path = path |
|
|
|
|
|
def __call__(self, epoch, model, optimizer, val_loss): |
|
|
if val_loss < self.best_loss: |
|
|
self.best_loss = val_loss |
|
|
save_checkpoint(epoch, model, optimizer, val_loss, self.path) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
class EarlyStopping: |
|
|
def __init__(self, patience=3, min_delta=0.0): |
|
|
self.patience = patience |
|
|
self.min_delta = min_delta |
|
|
self.counter = 0 |
|
|
self.best_loss = float('inf') |
|
|
self.early_stop = False |
|
|
|
|
|
def __call__(self, val_loss): |
|
|
if val_loss < self.best_loss - self.min_delta: |
|
|
self.best_loss = val_loss |
|
|
self.counter = 0 |
|
|
else: |
|
|
self.counter += 1 |
|
|
if self.counter >= self.patience: |
|
|
self.early_stop = True |
|
|
|
|
|
|
|
|
|
|
|
def train_model(run, model, train_loader, val_loader, loss_fn, optimizer, device, epochs=5, checkpoint=None, early_stopping=None): |
|
|
|
|
|
global_step = 0 |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
train_loss = 0.0 |
|
|
train_accuracy = 0.0 |
|
|
model.train() |
|
|
|
|
|
for images, labels in train_loader: |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
y_preds = model(images) |
|
|
loss = loss_fn(y_preds, labels) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
if global_step % 1 == 0: |
|
|
run.log({ |
|
|
"train/loss": loss.item() |
|
|
}, step=global_step) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
train_loss += loss.item() * labels.size(0) |
|
|
train_accuracy += (y_preds.argmax(dim=1) == labels).sum().item() |
|
|
|
|
|
train_loss /= len(train_loader.dataset) |
|
|
train_accuracy /= len(train_loader.dataset) |
|
|
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {train_loss:.4f} | Accuracy: {train_accuracy:.4f}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
val_accuracy = 0.0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for images, labels in val_loader: |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
y_preds = model(images) |
|
|
loss = loss_fn(y_preds, labels) |
|
|
val_loss += loss.item() * images.size(0) |
|
|
val_accuracy += (y_preds.argmax(dim=1) == labels).sum().item() |
|
|
|
|
|
val_loss /= len(val_loader.dataset) |
|
|
val_accuracy /= len(val_loader.dataset) |
|
|
print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}") |
|
|
|
|
|
run.log({ |
|
|
"val/loss": val_loss, |
|
|
"val/accuracy": val_accuracy, |
|
|
"train/accuracy": train_accuracy, |
|
|
"epoch": epoch + 1, |
|
|
}, step=global_step) |
|
|
|
|
|
|
|
|
|
|
|
if checkpoint: |
|
|
checkpoint(epoch, model, optimizer, val_loss) |
|
|
|
|
|
if early_stopping: |
|
|
early_stopping(val_loss) |
|
|
if early_stopping.early_stop: |
|
|
print("Early stopping triggered") |
|
|
break |
|
|
|
|
|
run.finish() |
|
|
|
|
|
|
|
|
def evaluate_model(model, test_loader, loss_fn, device): |
|
|
model.eval() |
|
|
test_loss = 0.0 |
|
|
test_accuracy = 0.0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for images, labels in test_loader: |
|
|
images, labels = images.to(device), labels.to(device) |
|
|
y_preds = model(images) |
|
|
loss = loss_fn(y_preds, labels) |
|
|
test_loss += loss.item() * images.size(0) |
|
|
test_accuracy += (y_preds.argmax(dim=1) == labels).sum().item() |
|
|
|
|
|
test_loss /= len(test_loader.dataset) |
|
|
test_accuracy /= len(test_loader.dataset) |
|
|
print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.4f}") |
|
|
|
|
|
|
|
|
def initialize_wandb(project_name, run_name, config): |
|
|
|
|
|
run = wandb.init( |
|
|
entity="i24106-code-i", |
|
|
project=project_name, |
|
|
name=run_name, |
|
|
config=config |
|
|
) |
|
|
|
|
|
return run |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Train EfficientNetV2-S on Food-101 dataset") |
|
|
parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs") |
|
|
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for optimizer") |
|
|
parser.add_argument("--model_path", type=str, default="checkpoints/best_model.pth", help="Path to save the best model checkpoint") |
|
|
parser.add_argument("--log_run_name", type=str, default="EfficientNetV2S_Run", help="WandB run name") |
|
|
args = parser.parse_args() |
|
|
|
|
|
saved_model = torch.load(args.model_path, map_location=device) |
|
|
model.load_state_dict(saved_model['model_state_dict']) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
for p in model.features.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
for p in model.features[-2:].parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
optimizer = torch.optim.Adam([ |
|
|
{"params": model.features[-2:].parameters(), "lr": 1e-5}, |
|
|
{"params": model.classifier.parameters(), "lr": 1e-4}, |
|
|
], weight_decay=1e-4) |
|
|
|
|
|
|
|
|
checkpoint = CheckpointCallback(path=args.model_path) |
|
|
early_stopping = EarlyStopping(patience=3, min_delta=0.01) |
|
|
|
|
|
|
|
|
indices = torch.randperm(len(test_dataset))[:int(0.1 * len(test_dataset))] |
|
|
val_set = Subset(test_dataset, indices) |
|
|
val_loader = DataLoader(val_set, batch_size=32, shuffle=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluate_model(model, test_loader, loss_fn, device) |
|
|
|
|
|
|
|
|
|