File size: 9,823 Bytes
957df8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
#!/usr/bin/env python3
"""
Batch processing script for diabetic retinopathy detection.
Processes multiple OCT images and saves results in a structured format.
"""

import os
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import csv
import datetime
from pathlib import Path

class BatchDRDetector:
    def __init__(self, model_path="resnet50_dr_classifier.pth"):
        """Initialize the batch detector with the trained model."""
        self.device = torch.device("cpu")
        self.model_path = model_path
        self.model = None
        self.cam = None
        self.transform = None
        self.output_dir = "batch_results"
        
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)
        
        self._load_model()
        self._setup_gradcam()
        self._setup_transforms()
    
    def _load_model(self):
        """Load the trained ResNet-50 model."""
        print("πŸ”„ Loading model...")
        try:
            self.model = models.resnet50(weights=None)
            self.model.fc = torch.nn.Linear(self.model.fc.in_features, 2)
            self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
            self.model.to(self.device)
            self.model.eval()
            print("βœ… Model loaded successfully!")
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def _setup_gradcam(self):
        """Setup Grad-CAM for visualization."""
        target_layer = self.model.layer4[-1]
        self.cam = GradCAM(model=self.model, target_layers=[target_layer])
        print("βœ… Grad-CAM setup complete!")
    
    def _setup_transforms(self):
        """Setup image preprocessing transforms."""
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def process_single_image(self, image_path):
        """Process a single image and return results."""
        try:
            # Load and preprocess image
            img = Image.open(image_path).convert("RGB")
            img_tensor = self.transform(img).unsqueeze(0).to(self.device)
            
            # Get prediction
            with torch.no_grad():
                output = self.model(img_tensor)
                probs = F.softmax(output, dim=1)
                pred = torch.argmax(probs, dim=1).item()
                confidence = probs[0][pred].item()
            
            # Generate Grad-CAM
            rgb_img_np = np.array(img.resize((224, 224))).astype(np.float32) / 255.0
            rgb_img_np = np.ascontiguousarray(rgb_img_np)
            grayscale_cam = self.cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
            cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
            
            # Determine label
            label = "DR" if pred == 0 else "NoDR"
            
            return {
                'image_path': image_path,
                'prediction': label,
                'confidence': confidence,
                'dr_probability': 1 - confidence if pred == 1 else confidence,
                'cam_image': cam_image,
                'status': 'success'
            }
            
        except Exception as e:
            return {
                'image_path': image_path,
                'prediction': 'ERROR',
                'confidence': 0.0,
                'dr_probability': 0.0,
                'cam_image': None,
                'status': f'error: {str(e)}'
            }
    
    def process_directory(self, input_dir, extensions=['.jpg', '.jpeg', '.png', '.tiff', '.bmp']):
        """Process all images in a directory."""
        print(f"πŸ” Scanning directory: {input_dir}")
        
        # Find all image files
        image_files = []
        for ext in extensions:
            image_files.extend(Path(input_dir).glob(f"*{ext}"))
            image_files.extend(Path(input_dir).glob(f"*{ext.upper()}"))
        
        if not image_files:
            print("❌ No image files found in the directory!")
            return []
        
        print(f"πŸ“ Found {len(image_files)} image files")
        
        # Process each image
        results = []
        for i, image_path in enumerate(image_files, 1):
            print(f"πŸ”„ Processing {i}/{len(image_files)}: {image_path.name}")
            result = self.process_single_image(str(image_path))
            results.append(result)
            
            # Save Grad-CAM image if successful
            if result['status'] == 'success' and result['cam_image'] is not None:
                cam_filename = f"cam_{Path(image_path).stem}_{result['prediction']}_{result['confidence']:.3f}.png"
                cam_path = os.path.join(self.output_dir, cam_filename)
                Image.fromarray(result['cam_image']).save(cam_path)
                result['cam_saved_path'] = cam_path
        
        return results
    
    def save_results_csv(self, results, filename=None):
        """Save results to a CSV file."""
        if not filename:
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"dr_results_{timestamp}.csv"
        
        csv_path = os.path.join(self.output_dir, filename)
        
        with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
            fieldnames = ['image_path', 'prediction', 'confidence', 'dr_probability', 'status', 'cam_saved_path']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            
            writer.writeheader()
            for result in results:
                # Clean up the result dict for CSV
                csv_result = {k: v for k, v in result.items() if k in fieldnames}
                writer.writerow(csv_result)
        
        print(f"πŸ“Š Results saved to: {csv_path}")
        return csv_path
    
    def generate_summary(self, results):
        """Generate a summary of the batch processing results."""
        successful = [r for r in results if r['status'] == 'success']
        errors = [r for r in results if r['status'] != 'success']
        
        if successful:
            dr_count = len([r for r in successful if r['prediction'] == 'DR'])
            nodr_count = len([r for r in successful if r['prediction'] == 'NoDR'])
            
            avg_confidence = np.mean([r['confidence'] for r in successful])
            avg_dr_prob = np.mean([r['dr_probability'] for r in successful])
            
            summary = {
                'total_images': len(results),
                'successful': len(successful),
                'errors': len(errors),
                'dr_detected': dr_count,
                'no_dr_detected': nodr_count,
                'dr_percentage': (dr_count / len(successful)) * 100 if successful else 0,
                'average_confidence': avg_confidence,
                'average_dr_probability': avg_dr_prob
            }
        else:
            summary = {
                'total_images': len(results),
                'successful': 0,
                'errors': len(errors),
                'dr_detected': 0,
                'no_dr_detected': 0,
                'dr_percentage': 0,
                'average_confidence': 0,
                'average_dr_probability': 0
            }
        
        return summary

