|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
import streamlit as st |
|
|
from typing import Tuple |
|
|
from fpdf import FPDF |
|
|
import io |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"] |
|
|
LESION_COLORS = { |
|
|
0: [0, 0, 0], |
|
|
1: [255, 255, 0], |
|
|
2: [255, 0, 0] |
|
|
} |
|
|
UK_GRADES = { |
|
|
"No_DR": "R0 - No retinopathy", |
|
|
"Mild": "R1 - Background DR", |
|
|
"Moderate": "R1 - Background DR", |
|
|
"Severe": "R2 - Pre-proliferative DR", |
|
|
"Proliferate_DR": "R3 - Proliferative DR" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class UNet(nn.Module): |
|
|
def __init__(self, input_channels=3, num_classes=3): |
|
|
super(UNet, self).__init__() |
|
|
|
|
|
def conv_block(in_channels, out_channels): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
) |
|
|
|
|
|
self.encoder1 = conv_block(input_channels, 32) |
|
|
self.pool1 = nn.MaxPool2d(2) |
|
|
self.encoder2 = conv_block(32, 64) |
|
|
self.pool2 = nn.MaxPool2d(2) |
|
|
self.encoder3 = conv_block(64, 128) |
|
|
self.pool3 = nn.MaxPool2d(2) |
|
|
|
|
|
self.bottleneck = conv_block(128, 256) |
|
|
|
|
|
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) |
|
|
self.decoder3 = conv_block(256, 128) |
|
|
|
|
|
self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) |
|
|
self.decoder2 = conv_block(128, 64) |
|
|
|
|
|
self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) |
|
|
self.decoder1 = conv_block(64, 32) |
|
|
|
|
|
self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1) |
|
|
|
|
|
def forward(self, x): |
|
|
enc1 = self.encoder1(x) |
|
|
x = self.pool1(enc1) |
|
|
|
|
|
enc2 = self.encoder2(x) |
|
|
x = self.pool2(enc2) |
|
|
|
|
|
enc3 = self.encoder3(x) |
|
|
x = self.pool3(enc3) |
|
|
|
|
|
x = self.bottleneck(x) |
|
|
|
|
|
x = self.up3(x) |
|
|
x = torch.cat([x, enc3], dim=1) |
|
|
x = self.decoder3(x) |
|
|
|
|
|
x = self.up2(x) |
|
|
x = torch.cat([x, enc2], dim=1) |
|
|
x = self.decoder2(x) |
|
|
|
|
|
x = self.up1(x) |
|
|
x = torch.cat([x, enc1], dim=1) |
|
|
x = self.decoder1(x) |
|
|
|
|
|
return self.final_conv(x) |
|
|
|
|
|
|
|
|
def create_classifier_model(): |
|
|
model = models.resnet152(weights=None) |
|
|
num_ftrs = model.fc.in_features |
|
|
model.fc = nn.Sequential( |
|
|
nn.Linear(num_ftrs, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, 5), |
|
|
nn.LogSoftmax(dim=1)) |
|
|
return model |
|
|
|
|
|
@st.cache_resource |
|
|
def load_classifier(): |
|
|
model = create_classifier_model().to(device) |
|
|
checkpoint = torch.load('classifier.pt', map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def preprocess_classifier(image: Image.Image) -> np.ndarray: |
|
|
img_np = np.array(image) |
|
|
green_channel = img_np[:, :, 1] |
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
|
|
return np.stack([clahe.apply(green_channel)]*3, axis=-1) |
|
|
|
|
|
def get_classifier_transform(): |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_segmenter(): |
|
|
model = UNet().to(device) |
|
|
model.load_state_dict(torch.load('best_unet_model.pth', map_location=device)) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def preprocess_segmenter(image: Image.Image) -> np.ndarray: |
|
|
img_np = np.array(image) |
|
|
img_filtered = cv2.medianBlur(img_np, 3) |
|
|
lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB) |
|
|
l, a, b = cv2.split(lab) |
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
|
|
lab_clahe = cv2.merge((clahe.apply(l), a, b)) |
|
|
return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) |
|
|
|
|
|
def get_segmenter_transform(): |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((512, 512)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: |
|
|
probs = torch.softmax(output, dim=1).cpu().numpy().squeeze() |
|
|
pred_class = np.argmax(probs, axis=0) |
|
|
final_mask = pred_class.astype(np.uint8) |
|
|
return final_mask, probs |
|
|
|
|
|
|
|
|
def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image: |
|
|
original_np = np.array(original) |
|
|
mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]), |
|
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
overlay = original_np.copy() |
|
|
for class_idx, color in LESION_COLORS.items(): |
|
|
overlay[mask_resized == class_idx] = color |
|
|
return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0)) |
|
|
|
|
|
def segment_image(image: Image.Image, model: nn.Module) -> dict: |
|
|
processed_img = preprocess_segmenter(image) |
|
|
img_pil = Image.fromarray(processed_img) |
|
|
transform = get_segmenter_transform() |
|
|
image_tensor = transform(img_pil).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(image_tensor) |
|
|
|
|
|
final_mask, class_probs = process_segmentation_output(output) |
|
|
total_pixels = final_mask.size |
|
|
return { |
|
|
'mask': final_mask, |
|
|
'probs': class_probs, |
|
|
'bright_area': (np.sum(final_mask == 1) / total_pixels * 100), |
|
|
'red_area': (np.sum(final_mask == 2) / total_pixels * 100) |
|
|
} |
|
|
|
|
|
|
|
|
def generate_pdf_report(original_img: Image.Image, mask: np.ndarray, overlay: Image.Image, |
|
|
diagnosis: str, grade: str, bright_area: float, red_area: float): |
|
|
try: |
|
|
pdf = FPDF() |
|
|
pdf.add_page() |
|
|
|
|
|
|
|
|
pdf.set_font("helvetica", "B", 16) |
|
|
pdf.cell(text="Diabetic Retinopathy Diagnosis Report", new_x="LMARGIN", new_y="NEXT", align='C') |
|
|
pdf.ln(10) |
|
|
|
|
|
pdf.set_font("helvetica", "", 12) |
|
|
pdf.cell(text="Patient: ___________________________", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.cell(text="Date: _____________________________", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.ln(10) |
|
|
|
|
|
|
|
|
pdf.set_font("helvetica", "B", 14) |
|
|
pdf.cell(text="Diagnosis:", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.set_font("helvetica", "", 12) |
|
|
pdf.cell(text=f"Stage: {diagnosis}", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.cell(text=f"Grading: {grade}", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.ln(10) |
|
|
|
|
|
|
|
|
pdf.set_font("helvetica", "B", 14) |
|
|
pdf.cell(text="Lesion Analysis:", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.set_font("helvetica", "", 12) |
|
|
pdf.cell(text=f"Bright Lesions: {bright_area:.2f}%", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.cell(text=f"Red Lesions: {red_area:.2f}%", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.cell(text=f"Total Affected Area: {bright_area + red_area:.2f}%", new_x="LMARGIN", new_y="NEXT") |
|
|
pdf.ln(15) |
|
|
|
|
|
|
|
|
pdf.set_font("helvetica", "B", 12) |
|
|
pdf.cell(text="Original Retinal Image:", new_x="LMARGIN", new_y="NEXT") |
|
|
img_byte_arr = io.BytesIO() |
|
|
original_img.save(img_byte_arr, format='PNG') |
|
|
pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) |
|
|
pdf.ln(10) |
|
|
|
|
|
|
|
|
pdf.add_page() |
|
|
|
|
|
|
|
|
pdf.set_font("helvetica", "B", 12) |
|
|
pdf.cell(text="Lesion Segmentation Mask:", new_x="LMARGIN", new_y="NEXT") |
|
|
img_byte_arr = io.BytesIO() |
|
|
Image.fromarray((mask * 85).astype(np.uint8)).save(img_byte_arr, format='PNG') |
|
|
pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) |
|
|
pdf.ln(10) |
|
|
|
|
|
|
|
|
pdf.set_font("helvetica", "B", 12) |
|
|
pdf.cell(text="Lesion Overlay:", new_x="LMARGIN", new_y="NEXT") |
|
|
img_byte_arr = io.BytesIO() |
|
|
overlay.save(img_byte_arr, format='PNG') |
|
|
pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) |
|
|
|
|
|
|
|
|
pdf.ln(10) |
|
|
pdf.set_font("helvetica", "I", 10) |
|
|
pdf.cell(text="This report was generated by DR Analysis System", new_x="LMARGIN", new_y="NEXT", align='C') |
|
|
|
|
|
return bytes(pdf.output()) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"PDF generation failed: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
def main(): |
|
|
st.set_page_config(layout="wide") |
|
|
st.title("Diabetic Retinopathy Analysis") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload retinal scan image", |
|
|
type=["jpg", "jpeg", "png"], |
|
|
label_visibility="visible") |
|
|
if not uploaded_file: |
|
|
st.info("Please upload an image") |
|
|
return |
|
|
|
|
|
try: |
|
|
original_image = Image.open(uploaded_file).convert('RGB') |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.image(original_image, caption="Original Image", use_container_width=True) |
|
|
|
|
|
|
|
|
classifier = load_classifier() |
|
|
clf_processed = preprocess_classifier(original_image) |
|
|
img_tensor = get_classifier_transform()(Image.fromarray(clf_processed)).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logps = classifier(img_tensor) |
|
|
ps = torch.exp(logps) |
|
|
pred_class = torch.argmax(ps).item() |
|
|
probabilities = ps[0].cpu().numpy() * 100 |
|
|
|
|
|
st.subheader("Classification Results") |
|
|
predicted_class_name = CLASS_NAMES[pred_class] |
|
|
uk_grade = UK_GRADES[predicted_class_name] |
|
|
|
|
|
if predicted_class_name == "No_DR": |
|
|
st.success(f""" |
|
|
**Prediction:** {predicted_class_name} |
|
|
**Grade:** {uk_grade} |
|
|
""") |
|
|
st.write("No diabetic retinopathy detected - no segmentation needed.") |
|
|
else: |
|
|
st.error(f""" |
|
|
**Prediction:** {predicted_class_name} |
|
|
**Grade:** {uk_grade} |
|
|
""") |
|
|
|
|
|
st.write("**Confidence Levels:**") |
|
|
for name, prob in zip(CLASS_NAMES, probabilities): |
|
|
st.progress(int(prob)) |
|
|
st.write(f"{name}: {prob:.1f}%") |
|
|
|
|
|
|
|
|
segmenter = load_segmenter() |
|
|
with st.spinner("Detecting lesions..."): |
|
|
seg_results = segment_image(original_image, segmenter) |
|
|
overlay = create_lesion_overlay(original_image, seg_results['mask']) |
|
|
|
|
|
with col2: |
|
|
st.image(overlay, caption="Lesion Overlay", use_container_width=True) |
|
|
|
|
|
|
|
|
st.write("**Lesion Analysis:**") |
|
|
cols = st.columns(3) |
|
|
cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%") |
|
|
cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%") |
|
|
cols[2].metric("Total Affected", |
|
|
f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.download_button( |
|
|
"Download Mask", |
|
|
cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(), |
|
|
"dr_mask.png", |
|
|
"image/png" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
|
|
|
pdf_bytes = generate_pdf_report( |
|
|
original_image, |
|
|
seg_results['mask'], |
|
|
overlay, |
|
|
predicted_class_name, |
|
|
uk_grade, |
|
|
seg_results['bright_area'], |
|
|
seg_results['red_area'] |
|
|
) |
|
|
if pdf_bytes is not None: |
|
|
st.download_button( |
|
|
"Download Full Report", |
|
|
data=pdf_bytes, |
|
|
file_name="dr_diagnosis_report.pdf", |
|
|
mime="application/pdf" |
|
|
) |
|
|
else: |
|
|
st.warning("Failed to generate PDF report") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error processing image: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |