| import torch
|
| from torch.utils.data import DataLoader
|
| from torchvision import transforms
|
| from models.resnet import resnet18
|
| from models.openmax import OpenMax
|
| from models.metamax import MetaMax
|
| from train import GameDataset
|
| from utils.data_stats import load_dataset_stats
|
| from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
|
| import os
|
| from pprint import pprint
|
|
|
| def test_models():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| mean, std = load_dataset_stats()
|
| transform = transforms.Compose([
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=mean, std=std)
|
| ])
|
|
|
|
|
| test_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
|
| test_loader = DataLoader(test_dataset, batch_size=400, shuffle=False, num_workers=4, pin_memory=True)
|
|
|
|
|
| model = resnet18(num_classes=20)
|
| checkpoint = torch.load('models/best_model_99.92_02.pth')
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| model = model.to(device)
|
| model.eval()
|
|
|
|
|
| try:
|
| openmax = torch.load('models/best_openmax_94.71_01.pth')
|
|
|
| print("Successfully loaded OpenMax and MetaMax models")
|
| except Exception as e:
|
| print(f"Error loading models: {e}")
|
| return
|
|
|
|
|
| print("\n=== Testing ResNet (Known Classes Only) ===")
|
| _, accuracy, errors = evaluate_known_classes(model, test_loader, torch.nn.CrossEntropyLoss(), device)
|
| print(f"Known Classes Accuracy: {accuracy:.2f}%")
|
| if errors:
|
| print("\nErrors in known classes:")
|
| pprint(errors)
|
|
|
|
|
| print("\n=== Testing ResNet + OpenMax ===")
|
| evaluate_openmax(openmax, model, test_loader, device, multiplier=0.5, fraction=0.2, verbose=True)
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == '__main__':
|
| test_models()
|
|
|