import torch import torch.nn as nn from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau from torchvision import models from sklearn.metrics import classification_report, confusion_matrix from src.data.generator_loader import get_generator_dataloaders, CLASS_NAMES def build_multiclass_model(num_classes=4, pretrained=True): model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None) for param in model.parameters(): param.requires_grad = False # Unfreeze layer4 for name, param in model.named_parameters(): if "layer4" in name or "fc" in name: param.requires_grad = True in_features = model.fc.in_features model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(in_features, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) return model def evaluate(model, loader, device): model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for images, labels in loader: images = images.to(device) labels = labels.to(device) outputs = model(images) preds = outputs.argmax(dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) return all_preds, all_labels def train(epochs=15, batch_size=32, lr=1e-4): device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") train_loader, val_loader, test_loader = get_generator_dataloaders(batch_size=batch_size) model = build_multiclass_model().to(device) criterion = nn.CrossEntropyLoss() optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) scheduler = ReduceLROnPlateau(optimizer, patience=2) best_val_acc = 0 early_stop_patience = 3 no_improve_count = 0 for epoch in range(epochs): model.train() train_loss, correct, total = 0, 0, 0 for images, labels in train_loader: images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) train_acc = correct / total avg_train_loss = train_loss / len(train_loader) val_preds, val_labels = evaluate(model, val_loader, device) val_acc = sum(p == l for p, l in zip(val_preds, val_labels)) / len(val_labels) scheduler.step(1 - val_acc) print(f"Epoch {epoch+1}/{epochs} | " f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | " f"Val Acc: {val_acc:.4f}") if val_acc > best_val_acc: best_val_acc = val_acc no_improve_count = 0 torch.save(model.state_dict(), "saved_models/generator_model.pth") print(f" -> Best model saved") else: no_improve_count += 1 if no_improve_count >= early_stop_patience: print(f"Early stopping at epoch {epoch+1}") break # Final evaluation print("\n--- Final Evaluation ---") test_preds, test_labels = evaluate(model, test_loader, device) test_acc = sum(p == l for p, l in zip(test_preds, test_labels)) / len(test_labels) print(f"Test Accuracy: {test_acc:.4f}") print("\nClassification Report:") print(classification_report(test_labels, test_preds, target_names=list(CLASS_NAMES.values()))) print("Confusion Matrix:") print(confusion_matrix(test_labels, test_preds)) if __name__ == "__main__": train()