NagashreePai commited on
Commit
40f2a01
Β·
verified Β·
1 Parent(s): e88ab31

Upload weed_test.py

Browse files
Files changed (1) hide show
  1. weed_test.py +105 -0
weed_test.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.datasets import ImageFolder
6
+ from torch.utils.data import DataLoader
7
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
8
+ import seaborn as sns
9
+ import matplotlib.pyplot as plt
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from PIL import Image
13
+ from torchvision.models import swin_t
14
+ import matplotlib
15
+ matplotlib.use("Agg") # Use non-interactive backend
16
+
17
+
18
+ # βœ… MMIM model definition (must match training script)
19
+ class MMIM(nn.Module):
20
+ def __init__(self, num_classes=9):
21
+ super(MMIM, self).__init__()
22
+ self.backbone = swin_t(weights='IMAGENET1K_V1')
23
+ self.backbone.head = nn.Identity()
24
+ self.classifier = nn.Sequential(
25
+ nn.Linear(768, 512),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.3),
28
+ nn.Linear(512, num_classes)
29
+ )
30
+
31
+ def forward(self, x):
32
+ features = self.backbone(x)
33
+ return self.classifier(features)
34
+
35
+ # βœ… Config
36
+ model_path = 'MMIM_best3.pth' # or full path like '/home/student/Desktop/wt/MMIM_best.pth'
37
+ test_dir = 'test' # or full path if needed
38
+ batch_size = 32
39
+
40
+ # βœ… Transforms (same as training)
41
+ transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor()
44
+ ])
45
+
46
+ # βœ… Load test dataset
47
+ test_dataset = ImageFolder(test_dir, transform=transform)
48
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
49
+ class_names = test_dataset.classes
50
+
51
+ # βœ… Load trained model
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ model = MMIM(num_classes=len(class_names)).to(device)
54
+ model.load_state_dict(torch.load(model_path, map_location=device))
55
+ model.eval()
56
+
57
+ # βœ… Evaluate on test set
58
+ all_preds = []
59
+ all_labels = []
60
+
61
+ with torch.no_grad():
62
+ for images, labels in tqdm(test_loader, desc="πŸ” Evaluating"):
63
+ images, labels = images.to(device), labels.to(device)
64
+ outputs = model(images)
65
+ _, preds = torch.max(outputs, 1)
66
+ all_preds.extend(preds.cpu().numpy())
67
+ all_labels.extend(labels.cpu().numpy())
68
+
69
+ # βœ… Metrics
70
+ acc = accuracy_score(all_labels, all_preds)
71
+ f1 = f1_score(all_labels, all_preds, average='weighted')
72
+ cm = confusion_matrix(all_labels, all_preds)
73
+
74
+ print(f"\nβœ… Accuracy: {acc:.4f}")
75
+ print(f"🎯 F1 Score (weighted): {f1:.4f}")
76
+ print("\nπŸ“ Classification Report:\n")
77
+ print(classification_report(all_labels, all_preds, target_names=class_names))
78
+
79
+ # βœ… Plot confusion matrix
80
+ plt.figure(figsize=(10, 8))
81
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
82
+ xticklabels=class_names,
83
+ yticklabels=class_names)
84
+ plt.xlabel("Predicted")
85
+ plt.ylabel("True")
86
+ plt.title("Confusion Matrix")
87
+ plt.tight_layout()
88
+ plt.savefig("confusion_matrix.png")
89
+ print("βœ… Confusion matrix saved as confusion_matrix.png")
90
+
91
+
92
+ # βœ… Predict a single image
93
+ def predict_image(image_path):
94
+ image = Image.open(image_path).convert('RGB')
95
+ image = transform(image).unsqueeze(0).to(device)
96
+ model.eval()
97
+ with torch.no_grad():
98
+ output = model(image)
99
+ _, predicted = torch.max(output, 1)
100
+ return class_names[predicted.item()]
101
+
102
+ # Example usage:
103
+ example_image = os.path.join(test_dir, class_names[0], os.listdir(os.path.join(test_dir, class_names[0]))[0])
104
+ print(f"\nπŸ–ΌοΈ Example image prediction: {example_image}")
105
+ print("πŸ‘‰ Predicted class:", predict_image(example_image))