import cv2 import numpy as np from PIL import Image import torch import torch.nn as nn from torchvision import models, transforms import streamlit as st from typing import Tuple from fpdf import FPDF import io # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ====================== CONSTANTS ====================== CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"] LESION_COLORS = { 0: [0, 0, 0], # Background (black) 1: [255, 255, 0], # Bright lesions (yellow) 2: [255, 0, 0] # Red lesions (red) } UK_GRADES = { "No_DR": "R0 - No retinopathy", "Mild": "R1 - Background DR", "Moderate": "R1 - Background DR", "Severe": "R2 - Pre-proliferative DR", "Proliferate_DR": "R3 - Proliferative DR" } # ====================== UNET ARCHITECTURE ====================== class UNet(nn.Module): def __init__(self, input_channels=3, num_classes=3): super(UNet, self).__init__() def conv_block(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(out_channels), ) self.encoder1 = conv_block(input_channels, 32) self.pool1 = nn.MaxPool2d(2) self.encoder2 = conv_block(32, 64) self.pool2 = nn.MaxPool2d(2) self.encoder3 = conv_block(64, 128) self.pool3 = nn.MaxPool2d(2) self.bottleneck = conv_block(128, 256) self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.decoder3 = conv_block(256, 128) self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.decoder2 = conv_block(128, 64) self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) self.decoder1 = conv_block(64, 32) self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1) def forward(self, x): enc1 = self.encoder1(x) x = self.pool1(enc1) enc2 = self.encoder2(x) x = self.pool2(enc2) enc3 = self.encoder3(x) x = self.pool3(enc3) x = self.bottleneck(x) x = self.up3(x) x = torch.cat([x, enc3], dim=1) x = self.decoder3(x) x = self.up2(x) x = torch.cat([x, enc2], dim=1) x = self.decoder2(x) x = self.up1(x) x = torch.cat([x, enc1], dim=1) x = self.decoder1(x) return self.final_conv(x) # ====================== CLASSIFIER ====================== def create_classifier_model(): model = models.resnet152(weights=None) # Modern syntax num_ftrs = model.fc.in_features model.fc = nn.Sequential( nn.Linear(num_ftrs, 512), nn.ReLU(), nn.Linear(512, 5), nn.LogSoftmax(dim=1)) return model @st.cache_resource def load_classifier(): model = create_classifier_model().to(device) checkpoint = torch.load('classifier.pt', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model def preprocess_classifier(image: Image.Image) -> np.ndarray: img_np = np.array(image) green_channel = img_np[:, :, 1] clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) return np.stack([clahe.apply(green_channel)]*3, axis=-1) def get_classifier_transform(): return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # ====================== SEGMENTATION ====================== @st.cache_resource def load_segmenter(): model = UNet().to(device) model.load_state_dict(torch.load('best_unet_model.pth', map_location=device)) model.eval() return model def preprocess_segmenter(image: Image.Image) -> np.ndarray: img_np = np.array(image) img_filtered = cv2.medianBlur(img_np, 3) lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) lab_clahe = cv2.merge((clahe.apply(l), a, b)) return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) def get_segmenter_transform(): return transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: probs = torch.softmax(output, dim=1).cpu().numpy().squeeze() pred_class = np.argmax(probs, axis=0) final_mask = pred_class.astype(np.uint8) # Already 0=bg, 1=bright, 2=red return final_mask, probs # ====================== VISUALIZATION ====================== def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image: original_np = np.array(original) mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]), interpolation=cv2.INTER_NEAREST) overlay = original_np.copy() for class_idx, color in LESION_COLORS.items(): overlay[mask_resized == class_idx] = color return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0)) def segment_image(image: Image.Image, model: nn.Module) -> dict: processed_img = preprocess_segmenter(image) img_pil = Image.fromarray(processed_img) transform = get_segmenter_transform() image_tensor = transform(img_pil).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) final_mask, class_probs = process_segmentation_output(output) total_pixels = final_mask.size return { 'mask': final_mask, 'probs': class_probs, 'bright_area': (np.sum(final_mask == 1) / total_pixels * 100), 'red_area': (np.sum(final_mask == 2) / total_pixels * 100) } # ====================== PDF REPORT GENERATION ====================== def generate_pdf_report(original_img: Image.Image, mask: np.ndarray, overlay: Image.Image, diagnosis: str, grade: str, bright_area: float, red_area: float): try: pdf = FPDF() pdf.add_page() # Header and patient info pdf.set_font("helvetica", "B", 16) pdf.cell(text="Diabetic Retinopathy Diagnosis Report", new_x="LMARGIN", new_y="NEXT", align='C') pdf.ln(10) pdf.set_font("helvetica", "", 12) pdf.cell(text="Patient: ___________________________", new_x="LMARGIN", new_y="NEXT") pdf.cell(text="Date: _____________________________", new_x="LMARGIN", new_y="NEXT") pdf.ln(10) # Diagnosis section pdf.set_font("helvetica", "B", 14) pdf.cell(text="Diagnosis:", new_x="LMARGIN", new_y="NEXT") pdf.set_font("helvetica", "", 12) pdf.cell(text=f"Stage: {diagnosis}", new_x="LMARGIN", new_y="NEXT") pdf.cell(text=f"Grading: {grade}", new_x="LMARGIN", new_y="NEXT") pdf.ln(10) # Lesion analysis pdf.set_font("helvetica", "B", 14) pdf.cell(text="Lesion Analysis:", new_x="LMARGIN", new_y="NEXT") pdf.set_font("helvetica", "", 12) pdf.cell(text=f"Bright Lesions: {bright_area:.2f}%", new_x="LMARGIN", new_y="NEXT") pdf.cell(text=f"Red Lesions: {red_area:.2f}%", new_x="LMARGIN", new_y="NEXT") pdf.cell(text=f"Total Affected Area: {bright_area + red_area:.2f}%", new_x="LMARGIN", new_y="NEXT") pdf.ln(15) # Original image on first page pdf.set_font("helvetica", "B", 12) pdf.cell(text="Original Retinal Image:", new_x="LMARGIN", new_y="NEXT") img_byte_arr = io.BytesIO() original_img.save(img_byte_arr, format='PNG') pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) pdf.ln(10) # Add new page for segmentation results pdf.add_page() # Segmentation mask pdf.set_font("helvetica", "B", 12) pdf.cell(text="Lesion Segmentation Mask:", new_x="LMARGIN", new_y="NEXT") img_byte_arr = io.BytesIO() Image.fromarray((mask * 85).astype(np.uint8)).save(img_byte_arr, format='PNG') pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) pdf.ln(10) # Lesion overlay pdf.set_font("helvetica", "B", 12) pdf.cell(text="Lesion Overlay:", new_x="LMARGIN", new_y="NEXT") img_byte_arr = io.BytesIO() overlay.save(img_byte_arr, format='PNG') pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) # Footer on last page pdf.ln(10) pdf.set_font("helvetica", "I", 10) pdf.cell(text="This report was generated by DR Analysis System", new_x="LMARGIN", new_y="NEXT", align='C') return bytes(pdf.output()) except Exception as e: st.error(f"PDF generation failed: {str(e)}") return None # ====================== MAIN APP ====================== def main(): st.set_page_config(layout="wide") st.title("Diabetic Retinopathy Analysis") uploaded_file = st.file_uploader("Upload retinal scan image", type=["jpg", "jpeg", "png"], label_visibility="visible") if not uploaded_file: st.info("Please upload an image") return try: original_image = Image.open(uploaded_file).convert('RGB') col1, col2 = st.columns(2) with col1: st.image(original_image, caption="Original Image", use_container_width=True) # Classification classifier = load_classifier() clf_processed = preprocess_classifier(original_image) img_tensor = get_classifier_transform()(Image.fromarray(clf_processed)).unsqueeze(0).to(device) with torch.no_grad(): logps = classifier(img_tensor) ps = torch.exp(logps) pred_class = torch.argmax(ps).item() probabilities = ps[0].cpu().numpy() * 100 st.subheader("Classification Results") predicted_class_name = CLASS_NAMES[pred_class] uk_grade = UK_GRADES[predicted_class_name] if predicted_class_name == "No_DR": st.success(f""" **Prediction:** {predicted_class_name} **Grade:** {uk_grade} """) st.write("No diabetic retinopathy detected - no segmentation needed.") else: st.error(f""" **Prediction:** {predicted_class_name} **Grade:** {uk_grade} """) st.write("**Confidence Levels:**") for name, prob in zip(CLASS_NAMES, probabilities): st.progress(int(prob)) st.write(f"{name}: {prob:.1f}%") # Segmentation (ONLY if not "No_DR") segmenter = load_segmenter() with st.spinner("Detecting lesions..."): seg_results = segment_image(original_image, segmenter) overlay = create_lesion_overlay(original_image, seg_results['mask']) with col2: st.image(overlay, caption="Lesion Overlay", use_container_width=True) # Metrics st.write("**Lesion Analysis:**") cols = st.columns(3) cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%") cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%") cols[2].metric("Total Affected", f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%") # Download buttons col1, col2 = st.columns(2) with col1: st.download_button( "Download Mask", cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(), "dr_mask.png", "image/png" ) with col2: # Generate and download PDF report pdf_bytes = generate_pdf_report( original_image, seg_results['mask'], overlay, predicted_class_name, uk_grade, seg_results['bright_area'], seg_results['red_area'] ) if pdf_bytes is not None: st.download_button( "Download Full Report", data=pdf_bytes, file_name="dr_diagnosis_report.pdf", mime="application/pdf" ) else: st.warning("Failed to generate PDF report") except Exception as e: st.error(f"Error processing image: {str(e)}") if __name__ == "__main__": main()