import os import time from pathlib import Path import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms, models # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- BASE_DIR = Path(__file__).resolve().parents[1] CROPS_DIR = BASE_DIR / "data" / "crops" MODELS_DIR = BASE_DIR / "models" MODELS_DIR.mkdir(exist_ok=True) MODEL_PATH = MODELS_DIR / "slot_classifier.pth" # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- IMG_SIZE = 224 BATCH_SIZE = 32 EPOCHS = 10 LR = 1e-4 VAL_SPLIT = 0.2 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --------------------------------------------------------------------------- # Transforms # --------------------------------------------------------------------------- train_transforms = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) val_transforms = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- def build_loaders(): full_dataset = datasets.ImageFolder(CROPS_DIR) print(f"Classes : {full_dataset.classes}") print(f"Total : {len(full_dataset)} images") val_size = int(len(full_dataset) * VAL_SPLIT) train_size = len(full_dataset) - val_size train_ds, val_ds = random_split( full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42), ) # Apply correct transforms to each split train_ds.dataset.transform = train_transforms val_ds.dataset.transform = val_transforms train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) print(f"Train : {train_size} | Val : {val_size}") return train_loader, val_loader, full_dataset.classes # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- def build_model(num_classes: int = 2) -> nn.Module: model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) # Freeze all base layers for param in model.parameters(): param.requires_grad = False # Replace classifier head in_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(in_features, num_classes), ) return model.to(DEVICE) # --------------------------------------------------------------------------- # Train / Eval loops # --------------------------------------------------------------------------- def train_one_epoch(model, loader, criterion, optimizer): model.train() total_loss = 0.0 correct = 0 total = 0 for images, labels in loader: images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) return total_loss / total, correct / total def evaluate(model, loader, criterion): model.eval() total_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in loader: images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() * images.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) return total_loss / total, correct / total # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): print(f"Device : {DEVICE}") print(f"Epochs : {EPOCHS} | Batch : {BATCH_SIZE} | LR : {LR}\n") train_loader, val_loader, classes = build_loaders() model = build_model(num_classes=len(classes)) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=LR, ) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.5) best_val_acc = 0.0 for epoch in range(1, EPOCHS + 1): t0 = time.time() train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer) val_loss, val_acc = evaluate(model, val_loader, criterion) scheduler.step() elapsed = time.time() - t0 marker = " *" if val_acc > best_val_acc else "" print( f"Epoch {epoch:02d}/{EPOCHS} | " f"Train loss {train_loss:.4f} acc {train_acc:.4f} | " f"Val loss {val_loss:.4f} acc {val_acc:.4f} | " f"{elapsed:.1f}s{marker}" ) if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ "epoch" : epoch, "model_state": model.state_dict(), "classes" : classes, "val_acc" : val_acc, }, MODEL_PATH) print(f"\nBest val accuracy : {best_val_acc:.4f}") print(f"Model saved to : {MODEL_PATH}") if __name__ == "__main__": main()