DR / app.py
rnmee's picture
Update app.py
3eaadc1 verified
raw
history blame
13.2 kB
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
import streamlit as st
from typing import Tuple
from fpdf import FPDF
import io
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ====================== CONSTANTS ======================
CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"]
LESION_COLORS = {
0: [0, 0, 0], # Background (black)
1: [255, 255, 0], # Bright lesions (yellow)
2: [255, 0, 0] # Red lesions (red)
}
UK_GRADES = {
"No_DR": "R0 - No retinopathy",
"Mild": "R1 - Background DR",
"Moderate": "R1 - Background DR",
"Severe": "R2 - Pre-proliferative DR",
"Proliferate_DR": "R3 - Proliferative DR"
}
# ====================== UNET ARCHITECTURE ======================
class UNet(nn.Module):
def __init__(self, input_channels=3, num_classes=3):
super(UNet, self).__init__()
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
)
self.encoder1 = conv_block(input_channels, 32)
self.pool1 = nn.MaxPool2d(2)
self.encoder2 = conv_block(32, 64)
self.pool2 = nn.MaxPool2d(2)
self.encoder3 = conv_block(64, 128)
self.pool3 = nn.MaxPool2d(2)
self.bottleneck = conv_block(128, 256)
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder3 = conv_block(256, 128)
self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder2 = conv_block(128, 64)
self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.decoder1 = conv_block(64, 32)
self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
x = self.pool1(enc1)
enc2 = self.encoder2(x)
x = self.pool2(enc2)
enc3 = self.encoder3(x)
x = self.pool3(enc3)
x = self.bottleneck(x)
x = self.up3(x)
x = torch.cat([x, enc3], dim=1)
x = self.decoder3(x)
x = self.up2(x)
x = torch.cat([x, enc2], dim=1)
x = self.decoder2(x)
x = self.up1(x)
x = torch.cat([x, enc1], dim=1)
x = self.decoder1(x)
return self.final_conv(x)
# ====================== CLASSIFIER ======================
def create_classifier_model():
model = models.resnet152(weights=None) # Modern syntax
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Linear(512, 5),
nn.LogSoftmax(dim=1))
return model
@st.cache_resource
def load_classifier():
model = create_classifier_model().to(device)
checkpoint = torch.load('classifier.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def preprocess_classifier(image: Image.Image) -> np.ndarray:
img_np = np.array(image)
green_channel = img_np[:, :, 1]
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return np.stack([clahe.apply(green_channel)]*3, axis=-1)
def get_classifier_transform():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# ====================== SEGMENTATION ======================
@st.cache_resource
def load_segmenter():
model = UNet().to(device)
model.load_state_dict(torch.load('best_unet_model.pth', map_location=device))
model.eval()
return model
def preprocess_segmenter(image: Image.Image) -> np.ndarray:
img_np = np.array(image)
img_filtered = cv2.medianBlur(img_np, 3)
lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
lab_clahe = cv2.merge((clahe.apply(l), a, b))
return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
def get_segmenter_transform():
return transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
pred_class = np.argmax(probs, axis=0)
final_mask = pred_class.astype(np.uint8) # Already 0=bg, 1=bright, 2=red
return final_mask, probs
# ====================== VISUALIZATION ======================
def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image:
original_np = np.array(original)
mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]),
interpolation=cv2.INTER_NEAREST)
overlay = original_np.copy()
for class_idx, color in LESION_COLORS.items():
overlay[mask_resized == class_idx] = color
return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0))
def segment_image(image: Image.Image, model: nn.Module) -> dict:
processed_img = preprocess_segmenter(image)
img_pil = Image.fromarray(processed_img)
transform = get_segmenter_transform()
image_tensor = transform(img_pil).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
final_mask, class_probs = process_segmentation_output(output)
total_pixels = final_mask.size
return {
'mask': final_mask,
'probs': class_probs,
'bright_area': (np.sum(final_mask == 1) / total_pixels * 100),
'red_area': (np.sum(final_mask == 2) / total_pixels * 100)
}
# ====================== PDF REPORT GENERATION ======================
def generate_pdf_report(original_img: Image.Image, mask: np.ndarray, overlay: Image.Image,
diagnosis: str, grade: str, bright_area: float, red_area: float):
try:
# Create PDF in memory
pdf = FPDF()
pdf.add_page()
# Use built-in font that supports basic characters
pdf.set_font("helvetica", "B", 16)
pdf.cell(text="Diabetic Retinopathy Diagnosis Report", new_x="LMARGIN", new_y="NEXT", align='C')
pdf.ln(10)
# Patient Info
pdf.set_font("helvetica", "", 12)
pdf.cell(text="Patient: ___________________________", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text="Date: _____________________________", new_x="LMARGIN", new_y="NEXT")
pdf.ln(10)
# Diagnosis
pdf.set_font("helvetica", "B", 14)
pdf.cell(text="Diagnosis:", new_x="LMARGIN", new_y="NEXT")
pdf.set_font("helvetica", "", 12)
pdf.cell(text=f"Stage: {diagnosis}", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text=f"UK Grading: {grade}", new_x="LMARGIN", new_y="NEXT")
pdf.ln(10)
# Lesion Analysis
pdf.set_font("helvetica", "B", 14)
pdf.cell(text="Lesion Analysis:", new_x="LMARGIN", new_y="NEXT")
pdf.set_font("helvetica", "", 12)
pdf.cell(text=f"Bright Lesions: {bright_area:.2f}%", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text=f"Red Lesions: {red_area:.2f}%", new_x="LMARGIN", new_y="NEXT")
pdf.cell(text=f"Total Affected Area: {bright_area + red_area:.2f}%", new_x="LMARGIN", new_y="NEXT")
pdf.ln(15)
# Helper function to add images
def add_pdf_image(pdf, image, title):
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_bytes = img_byte_arr.getvalue()
pdf.set_font("helvetica", "B", 12)
pdf.cell(text=title, new_x="LMARGIN", new_y="NEXT")
pdf.image(io.BytesIO(img_bytes), x=10, w=180)
pdf.ln(5)
# Add images
add_pdf_image(pdf, original_img, "Original Retinal Image:")
add_pdf_image(pdf, Image.fromarray((mask * 85).astype(np.uint8)), "Lesion Segmentation Mask:")
add_pdf_image(pdf, overlay, "Lesion Overlay:")
# Footer
pdf.ln(10)
pdf.set_font("helvetica", "I", 10)
pdf.cell(text="This report was generated by DR Analysis System", new_x="LMARGIN", new_y="NEXT", align='C')
# Get PDF as bytes - this is the critical fix
pdf_bytes = pdf.output(dest='S').encode('latin1')
return pdf_bytes
except Exception as e:
st.error(f"PDF generation failed: {str(e)}")
return None
# ====================== MAIN APP ======================
def main():
st.set_page_config(layout="wide")
st.title("Diabetic Retinopathy Analysis")
uploaded_file = st.file_uploader("Upload retinal scan image",
type=["jpg", "jpeg", "png"],
label_visibility="visible")
if not uploaded_file:
st.info("Please upload an image")
return
try:
original_image = Image.open(uploaded_file).convert('RGB')
col1, col2 = st.columns(2)
with col1:
st.image(original_image, caption="Original Image", use_container_width=True)
# Classification
classifier = load_classifier()
clf_processed = preprocess_classifier(original_image)
img_tensor = get_classifier_transform()(Image.fromarray(clf_processed)).unsqueeze(0).to(device)
with torch.no_grad():
logps = classifier(img_tensor)
ps = torch.exp(logps)
pred_class = torch.argmax(ps).item()
probabilities = ps[0].cpu().numpy() * 100
st.subheader("Classification Results")
predicted_class_name = CLASS_NAMES[pred_class]
uk_grade = UK_GRADES[predicted_class_name]
if predicted_class_name == "No_DR":
st.success(f"""
**Prediction:** {predicted_class_name}
**Grade:** {uk_grade}
""")
st.write("No diabetic retinopathy detected - no segmentation needed.")
else:
st.error(f"""
**Prediction:** {predicted_class_name}
**Grade:** {uk_grade}
""")
st.write("**Confidence Levels:**")
for name, prob in zip(CLASS_NAMES, probabilities):
st.progress(int(prob))
st.write(f"{name}: {prob:.1f}%")
# Segmentation (ONLY if not "No_DR")
segmenter = load_segmenter()
with st.spinner("Detecting lesions..."):
seg_results = segment_image(original_image, segmenter)
overlay = create_lesion_overlay(original_image, seg_results['mask'])
with col2:
st.image(overlay, caption="Lesion Overlay", use_container_width=True)
# Metrics
st.write("**Lesion Analysis:**")
cols = st.columns(3)
cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%")
cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%")
cols[2].metric("Total Affected",
f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%")
# Download buttons
col1, col2 = st.columns(2)
with col1:
st.download_button(
"Download Mask",
cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(),
"dr_mask.png",
"image/png"
)
# In your main app where you create the download button:
with col2:
# Generate and download PDF report
pdf_bytes = generate_pdf_report(
original_image,
seg_results['mask'],
overlay,
predicted_class_name,
uk_grade,
seg_results['bright_area'],
seg_results['red_area']
)
if pdf_bytes is not None:
st.download_button(
"Download Full Report",
data=pdf_bytes,
file_name="dr_diagnosis_report.pdf",
mime="application/pdf"
)
else:
st.warning("Failed to generate PDF report")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
if __name__ == "__main__":
main()