Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.datasets import ImageFolder | |
| from torch.utils.data import DataLoader | |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision.models import swin_t | |
| import matplotlib | |
| matplotlib.use("Agg") # Use non-interactive backend | |
| # β MMIM model definition (must match training script) | |
| class MMIM(nn.Module): | |
| def __init__(self, num_classes=9): | |
| super(MMIM, self).__init__() | |
| self.backbone = swin_t(weights='IMAGENET1K_V1') | |
| self.backbone.head = nn.Identity() | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| return self.classifier(features) | |
| # β Config | |
| model_path = 'MMIM_best.pth' # or full path like '/home/student/Desktop/wt/MMIM_best.pth' | |
| test_dir = 'test' # or full path if needed | |
| batch_size = 32 | |
| # β Transforms (same as training) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| # β Load test dataset | |
| test_dataset = ImageFolder(test_dir, transform=transform) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| class_names = test_dataset.classes | |
| # β Load trained model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = MMIM(num_classes=len(class_names)).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| # β Evaluate on test set | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for images, labels in tqdm(test_loader, desc="π Evaluating"): | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| _, preds = torch.max(outputs, 1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| # β Metrics | |
| acc = accuracy_score(all_labels, all_preds) | |
| f1 = f1_score(all_labels, all_preds, average='weighted') | |
| cm = confusion_matrix(all_labels, all_preds) | |
| print(f"\nβ Accuracy: {acc:.4f}") | |
| print(f"π― F1 Score (weighted): {f1:.4f}") | |
| print("\nπ Classification Report:\n") | |
| print(classification_report(all_labels, all_preds, target_names=class_names)) | |
| # β Plot confusion matrix | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=class_names, | |
| yticklabels=class_names) | |
| plt.xlabel("Predicted") | |
| plt.ylabel("True") | |
| plt.title("Confusion Matrix") | |
| plt.tight_layout() | |
| plt.savefig("confusion_matrix.png") | |
| print("β Confusion matrix saved as confusion_matrix.png") | |
| # β Predict a single image | |
| def predict_image(image_path): | |
| image = Image.open(image_path).convert('RGB') | |
| image = transform(image).unsqueeze(0).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(image) | |
| _, predicted = torch.max(output, 1) | |
| return class_names[predicted.item()] | |
| # Example usage: | |
| example_image = os.path.join(test_dir, class_names[0], os.listdir(os.path.join(test_dir, class_names[0]))[0]) | |
| print(f"\nπΌοΈ Example image prediction: {example_image}") | |
| print("π Predicted class:", predict_image(example_image)) | |