def main():
    """Main function for batch processing."""
    print("πŸš€ Diabetic Retinopathy Detection - Batch Processing")
    print("=" * 60)
    
    # Check if model exists
    if not os.path.exists("resnet50_dr_classifier.pth"):
        print("❌ Model file 'resnet50_dr_classifier.pth' not found!")
        print("   Please ensure the model file is in the current directory.")
        return
    
    # Initialize detector
    try:
        detector = BatchDRDetector()
    except Exception as e:
        print(f"❌ Failed to initialize detector: {e}")
        return
    
    # Get input directory from user
    print("\nπŸ“ Enter the path to the directory containing OCT images:")
    print("   (or press Enter to use current directory)")
    
    user_input = input("Directory path: ").strip()
    
    if user_input:
        input_dir = user_input
    else:
        input_dir = os.getcwd()
    
    if not os.path.exists(input_dir):
        print(f"❌ Directory not found: {input_dir}")
        return
    
    print(f"\n🎯 Processing images from: {input_dir}")
    
    # Process images
    results = detector.process_directory(input_dir)
    
    if not results:
        print("❌ No results to process!")
        return
    
    # Save results
    csv_path = detector.save_results_csv(results)
    
    # Generate and display summary
    summary = detector.generate_summary(results)
    
    print("\nπŸ“Š Batch Processing Summary")
    print("=" * 40)
    print(f"Total images: {summary['total_images']}")
    print(f"Successfully processed: {summary['successful']}")
    print(f"Errors: {summary['errors']}")
    
    if summary['successful'] > 0:
        print(f"DR detected: {summary['dr_detected']} ({summary['dr_percentage']:.1f}%)")
        print(f"No DR detected: {summary['no_dr_detected']}")
        print(f"Average confidence: {summary['average_confidence']:.3f}")
        print(f"Average DR probability: {summary['average_dr_probability']:.3f}")
    
    print(f"\nπŸ“ Results saved to: {detector.output_dir}/")
    print(f"πŸ“Š CSV report: {csv_path}")

if __name__ == "__main__":
    main()