DR / app.py
rnmee's picture
Update app.py
57b7130 verified
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
import streamlit as st
from typing import Tuple
from fpdf import FPDF
import io
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ====================== CONSTANTS ======================
CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"]
LESION_COLORS = {
0: [0, 0, 0], # Background (black)
1: [255, 255, 0], # Bright lesions (yellow)
2: [255, 0, 0] # Red lesions (red)
}
UK_GRADES = {
"No_DR": "R0 - No retinopathy",
"Mild": "R1 - Background DR",
"Moderate": "R1 - Background DR",
"Severe": "R2 - Pre-proliferative DR",
"Proliferate_DR": "R3 - Proliferative DR"
}
# ====================== UNET ARCHITECTURE ======================
class UNet(nn.Module):
def __init__(self, input_channels=3, num_classes=3):
super(UNet, self).__init__()
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
)
self.encoder1 = conv_block(input_channels, 32)
self.pool1 = nn.MaxPool2d(2)
self.encoder2 = conv_block(32, 64)
self.pool2 = nn.MaxPool2d(2)
self.encoder3 = conv_block(64, 128)
self.pool3 = nn.MaxPool2d(2)
self.bottleneck = conv_block(128, 256)
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder3 = conv_block(256, 128)
self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder2 = conv_block(128, 64)
self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.decoder1 = conv_block(64, 32)
self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
x = self.pool1(enc1)
enc2 = self.encoder2(x)
x = self.pool2(enc2)
enc3 = self.encoder3(x)
x = self.pool3(enc3)
x = self.bottleneck(x)
x = self.up3(x)
x = torch.cat([x, enc3], dim=1)
x = self.decoder3(x)
x = self.up2(x)
x = torch.cat([x, enc2], dim=1)
x = self.decoder2(x)
x = self.up1(x)
x = torch.cat([x, enc1], dim=1)
x = self.decoder1(x)
return self.final_conv(x)
# ====================== CLASSIFIER ======================
def create_classifier_model():
model = models.resnet152(weights=None) # Modern syntax
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Linear(512, 5),
nn.LogSoftmax(dim=1))
return model
@st.cache_resource
def load_classifier():
model = create_classifier_model().to(device)
checkpoint = torch.load('classifier.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def preprocess_classifier(image: Image.Image) -> np.ndarray:
img_np = np.array(image)
green_channel = img_np[:, :, 1]
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return np.stack([clahe.apply(green_channel)]*3, axis=-1)
def get_classifier_transform():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# ====================== SEGMENTATION ======================
@st.cache_resource
def load_segmenter():
model = UNet().to(device)
model.load_state_dict(torch.load('best_unet_model.pth', map_location=device))
model.eval()
return model
def preprocess_segmenter(image: Image.Image) -> np.ndarray:
img_np = np.array(image)
img_filtered = cv2.medianBlur(img_np, 3)
lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
lab_clahe = cv2.merge((clahe.apply(l), a, b))
return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
def get_segmenter_transform():
return transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
pred_class = np.argmax(probs, axis=0)
final_mask = pred_class.astype(np.uint8) # Already 0=bg, 1=bright, 2=red
return final_mask, probs
# ====================== VISUALIZATION ======================
def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image:
original_np = np.array(original)
mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]),
interpolation=cv2.INTER_NEAREST)
overlay = original_np.copy()
for class_idx, color in LESION_COLORS.items():
overlay[mask_resized == class_idx] = color
return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0))
def segment_image(image: Image.Image, model: nn.Module) -> dict:
processed_img = preprocess_segmenter(image)
img_pil = Image.fromarray(processed_img)
transform = get_segmenter_transform()
image_tensor = transform(img_pil).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
final_mask, class_probs = process_segmentation_output(output)
total_pixels = final_mask.size
return {
'mask': final_mask,
'probs': class_probs,
'bright_area': (np.sum(final_mask == 1) / total_pixels * 100),
'red_area': (np.sum(final_mask == 2) / total_pixels * 100)
}
# ====================== PDF REPORT GENERATION ======================
def generate_pdf_report(original_img: Image.Image, mask: np.ndarray, overlay: Image.Image,
diagnosis: str, grade: str, bright_area: float, red_area: float):
try:
pdf = FPDF()
pdf.add_page()
# Header and patient info
pdf.set_font("helvetica", "B", 16)
pdf.cell(text="Diabetic Retinopathy Diagnosis Report", new_x="LMARGIN", new_y="NEXT", align='C')
pdf.ln(10)
pdf.set_font("helvetica", "", 12)
pdf.cell(text="Patient: ___________________________", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text="Date: _____________________________", new_x="LMARGIN", new_y="NEXT")
pdf.ln(10)
# Diagnosis section
pdf.set_font("helvetica", "B", 14)
pdf.cell(text="Diagnosis:", new_x="LMARGIN", new_y="NEXT")
pdf.set_font("helvetica", "", 12)
pdf.cell(text=f"Stage: {diagnosis}", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text=f"Grading: {grade}", new_x="LMARGIN", new_y="NEXT")
pdf.ln(10)
# Lesion analysis
pdf.set_font("helvetica", "B", 14)
pdf.cell(text="Lesion Analysis:", new_x="LMARGIN", new_y="NEXT")
pdf.set_font("helvetica", "", 12)
pdf.cell(text=f"Bright Lesions: {bright_area:.2f}%", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text=f"Red Lesions: {red_area:.2f}%", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text=f"Total Affected Area: {bright_area + red_area:.2f}%", new_x="LMARGIN", new_y="NEXT")
pdf.ln(15)
# Original image on first page
pdf.set_font("helvetica", "B", 12)
pdf.cell(text="Original Retinal Image:", new_x="LMARGIN", new_y="NEXT")
img_byte_arr = io.BytesIO()
original_img.save(img_byte_arr, format='PNG')
pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100)
pdf.ln(10)
# Add new page for segmentation results
pdf.add_page()
# Segmentation mask
pdf.set_font("helvetica", "B", 12)
pdf.cell(text="Lesion Segmentation Mask:", new_x="LMARGIN", new_y="NEXT")
img_byte_arr = io.BytesIO()
Image.fromarray((mask * 85).astype(np.uint8)).save(img_byte_arr, format='PNG')
pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100)
pdf.ln(10)
# Lesion overlay
pdf.set_font("helvetica", "B", 12)
pdf.cell(text="Lesion Overlay:", new_x="LMARGIN", new_y="NEXT")
img_byte_arr = io.BytesIO()
overlay.save(img_byte_arr, format='PNG')
pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100)
# Footer on last page
pdf.ln(10)
pdf.set_font("helvetica", "I", 10)
pdf.cell(text="This report was generated by DR Analysis System", new_x="LMARGIN", new_y="NEXT", align='C')
return bytes(pdf.output())
except Exception as e:
st.error(f"PDF generation failed: {str(e)}")
return None
# ====================== MAIN APP ======================
def main():
st.set_page_config(layout="wide")
st.title("Diabetic Retinopathy Analysis")
uploaded_file = st.file_uploader("Upload retinal scan image",
type=["jpg", "jpeg", "png"],
label_visibility="visible")
if not uploaded_file:
st.info("Please upload an image")
return
try:
original_image = Image.open(uploaded_file).convert('RGB')
col1, col2 = st.columns(2)
with col1:
st.image(original_image, caption="Original Image", use_container_width=True)
# Classification
classifier = load_classifier()
clf_processed = preprocess_classifier(original_image)
img_tensor = get_classifier_transform()(Image.fromarray(clf_processed)).unsqueeze(0).to(device)
with torch.no_grad():
logps = classifier(img_tensor)
ps = torch.exp(logps)
pred_class = torch.argmax(ps).item()
probabilities = ps[0].cpu().numpy() * 100
st.subheader("Classification Results")
predicted_class_name = CLASS_NAMES[pred_class]
uk_grade = UK_GRADES[predicted_class_name]
if predicted_class_name == "No_DR":
st.success(f"""
**Prediction:** {predicted_class_name}
**Grade:** {uk_grade}
""")
st.write("No diabetic retinopathy detected - no segmentation needed.")
else:
st.error(f"""
**Prediction:** {predicted_class_name}
**Grade:** {uk_grade}
""")
st.write("**Confidence Levels:**")
for name, prob in zip(CLASS_NAMES, probabilities):
st.progress(int(prob))
st.write(f"{name}: {prob:.1f}%")
# Segmentation (ONLY if not "No_DR")
segmenter = load_segmenter()
with st.spinner("Detecting lesions..."):
seg_results = segment_image(original_image, segmenter)
overlay = create_lesion_overlay(original_image, seg_results['mask'])
with col2:
st.image(overlay, caption="Lesion Overlay", use_container_width=True)
# Metrics
st.write("**Lesion Analysis:**")
cols = st.columns(3)
cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%")
cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%")
cols[2].metric("Total Affected",
f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%")
# Download buttons
col1, col2 = st.columns(2)
with col1:
st.download_button(
"Download Mask",
cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(),
"dr_mask.png",
"image/png"
)
with col2:
# Generate and download PDF report
pdf_bytes = generate_pdf_report(
original_image,
seg_results['mask'],
overlay,
predicted_class_name,
uk_grade,
seg_results['bright_area'],
seg_results['red_area']
)
if pdf_bytes is not None:
st.download_button(
"Download Full Report",
data=pdf_bytes,
file_name="dr_diagnosis_report.pdf",
mime="application/pdf"
)
else:
st.warning("Failed to generate PDF report")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
if __name__ == "__main__":
main()