smart-parking-api / src /train_classifier.py
rohanv56's picture
Upload 380 files
2d30bad verified
Raw
History Blame Contribute Delete
6.53 kB
import os
import time
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
BASE_DIR = Path(__file__).resolve().parents[1]
CROPS_DIR = BASE_DIR / "data" / "crops"
MODELS_DIR = BASE_DIR / "models"
MODELS_DIR.mkdir(exist_ok=True)
MODEL_PATH = MODELS_DIR / "slot_classifier.pth"
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 10
LR = 1e-4
VAL_SPLIT = 0.2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------------------------------------------------------
# Transforms
# ---------------------------------------------------------------------------
train_transforms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
val_transforms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
def build_loaders():
full_dataset = datasets.ImageFolder(CROPS_DIR)
print(f"Classes : {full_dataset.classes}")
print(f"Total : {len(full_dataset)} images")
val_size = int(len(full_dataset) * VAL_SPLIT)
train_size = len(full_dataset) - val_size
train_ds, val_ds = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42),
)
# Apply correct transforms to each split
train_ds.dataset.transform = train_transforms
val_ds.dataset.transform = val_transforms
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f"Train : {train_size} | Val : {val_size}")
return train_loader, val_loader, full_dataset.classes
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
def build_model(num_classes: int = 2) -> nn.Module:
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
# Freeze all base layers
for param in model.parameters():
param.requires_grad = False
# Replace classifier head
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, num_classes),
)
return model.to(DEVICE)
# ---------------------------------------------------------------------------
# Train / Eval loops
# ---------------------------------------------------------------------------
def train_one_epoch(model, loader, criterion, optimizer):
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return total_loss / total, correct / total
def evaluate(model, loader, criterion):
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return total_loss / total, correct / total
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print(f"Device : {DEVICE}")
print(f"Epochs : {EPOCHS} | Batch : {BATCH_SIZE} | LR : {LR}\n")
train_loader, val_loader, classes = build_loaders()
model = build_model(num_classes=len(classes))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=LR,
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.5)
best_val_acc = 0.0
for epoch in range(1, EPOCHS + 1):
t0 = time.time()
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
val_loss, val_acc = evaluate(model, val_loader, criterion)
scheduler.step()
elapsed = time.time() - t0
marker = " *" if val_acc > best_val_acc else ""
print(
f"Epoch {epoch:02d}/{EPOCHS} | "
f"Train loss {train_loss:.4f} acc {train_acc:.4f} | "
f"Val loss {val_loss:.4f} acc {val_acc:.4f} | "
f"{elapsed:.1f}s{marker}"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"epoch" : epoch,
"model_state": model.state_dict(),
"classes" : classes,
"val_acc" : val_acc,
}, MODEL_PATH)
print(f"\nBest val accuracy : {best_val_acc:.4f}")
print(f"Model saved to : {MODEL_PATH}")
if __name__ == "__main__":
main()