Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, random_split | |
| from torchvision import datasets, transforms, models | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| BASE_DIR = Path(__file__).resolve().parents[1] | |
| CROPS_DIR = BASE_DIR / "data" / "crops" | |
| MODELS_DIR = BASE_DIR / "models" | |
| MODELS_DIR.mkdir(exist_ok=True) | |
| MODEL_PATH = MODELS_DIR / "slot_classifier.pth" | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| IMG_SIZE = 224 | |
| BATCH_SIZE = 32 | |
| EPOCHS = 10 | |
| LR = 1e-4 | |
| VAL_SPLIT = 0.2 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --------------------------------------------------------------------------- | |
| # Transforms | |
| # --------------------------------------------------------------------------- | |
| train_transforms = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| val_transforms = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| # --------------------------------------------------------------------------- | |
| # Dataset | |
| # --------------------------------------------------------------------------- | |
| def build_loaders(): | |
| full_dataset = datasets.ImageFolder(CROPS_DIR) | |
| print(f"Classes : {full_dataset.classes}") | |
| print(f"Total : {len(full_dataset)} images") | |
| val_size = int(len(full_dataset) * VAL_SPLIT) | |
| train_size = len(full_dataset) - val_size | |
| train_ds, val_ds = random_split( | |
| full_dataset, | |
| [train_size, val_size], | |
| generator=torch.Generator().manual_seed(42), | |
| ) | |
| # Apply correct transforms to each split | |
| train_ds.dataset.transform = train_transforms | |
| val_ds.dataset.transform = val_transforms | |
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) | |
| print(f"Train : {train_size} | Val : {val_size}") | |
| return train_loader, val_loader, full_dataset.classes | |
| # --------------------------------------------------------------------------- | |
| # Model | |
| # --------------------------------------------------------------------------- | |
| def build_model(num_classes: int = 2) -> nn.Module: | |
| model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) | |
| # Freeze all base layers | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Replace classifier head | |
| in_features = model.classifier[1].in_features | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(0.3), | |
| nn.Linear(in_features, num_classes), | |
| ) | |
| return model.to(DEVICE) | |
| # --------------------------------------------------------------------------- | |
| # Train / Eval loops | |
| # --------------------------------------------------------------------------- | |
| def train_one_epoch(model, loader, criterion, optimizer): | |
| model.train() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| return total_loss / total, correct / total | |
| def evaluate(model, loader, criterion): | |
| model.eval() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| return total_loss / total, correct / total | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| print(f"Device : {DEVICE}") | |
| print(f"Epochs : {EPOCHS} | Batch : {BATCH_SIZE} | LR : {LR}\n") | |
| train_loader, val_loader, classes = build_loaders() | |
| model = build_model(num_classes=len(classes)) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=LR, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.5) | |
| best_val_acc = 0.0 | |
| for epoch in range(1, EPOCHS + 1): | |
| t0 = time.time() | |
| train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer) | |
| val_loss, val_acc = evaluate(model, val_loader, criterion) | |
| scheduler.step() | |
| elapsed = time.time() - t0 | |
| marker = " *" if val_acc > best_val_acc else "" | |
| print( | |
| f"Epoch {epoch:02d}/{EPOCHS} | " | |
| f"Train loss {train_loss:.4f} acc {train_acc:.4f} | " | |
| f"Val loss {val_loss:.4f} acc {val_acc:.4f} | " | |
| f"{elapsed:.1f}s{marker}" | |
| ) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save({ | |
| "epoch" : epoch, | |
| "model_state": model.state_dict(), | |
| "classes" : classes, | |
| "val_acc" : val_acc, | |
| }, MODEL_PATH) | |
| print(f"\nBest val accuracy : {best_val_acc:.4f}") | |
| print(f"Model saved to : {MODEL_PATH}") | |
| if __name__ == "__main__": | |
| main() |