saudi-date-classifier / src /train_efficientnet.py
Rashidbm
Initial deployment
6276d4c
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from tqdm import tqdm
from src.dataset import create_dataloaders
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#training loop
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels, _ in tqdm(dataloader):
images = images.to(device)
labels = labels.to(device).long()
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
preds = torch.argmax(outputs, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = 100.0 * correct / total
return avg_loss, accuracy
#validation loop
def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels, _ in dataloader:
images = images.to(device)
labels = labels.to(device).long()
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
preds = torch.argmax(outputs, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = 100.0 * correct / total
return avg_loss, accuracy
def plot_curves(train_losses, val_losses, train_accuracies, val_accuracies):
epochs_done = len(train_losses)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs_done + 1), train_losses, marker="o", label="Train Loss")
plt.plot(range(1, epochs_done + 1), val_losses, marker="o", label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs_done + 1), train_accuracies, marker="o", label="Train Accuracy")
plt.plot(range(1, epochs_done + 1), val_accuracies, marker="o", label="Val Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy Curve")
plt.legend()
plt.show()
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
train_loader, val_loader, test_loader, class_names = create_dataloaders()
num_classes = len(class_names)
print("Number of classes:", num_classes)
print("Classes:", class_names)
weights = EfficientNet_B0_Weights.DEFAULT
model = efficientnet_b0(weights=weights)
#freezing the feature extractor and modifying the classifier
for param in model.features.parameters():
param.requires_grad = False
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(in_features, num_classes)
)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
# Phase 1: train classifier only
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3,
weight_decay=1e-4
)
epochs_phase1 = 10
for epoch in range(epochs_phase1):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)
train_losses.append(train_loss)
val_losses.append(val_loss)
train_accuracies.append(train_acc)
val_accuracies.append(val_acc)
print(
f"[Phase 1] Epoch {epoch + 1}/{epochs_phase1} | "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%"
)
# Unfreeze all feature layers for full fine-tuning
for param in model.features.parameters():
param.requires_grad = True
optimizer = optim.AdamW(
model.parameters(),
lr=1e-5,
weight_decay=1e-4
)
epochs_phase2 = 20
for epoch in range(epochs_phase2):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)
train_losses.append(train_loss)
val_losses.append(val_loss)
train_accuracies.append(train_acc)
val_accuracies.append(val_acc)
print(
f"[Phase 2] Epoch {epoch + 1}/{epochs_phase2} | "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%"
)
plot_curves(train_losses, val_losses, train_accuracies, val_accuracies)