NagashreePai commited on
Commit
0a8e777
Β·
verified Β·
1 Parent(s): 561c3a7

Upload weed_test.py

Browse files
Files changed (1) hide show
  1. weed_test.py +135 -0
weed_test.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.datasets import ImageFolder
6
+ from torch.utils.data import DataLoader
7
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, roc_curve, auc
8
+ from sklearn.preprocessing import label_binarize
9
+ import seaborn as sns
10
+ import matplotlib.pyplot as plt
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ from PIL import Image
14
+ from torchvision.models import swin_t
15
+ import matplotlib
16
+ matplotlib.use("Agg") # Use non-interactive backend
17
+
18
+ # βœ… MMIM model definition (must match training script)
19
+ class MMIM(nn.Module):
20
+ def __init__(self, num_classes=9):
21
+ super(MMIM, self).__init__()
22
+ self.backbone = swin_t(weights='IMAGENET1K_V1')
23
+ self.backbone.head = nn.Identity()
24
+ self.classifier = nn.Sequential(
25
+ nn.Linear(768, 512),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.3),
28
+ nn.Linear(512, num_classes)
29
+ )
30
+
31
+ def forward(self, x):
32
+ features = self.backbone(x)
33
+ return self.classifier(features)
34
+
35
+ # βœ… Config
36
+ model_path = 'MMIM_best.pth'
37
+ test_dir = 'test'
38
+ batch_size = 32
39
+
40
+ # βœ… Transforms
41
+ transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor()
44
+ ])
45
+
46
+ # βœ… Load test dataset
47
+ test_dataset = ImageFolder(test_dir, transform=transform)
48
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
49
+ class_names = test_dataset.classes
50
+
51
+ # βœ… Load model
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ model = MMIM(num_classes=len(class_names)).to(device)
54
+ model.load_state_dict(torch.load(model_path, map_location=device))
55
+ model.eval()
56
+
57
+ # βœ… Evaluate on test set
58
+ all_preds = []
59
+ all_labels = []
60
+ all_probs = []
61
+
62
+ with torch.no_grad():
63
+ for images, labels in tqdm(test_loader, desc="πŸ” Evaluating"):
64
+ images, labels = images.to(device), labels.to(device)
65
+ outputs = model(images)
66
+ probs = torch.nn.functional.softmax(outputs, dim=1)
67
+ _, preds = torch.max(probs, 1)
68
+
69
+ all_probs.extend(probs.cpu().numpy())
70
+ all_preds.extend(preds.cpu().numpy())
71
+ all_labels.extend(labels.cpu().numpy())
72
+
73
+ # βœ… Metrics
74
+ acc = accuracy_score(all_labels, all_preds)
75
+ f1 = f1_score(all_labels, all_preds, average='weighted')
76
+ cm = confusion_matrix(all_labels, all_preds)
77
+
78
+ print(f"\nβœ… Accuracy: {acc:.4f}")
79
+ print(f"🎯 F1 Score (weighted): {f1:.4f}")
80
+ print("\nπŸ“ Classification Report:\n")
81
+ print(classification_report(all_labels, all_preds, target_names=class_names))
82
+
83
+ # βœ… Plot confusion matrix
84
+ plt.figure(figsize=(10, 8))
85
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Purples',
86
+ xticklabels=class_names,
87
+ yticklabels=class_names)
88
+ plt.xlabel("Predicted")
89
+ plt.ylabel("True")
90
+ plt.title("Confusion Matrix")
91
+ plt.tight_layout()
92
+ plt.savefig("confusion_matrix.png")
93
+ print("βœ… Confusion matrix saved as confusion_matrix.png")
94
+
95
+ # βœ… ROC Curve Plotting
96
+ y_true = label_binarize(all_labels, classes=list(range(len(class_names))))
97
+ all_probs = np.array(all_probs)
98
+
99
+ fpr = dict()
100
+ tpr = dict()
101
+ roc_auc = dict()
102
+
103
+ for i in range(len(class_names)):
104
+ fpr[i], tpr[i], _ = roc_curve(y_true[:, i], all_probs[:, i])
105
+ roc_auc[i] = auc(fpr[i], tpr[i])
106
+
107
+ plt.figure(figsize=(10, 8))
108
+ for i in range(len(class_names)):
109
+ plt.plot(fpr[i], tpr[i], lw=2, label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')
110
+
111
+ plt.plot([0, 1], [0, 1], 'k--', lw=1)
112
+ plt.xlim([0.0, 1.0])
113
+ plt.ylim([0.0, 1.05])
114
+ plt.xlabel('False Positive Rate')
115
+ plt.ylabel('True Positive Rate')
116
+ plt.title('Multi-Class ROC Curve')
117
+ plt.legend(loc="lower right")
118
+ plt.tight_layout()
119
+ plt.savefig("roc_curve.png")
120
+ print("βœ… ROC curve saved as roc_curve.png")
121
+
122
+ # βœ… Predict a single image
123
+ def predict_image(image_path):
124
+ image = Image.open(image_path).convert('RGB')
125
+ image = transform(image).unsqueeze(0).to(device)
126
+ model.eval()
127
+ with torch.no_grad():
128
+ output = model(image)
129
+ _, predicted = torch.max(output, 1)
130
+ return class_names[predicted.item()]
131
+
132
+ # Example usage
133
+ example_image = os.path.join(test_dir, class_names[0], os.listdir(os.path.join(test_dir, class_names[0]))[0])
134
+ print(f"\nπŸ–ΌοΈ Example image prediction: {example_image}")
135
+ print("πŸ‘‰ Predicted class:", predict_image(example_image))