File size: 3,871 Bytes
957df8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""
Test script to verify the diabetic retinopathy detection model can be loaded correctly.
Run this before starting the main app to check for any issues.
"""

import torch
import torch.nn as nn
from torchvision import models
import os

def test_model_loading():
    """Test if the model can be loaded successfully."""
    print("πŸ” Testing model loading...")
    
    try:
        # Check if model file exists
        model_path = "resnet50_dr_classifier.pth"
        if not os.path.exists(model_path):
            print(f"❌ Error: Model file '{model_path}' not found!")
            print("   Please ensure the model file is in the current directory.")
            return False
        
        print(f"βœ… Model file found: {model_path}")
        
        # Check file size
        file_size = os.path.getsize(model_path) / (1024 * 1024)  # MB
        print(f"πŸ“ Model file size: {file_size:.2f} MB")
        
        # Try to load the model
        device = torch.device("cpu")
        print("πŸ”„ Loading model...")
        
        model = models.resnet50(weights=None)
        model.fc = nn.Linear(model.fc.in_features, 2)
        
        # Load state dict
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()
        
        print("βœ… Model loaded successfully!")
        
        # Test with dummy input
        print("πŸ§ͺ Testing with dummy input...")
        dummy_input = torch.randn(1, 3, 224, 224)
        
        with torch.no_grad():
            output = model(dummy_input)
            probs = torch.nn.functional.softmax(output, dim=1)
            pred = torch.argmax(probs, dim=1).item()
            confidence = probs[0][pred].item()
        
        print(f"βœ… Model inference successful!")
        print(f"   Prediction: {'DR' if pred == 0 else 'NoDR'}")
        print(f"   Confidence: {confidence:.4f}")
        
        return True
        
    except Exception as e:
        print(f"❌ Error loading model: {str(e)}")
        return False

def test_dependencies():
    """Test if required packages are available."""
    print("\nπŸ” Testing dependencies...")
    
    try:
        import gradio
        print("βœ… Gradio imported successfully")
    except ImportError:
        print("❌ Gradio not found. Install with: pip install gradio")
        return False
    
    try:
        import PIL
        print("βœ… PIL/Pillow imported successfully")
    except ImportError:
        print("❌ PIL/Pillow not found. Install with: pip install pillow")
        return False
    
    try:
        import numpy
        print("βœ… NumPy imported successfully")
    except ImportError:
        print("❌ NumPy not found. Install with: pip install numpy")
        return False
    
    try:
        from pytorch_grad_cam import GradCAM
        print("βœ… PyTorch Grad-CAM imported successfully")
    except ImportError:
        print("❌ PyTorch Grad-CAM not found. Install with: pip install pytorch-grad-cam")
        return False
    
    return True

def main():
    """Main test function."""
    print("πŸš€ Diabetic Retinopathy Detection - Model Test")
    print("=" * 50)
    
    # Test dependencies
    deps_ok = test_dependencies()
    
    if not deps_ok:
        print("\n❌ Dependency test failed. Please install missing packages.")
        return
    
    # Test model loading
    model_ok = test_model_loading()
    
    print("\n" + "=" * 50)
    if model_ok:
        print("πŸŽ‰ All tests passed! The app should work correctly.")
        print("   You can now run: python app.py")
    else:
        print("❌ Model test failed. Please check the error messages above.")
        print("   Make sure the model file is correct and all dependencies are installed.")

if __name__ == "__main__":
    main()