File size: 9,623 Bytes
c8df794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""
Model Testing Script for Crop Disease Detection
Test the trained model with sample images
"""

import torch
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
import sys
import os
import json
from pathlib import Path

# Add src to path
sys.path.append('src')

def load_model():
    """Load the trained model"""
    try:
        from src.model import CropDiseaseResNet50
        
        # Class names (updated for V3 model: Pepper, Potato, Tomato)
        class_names = [
            'Pepper__bell___Bacterial_spot',
            'Pepper__bell___healthy',
            'Potato___Early_blight',
            'Potato___healthy',
            'Potato___Late_blight',
            'Tomato__Target_Spot',
            'Tomato__Tomato_mosaic_virus',
            'Tomato__Tomato_YellowLeaf__Curl_Virus',
            'Tomato_Bacterial_spot',
            'Tomato_Early_blight',
            'Tomato_healthy',
            'Tomato_Late_blight',
            'Tomato_Leaf_Mold',
            'Tomato_Septoria_leaf_spot',
            'Tomato_Spider_mites_Two_spotted_spider_mite'
        ]
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        # Load the specified model
        model_path = 'models/crop_disease_v3_model.pth'
        
        model = None
        loaded_path = None
        
        if os.path.exists(model_path):
            try:
                print(f"Trying to load model from: {model_path}")
                
                # Create model
                model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=False)
                
                # Load checkpoint
                checkpoint = torch.load(model_path, map_location=device)
                
                # Handle checkpoint format from crop_disease_v3_model.pth
                if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                    # Use class names from checkpoint if available
                    if 'class_names' in checkpoint:
                        class_names = checkpoint['class_names']
                        print(f"Loaded class names from checkpoint: {len(class_names)} classes")
                elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = checkpoint
                
                # Load state dict
                model.load_state_dict(state_dict, strict=True)
                model.to(device)
                model.eval()
                print(f"βœ… Model loaded successfully from {model_path}")
                loaded_path = model_path
                
            except Exception as e:
                print(f"❌ Failed to load {model_path}: {e}")
                model = None
        else:
            print(f"❌ Model file not found: {model_path}")
            model = None
        
        if model is None:
            print("⚠️ No trained model found. Creating new model with pretrained weights.")
            model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=True)
            loaded_path = "pretrained_imagenet"
        
        model.to(device)
        model.eval()
        
        return model, class_names, device, loaded_path
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None, None, None, None

def create_test_image():
    """Create a simple test image"""
    # Create a green leaf-like image
    img = Image.new('RGB', (224, 224), color=(34, 139, 34))  # Forest green
    
    # Add some texture (simple pattern)
    from PIL import ImageDraw
    draw = ImageDraw.Draw(img)
    
    # Add some leaf-like patterns
    for i in range(0, 224, 20):
        draw.line([(i, 0), (i, 224)], fill=(0, 100, 0), width=1)
    for i in range(0, 224, 20):
        draw.line([(0, i), (224, i)], fill=(0, 100, 0), width=1)
    
    return img

def test_single_prediction(model, class_names, device):
    """Test single image prediction"""
    print("\nπŸ” Testing single image prediction...")
    
    # Create test image
    test_image = create_test_image()
    test_image.save('test_leaf_sample.jpg')
    print("βœ… Test image created: test_leaf_sample.jpg")
    
    # Preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Process image
    input_tensor = transform(test_image).unsqueeze(0).to(device)
    print(f"βœ… Input tensor shape: {input_tensor.shape}")
    
    # Make prediction
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)
        
        predicted_class = class_names[predicted_idx.item()]
        confidence_score = confidence.item()
    
    print(f"\nπŸ“Š Prediction Results:")
    print(f"   Predicted Class: {predicted_class}")
    print(f"   Confidence: {confidence_score:.2%}")
    
    # Show top 3 predictions
    top_probs, top_indices = torch.topk(probabilities[0], 3)
    print(f"\nπŸ† Top 3 Predictions:")
    for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
        print(f"   {i+1}. {class_names[idx]}: {prob.item():.2%}")
    
    return predicted_class, confidence_score

def test_model_components(model, device):
    """Test model components"""
    print("\nπŸ”§ Testing model components...")
    
    # Model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"βœ… Model Architecture: {model.__class__.__name__}")
    print(f"βœ… Total Parameters: {total_params:,}")
    print(f"βœ… Trainable Parameters: {trainable_params:,}")
    print(f"βœ… Device: {device}")
    
    # Test forward pass with random input
    try:
        dummy_input = torch.randn(1, 3, 224, 224).to(device)
        with torch.no_grad():
            output = model(dummy_input)
        print(f"βœ… Forward pass successful: Output shape {output.shape}")
        return True
    except Exception as e:
        print(f"❌ Forward pass failed: {e}")
        return False

def test_with_real_images():
    """Test with real images from dataset if available"""
    print("\nπŸ–ΌοΈ Testing with real dataset images...")
    
    # Look for test images in data folder
    test_dirs = [
        'data/test',
        'data/val',
        'data/train'
    ]
    
    test_images = []
    for test_dir in test_dirs:
        if os.path.exists(test_dir):
            for class_dir in os.listdir(test_dir):
                class_path = os.path.join(test_dir, class_dir)
                if os.path.isdir(class_path):
                    for img_file in os.listdir(class_path)[:2]:  # Take first 2 images
                        if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                            test_images.append({
                                'path': os.path.join(class_path, img_file),
                                'true_class': class_dir
                            })
            break  # Use first available directory
    
    if test_images:
        print(f"βœ… Found {len(test_images)} test images")
        return test_images[:5]  # Return first 5
    else:
        print("⚠️ No real test images found in data folder")
        return []

def main():
    """Main testing function"""
    print("πŸ§ͺ CROP DISEASE MODEL TESTING")
    print("=" * 50)
    
    # Load model
    model, class_names, device, model_path = load_model()
    
    if model is None:
        print("❌ Failed to load model. Cannot proceed with testing.")
        return False
    
    print(f"βœ… Model loaded from: {model_path}")
    print(f"βœ… Classes: {len(class_names)}")
    
    # Test model components
    if not test_model_components(model, device):
        print("❌ Model component test failed.")
        return False
    
    # Test single prediction
    try:
        predicted_class, confidence = test_single_prediction(model, class_names, device)
        print("βœ… Single prediction test passed")
    except Exception as e:
        print(f"❌ Single prediction test failed: {e}")
        return False
    
    # Test with real images if available
    real_images = test_with_real_images()
    if real_images:
        print(f"\n🎯 Testing with {len(real_images)} real images...")
        # You can add real image testing here if needed
    
    # Test risk assessment
    try:
        from src.risk_level import RiskLevelCalculator
        risk_calc = RiskLevelCalculator()
        risk = risk_calc.calculate_base_risk(predicted_class, confidence)
        print(f"βœ… Risk assessment test passed: {risk}")
    except Exception as e:
        print(f"⚠️ Risk assessment test failed: {e}")
    
    print("\n" + "=" * 50)
    print("πŸŽ‰ MODEL TESTING COMPLETED!")
    print("βœ… Your model is working and ready for use")
    print("\nπŸš€ Next steps:")
    print("   1. Start API server: cd api && python main.py")
    print("   2. Test API endpoints at http://localhost:8000/docs")
    print("   3. Upload real crop images for testing")
    
    return True

if __name__ == "__main__":
    success = main()
    if not success:
        print("\n❌ Testing failed. Check the errors above.")
        sys.exit(1)