FoodClassifier / model.py
htetooyan's picture
Initial commit
9a85b37 verified
import argparse
import torch # type: ignore
import torch.nn as nn # type: ignore
from torchinfo import summary # type: ignore
import torchvision # type: ignore
import torchvision.transforms as T # type: ignore
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights # type: ignore
from torch.utils.data import Subset, DataLoader # type: ignore
import wandb # type: ignore
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Load the pre-trained EfficientNetV2-S model
model_weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1
model = efficientnet_v2_s(weights=model_weights).to(device)
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=101).to(device)
# load dataset and create dataloaders here
dataset = torchvision.datasets.Food101(root='./data', split='train', download=True)
# Split the dataset into training and testing sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# transform functions
train_transforms = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ColorJitter(0.2, 0.2, 0.2, 0.1),
model_weights.transforms()
])
test_transforms = T.Compose([
model_weights.transforms()
])
# Apply transforms to datasets
train_dataset.dataset.transform = train_transforms
test_dataset.dataset.transform = test_transforms
# Create DataLoaders for training and testing sets
train_loader = DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
num_workers=2,
persistent_workers=True
)
test_loader = DataLoader(
test_dataset,
batch_size=16,
shuffle=False,
num_workers=2,
persistent_workers=True
)
# checkpoint callback
def save_checkpoint(epoch, model, optimizer, val_loss, path="checkpoints/best_model.pth"):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss
}, path)
print(f"Checkpoint saved at epoch {epoch} to {path}")
class CheckpointCallback:
def __init__(self, path="checkpoints/best_model.pth"):
self.best_loss = float('inf')
self.path = path
def __call__(self, epoch, model, optimizer, val_loss):
if val_loss < self.best_loss:
self.best_loss = val_loss
save_checkpoint(epoch, model, optimizer, val_loss, self.path)
return True
return False
# early stopping callback
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
# training function
def train_model(run, model, train_loader, val_loader, loss_fn, optimizer, device, epochs=5, checkpoint=None, early_stopping=None):
global_step = 0
model.to(device)
for epoch in range(epochs):
train_loss = 0.0
train_accuracy = 0.0
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
y_preds = model(images)
loss = loss_fn(y_preds, labels)
loss.backward()
optimizer.step()
if global_step % 1 == 0:
run.log({
"train/loss": loss.item()
}, step=global_step)
global_step += 1
train_loss += loss.item() * labels.size(0)
train_accuracy += (y_preds.argmax(dim=1) == labels).sum().item()
train_loss /= len(train_loader.dataset)
train_accuracy /= len(train_loader.dataset)
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {train_loss:.4f} | Accuracy: {train_accuracy:.4f}")
# validation phase
model.eval()
val_loss = 0.0
val_accuracy = 0.0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
y_preds = model(images)
loss = loss_fn(y_preds, labels)
val_loss += loss.item() * images.size(0)
val_accuracy += (y_preds.argmax(dim=1) == labels).sum().item()
val_loss /= len(val_loader.dataset)
val_accuracy /= len(val_loader.dataset)
print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}")
run.log({
"val/loss": val_loss,
"val/accuracy": val_accuracy,
"train/accuracy": train_accuracy,
"epoch": epoch + 1,
}, step=global_step)
# callbacks
if checkpoint:
checkpoint(epoch, model, optimizer, val_loss)
if early_stopping:
early_stopping(val_loss)
if early_stopping.early_stop:
print("Early stopping triggered")
break
run.finish()
# evaluation function
def evaluate_model(model, test_loader, loss_fn, device):
model.eval()
test_loss = 0.0
test_accuracy = 0.0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
y_preds = model(images)
loss = loss_fn(y_preds, labels)
test_loss += loss.item() * images.size(0)
test_accuracy += (y_preds.argmax(dim=1) == labels).sum().item()
test_loss /= len(test_loader.dataset)
test_accuracy /= len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.4f}")
# initalization for wandb
def initialize_wandb(project_name, run_name, config):
run = wandb.init(
entity="i24106-code-i",
project=project_name,
name=run_name,
config=config
)
return run
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train EfficientNetV2-S on Food-101 dataset")
parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for optimizer")
parser.add_argument("--model_path", type=str, default="checkpoints/best_model.pth", help="Path to save the best model checkpoint")
parser.add_argument("--log_run_name", type=str, default="EfficientNetV2S_Run", help="WandB run name")
args = parser.parse_args()
saved_model = torch.load(args.model_path, map_location=device)
model.load_state_dict(saved_model['model_state_dict'])
model.to(device)
# freeze all layers
for p in model.features.parameters():
p.requires_grad = False
# unfreeze last 2 blocks (tune N = 1,2,3)
for p in model.features[-2:].parameters():
p.requires_grad = True
# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam([
{"params": model.features[-2:].parameters(), "lr": 1e-5},
{"params": model.classifier.parameters(), "lr": 1e-4},
], weight_decay=1e-4)
# Create checkpoint and early stopping callbacks
checkpoint = CheckpointCallback(path=args.model_path)
early_stopping = EarlyStopping(patience=3, min_delta=0.01)
# val_loader
indices = torch.randperm(len(test_dataset))[:int(0.1 * len(test_dataset))]
val_set = Subset(test_dataset, indices)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)
# # Initialize wandb
# config = {
# "epochs": args.epochs,
# "learning_rate": args.learning_rate,
# "model": "EfficientNetV2-S",
# "dataset": "Food-101"
# }
# run = initialize_wandb("Food101_Classification", args.log_run_name, config)
# # Train the model
# train_model(run, model, val_loader, val_loader, loss_fn, optimizer, device, epochs=args.epochs, checkpoint=checkpoint, early_stopping=early_stopping)
# # Evaluate the model
evaluate_model(model, test_loader, loss_fn, device)