Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torchvision import models | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| from src.data.generator_loader import get_generator_dataloaders, CLASS_NAMES | |
| def build_multiclass_model(num_classes=4, pretrained=True): | |
| model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze layer4 | |
| for name, param in model.named_parameters(): | |
| if "layer4" in name or "fc" in name: | |
| param.requires_grad = True | |
| in_features = model.fc.in_features | |
| model.fc = nn.Sequential( | |
| nn.Dropout(0.5), | |
| nn.Linear(in_features, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) | |
| ) | |
| return model | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| outputs = model(images) | |
| preds = outputs.argmax(dim=1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| return all_preds, all_labels | |
| def train(epochs=15, batch_size=32, lr=1e-4): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| train_loader, val_loader, test_loader = get_generator_dataloaders(batch_size=batch_size) | |
| model = build_multiclass_model().to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |
| scheduler = ReduceLROnPlateau(optimizer, patience=2) | |
| best_val_acc = 0 | |
| early_stop_patience = 3 | |
| no_improve_count = 0 | |
| for epoch in range(epochs): | |
| model.train() | |
| train_loss, correct, total = 0, 0, 0 | |
| for images, labels in train_loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_acc = correct / total | |
| avg_train_loss = train_loss / len(train_loader) | |
| val_preds, val_labels = evaluate(model, val_loader, device) | |
| val_acc = sum(p == l for p, l in zip(val_preds, val_labels)) / len(val_labels) | |
| scheduler.step(1 - val_acc) | |
| print(f"Epoch {epoch+1}/{epochs} | " | |
| f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | " | |
| f"Val Acc: {val_acc:.4f}") | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| no_improve_count = 0 | |
| torch.save(model.state_dict(), "saved_models/generator_model.pth") | |
| print(f" -> Best model saved") | |
| else: | |
| no_improve_count += 1 | |
| if no_improve_count >= early_stop_patience: | |
| print(f"Early stopping at epoch {epoch+1}") | |
| break | |
| # Final evaluation | |
| print("\n--- Final Evaluation ---") | |
| test_preds, test_labels = evaluate(model, test_loader, device) | |
| test_acc = sum(p == l for p, l in zip(test_preds, test_labels)) / len(test_labels) | |
| print(f"Test Accuracy: {test_acc:.4f}") | |
| print("\nClassification Report:") | |
| print(classification_report(test_labels, test_preds, | |
| target_names=list(CLASS_NAMES.values()))) | |
| print("Confusion Matrix:") | |
| print(confusion_matrix(test_labels, test_preds)) | |
| if __name__ == "__main__": | |
| train() |