Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | |
| from src.models.model import build_model | |
| from src.data.loader import get_cross_dataset_loaders | |
| def evaluate(model, loader, device, criterion): | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.float().unsqueeze(1).to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| preds = (torch.sigmoid(outputs) >= 0.5).float() | |
| all_preds.extend(preds.cpu().numpy().flatten()) | |
| all_labels.extend(labels.cpu().numpy().flatten()) | |
| avg_loss = total_loss / len(loader) | |
| acc = accuracy_score(all_labels, all_preds) | |
| prec = precision_score(all_labels, all_preds) | |
| rec = recall_score(all_labels, all_preds) | |
| f1 = f1_score(all_labels, all_preds) | |
| cm = confusion_matrix(all_labels, all_preds) | |
| return avg_loss, acc, prec, rec, f1, cm | |
| def train(epochs=10, batch_size=32, lr=1e-4): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| train_loader, test_loader = get_cross_dataset_loaders(batch_size=batch_size) | |
| model = build_model().to(device) | |
| for name, param in model.named_parameters(): | |
| if "layer4" in name or "fc" in name: | |
| param.requires_grad = True | |
| criterion = nn.BCEWithLogitsLoss() | |
| optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) | |
| scheduler = ReduceLROnPlateau(optimizer, patience=2) | |
| best_loss = float("inf") | |
| early_stop_patience = 3 | |
| no_improve_count = 0 | |
| for epoch in range(epochs): | |
| model.train() | |
| train_loss, correct, total = 0, 0, 0 | |
| for images, labels in train_loader: | |
| images = images.to(device) | |
| labels = labels.float().unsqueeze(1).to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| preds = (torch.sigmoid(outputs) >= 0.5).float() | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_acc = correct / total | |
| avg_train_loss = train_loss / len(train_loader) | |
| # Validate on unseen generators | |
| val_loss, acc, prec, rec, f1, _ = evaluate(model, test_loader, device, criterion) | |
| scheduler.step(val_loss) | |
| print(f"Epoch {epoch+1}/{epochs} | " | |
| f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | " | |
| f"Unseen Test Acc: {acc:.4f} | F1: {f1:.4f}") | |
| if val_loss < best_loss: | |
| best_loss = val_loss | |
| no_improve_count = 0 | |
| torch.save(model.state_dict(), "saved_models/cross_val_best.pth") | |
| print(f" -> Best model saved") | |
| else: | |
| no_improve_count += 1 | |
| if no_improve_count >= early_stop_patience: | |
| print(f"Early stopping at epoch {epoch+1}") | |
| break | |
| # Final evaluation | |
| print("\n--- Final Evaluation on Unseen Generators ---") | |
| _, acc, prec, rec, f1, cm = evaluate(model, test_loader, device, criterion) | |
| print(f"Accuracy: {acc:.4f}") | |
| print(f"Precision: {prec:.4f}") | |
| print(f"Recall: {rec:.4f}") | |
| print(f"F1 Score: {f1:.4f}") | |
| print(f"Confusion Matrix:\n{cm}") | |
| if __name__ == "__main__": | |
| train() |