#!/usr/bin/env python3 """ Batch processing script for diabetic retinopathy detection. Processes multiple OCT images and saves results in a structured format. """ import os import torch import torch.nn.functional as F import numpy as np from PIL import Image from torchvision import models, transforms from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import csv import datetime from pathlib import Path class BatchDRDetector: def __init__(self, model_path="resnet50_dr_classifier.pth"): """Initialize the batch detector with the trained model.""" self.device = torch.device("cpu") self.model_path = model_path self.model = None self.cam = None self.transform = None self.output_dir = "batch_results" # Create output directory os.makedirs(self.output_dir, exist_ok=True) self._load_model() self._setup_gradcam() self._setup_transforms() def _load_model(self): """Load the trained ResNet-50 model.""" print("šŸ”„ Loading model...") try: self.model = models.resnet50(weights=None) self.model.fc = torch.nn.Linear(self.model.fc.in_features, 2) self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() print("āœ… Model loaded successfully!") except Exception as e: print(f"āŒ Error loading model: {e}") raise def _setup_gradcam(self): """Setup Grad-CAM for visualization.""" target_layer = self.model.layer4[-1] self.cam = GradCAM(model=self.model, target_layers=[target_layer]) print("āœ… Grad-CAM setup complete!") def _setup_transforms(self): """Setup image preprocessing transforms.""" self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def process_single_image(self, image_path): """Process a single image and return results.""" try: # Load and preprocess image img = Image.open(image_path).convert("RGB") img_tensor = self.transform(img).unsqueeze(0).to(self.device) # Get prediction with torch.no_grad(): output = self.model(img_tensor) probs = F.softmax(output, dim=1) pred = torch.argmax(probs, dim=1).item() confidence = probs[0][pred].item() # Generate Grad-CAM rgb_img_np = np.array(img.resize((224, 224))).astype(np.float32) / 255.0 rgb_img_np = np.ascontiguousarray(rgb_img_np) grayscale_cam = self.cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0] cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True) # Determine label label = "DR" if pred == 0 else "NoDR" return { 'image_path': image_path, 'prediction': label, 'confidence': confidence, 'dr_probability': 1 - confidence if pred == 1 else confidence, 'cam_image': cam_image, 'status': 'success' } except Exception as e: return { 'image_path': image_path, 'prediction': 'ERROR', 'confidence': 0.0, 'dr_probability': 0.0, 'cam_image': None, 'status': f'error: {str(e)}' } def process_directory(self, input_dir, extensions=['.jpg', '.jpeg', '.png', '.tiff', '.bmp']): """Process all images in a directory.""" print(f"šŸ” Scanning directory: {input_dir}") # Find all image files image_files = [] for ext in extensions: image_files.extend(Path(input_dir).glob(f"*{ext}")) image_files.extend(Path(input_dir).glob(f"*{ext.upper()}")) if not image_files: print("āŒ No image files found in the directory!") return [] print(f"šŸ“ Found {len(image_files)} image files") # Process each image results = [] for i, image_path in enumerate(image_files, 1): print(f"šŸ”„ Processing {i}/{len(image_files)}: {image_path.name}") result = self.process_single_image(str(image_path)) results.append(result) # Save Grad-CAM image if successful if result['status'] == 'success' and result['cam_image'] is not None: cam_filename = f"cam_{Path(image_path).stem}_{result['prediction']}_{result['confidence']:.3f}.png" cam_path = os.path.join(self.output_dir, cam_filename) Image.fromarray(result['cam_image']).save(cam_path) result['cam_saved_path'] = cam_path return results def save_results_csv(self, results, filename=None): """Save results to a CSV file.""" if not filename: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"dr_results_{timestamp}.csv" csv_path = os.path.join(self.output_dir, filename) with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile: fieldnames = ['image_path', 'prediction', 'confidence', 'dr_probability', 'status', 'cam_saved_path'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for result in results: # Clean up the result dict for CSV csv_result = {k: v for k, v in result.items() if k in fieldnames} writer.writerow(csv_result) print(f"šŸ“Š Results saved to: {csv_path}") return csv_path def generate_summary(self, results): """Generate a summary of the batch processing results.""" successful = [r for r in results if r['status'] == 'success'] errors = [r for r in results if r['status'] != 'success'] if successful: dr_count = len([r for r in successful if r['prediction'] == 'DR']) nodr_count = len([r for r in successful if r['prediction'] == 'NoDR']) avg_confidence = np.mean([r['confidence'] for r in successful]) avg_dr_prob = np.mean([r['dr_probability'] for r in successful]) summary = { 'total_images': len(results), 'successful': len(successful), 'errors': len(errors), 'dr_detected': dr_count, 'no_dr_detected': nodr_count, 'dr_percentage': (dr_count / len(successful)) * 100 if successful else 0, 'average_confidence': avg_confidence, 'average_dr_probability': avg_dr_prob } else: summary = { 'total_images': len(results), 'successful': 0, 'errors': len(errors), 'dr_detected': 0, 'no_dr_detected': 0, 'dr_percentage': 0, 'average_confidence': 0, 'average_dr_probability': 0 } return summary def main(): """Main function for batch processing.""" print("šŸš€ Diabetic Retinopathy Detection - Batch Processing") print("=" * 60) # Check if model exists if not os.path.exists("resnet50_dr_classifier.pth"): print("āŒ Model file 'resnet50_dr_classifier.pth' not found!") print(" Please ensure the model file is in the current directory.") return # Initialize detector try: detector = BatchDRDetector() except Exception as e: print(f"āŒ Failed to initialize detector: {e}") return # Get input directory from user print("\nšŸ“ Enter the path to the directory containing OCT images:") print(" (or press Enter to use current directory)") user_input = input("Directory path: ").strip() if user_input: input_dir = user_input else: input_dir = os.getcwd() if not os.path.exists(input_dir): print(f"āŒ Directory not found: {input_dir}") return print(f"\nšŸŽÆ Processing images from: {input_dir}") # Process images results = detector.process_directory(input_dir) if not results: print("āŒ No results to process!") return # Save results csv_path = detector.save_results_csv(results) # Generate and display summary summary = detector.generate_summary(results) print("\nšŸ“Š Batch Processing Summary") print("=" * 40) print(f"Total images: {summary['total_images']}") print(f"Successfully processed: {summary['successful']}") print(f"Errors: {summary['errors']}") if summary['successful'] > 0: print(f"DR detected: {summary['dr_detected']} ({summary['dr_percentage']:.1f}%)") print(f"No DR detected: {summary['no_dr_detected']}") print(f"Average confidence: {summary['average_confidence']:.3f}") print(f"Average DR probability: {summary['average_dr_probability']:.3f}") print(f"\nšŸ“ Results saved to: {detector.output_dir}/") print(f"šŸ“Š CSV report: {csv_path}") if __name__ == "__main__": main()