Update app.py
Browse files
app.py
CHANGED
|
@@ -1,104 +1,178 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import streamlit as st
|
| 3 |
-
import numpy as np
|
| 4 |
import cv2
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class ResNetClassifier(torch.nn.Module):
|
| 12 |
-
def __init__(self, num_classes=5):
|
| 13 |
-
super(ResNetClassifier, self).__init__()
|
| 14 |
-
self.model = models.resnet152(pretrained=False)
|
| 15 |
-
self.model.fc = torch.nn.Sequential(
|
| 16 |
-
torch.nn.Linear(self.model.fc.in_features, 512),
|
| 17 |
-
torch.nn.ReLU(),
|
| 18 |
-
torch.nn.Linear(512, num_classes)
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
def forward(self, x):
|
| 22 |
-
return self.model(x)
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
def load_models():
|
| 27 |
-
# Load classification model
|
| 28 |
-
classification_model = ResNetClassifier()
|
| 29 |
-
classifier_checkpoint = torch.load("classifier.pt", map_location=torch.device("cpu"))
|
| 30 |
-
classification_model.load_state_dict(classifier_checkpoint["model_state_dict"])
|
| 31 |
-
classification_model.eval()
|
| 32 |
-
|
| 33 |
-
# Load segmentation model
|
| 34 |
-
segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
|
| 35 |
-
segmentation_model.eval()
|
| 36 |
-
|
| 37 |
-
return classification_model, segmentation_model
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
def
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
transforms.ToTensor(),
|
| 51 |
-
transforms.Normalize(mean=[0.5], std=[0.5])
|
| 52 |
])
|
| 53 |
-
return transform(img_clahe).unsqueeze(0)
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
transforms.Resize((512, 512)),
|
| 59 |
transforms.ToTensor(),
|
| 60 |
-
transforms.Normalize(mean=[0.
|
| 61 |
])
|
| 62 |
-
return transform(image).unsqueeze(0)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
input_tensor = preprocess_image(image)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
if
|
| 85 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
with torch.no_grad():
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
st.
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torchvision import models, transforms
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
# Device configuration
|
| 11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Constants
|
| 14 |
+
CLASS_NAMES = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
|
| 15 |
+
LESION_COLORS = {
|
| 16 |
+
0: [0, 0, 0], # Background (black)
|
| 17 |
+
1: [255, 255, 0], # Bright lesions (yellow)
|
| 18 |
+
2: [255, 0, 0] # Red lesions (red)
|
| 19 |
+
}
|
| 20 |
|
| 21 |
+
# ====================== CLASSIFIER ======================
|
| 22 |
+
def create_classifier_model():
|
| 23 |
+
model = models.resnet152(pretrained=False)
|
| 24 |
+
num_ftrs = model.fc.in_features
|
| 25 |
+
model.fc = nn.Sequential(
|
| 26 |
+
nn.Linear(num_ftrs, 512),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
nn.Linear(512, 5),
|
| 29 |
+
nn.LogSoftmax(dim=1)
|
| 30 |
+
return model
|
| 31 |
+
|
| 32 |
+
@st.cache_resource
|
| 33 |
+
def load_classifier():
|
| 34 |
+
model = create_classifier_model().to(device)
|
| 35 |
+
checkpoint = torch.load('classifier.pt', map_location=device)
|
| 36 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 37 |
+
model.eval()
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
def preprocess_classifier(image: Image.Image) -> np.ndarray:
|
| 41 |
+
"""Green channel + CLAHE preprocessing"""
|
| 42 |
+
img_np = np.array(image)
|
| 43 |
+
green_channel = img_np[:, :, 1]
|
| 44 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 45 |
+
return np.stack([clahe.apply(green_channel)]*3, axis=-1)
|
| 46 |
+
|
| 47 |
+
def get_classifier_transform():
|
| 48 |
+
return transforms.Compose([
|
| 49 |
+
transforms.Resize((224, 224)),
|
| 50 |
transforms.ToTensor(),
|
| 51 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 52 |
])
|
|
|
|
| 53 |
|
| 54 |
+
# ====================== SEGMENTATION ======================
|
| 55 |
+
@st.cache_resource
|
| 56 |
+
def load_segmenter():
|
| 57 |
+
model = torch.load('best_unet_model.pth', map_location=device)
|
| 58 |
+
model.eval()
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
def preprocess_segmenter(image: Image.Image) -> np.ndarray:
|
| 62 |
+
"""LAB + CLAHE + Median filtering"""
|
| 63 |
+
img_np = np.array(image)
|
| 64 |
+
img_filtered = cv2.medianBlur(img_np, 3)
|
| 65 |
+
lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB)
|
| 66 |
+
l, a, b = cv2.split(lab)
|
| 67 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 68 |
+
lab_clahe = cv2.merge((clahe.apply(l), a, b))
|
| 69 |
+
return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
|
| 70 |
+
|
| 71 |
+
def get_segmenter_transform():
|
| 72 |
+
return transforms.Compose([
|
| 73 |
transforms.Resize((512, 512)),
|
| 74 |
transforms.ToTensor(),
|
| 75 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 76 |
])
|
|
|
|
| 77 |
|
| 78 |
+
def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
|
| 79 |
+
"""Convert 5-class output to 3-class mask"""
|
| 80 |
+
probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
|
| 81 |
+
pred_class = np.argmax(probs, axis=0)
|
| 82 |
+
final_mask = np.zeros_like(pred_class, dtype=np.uint8)
|
| 83 |
+
final_mask[(pred_class == 1) | (pred_class == 4)] = 1 # Bright
|
| 84 |
+
final_mask[(pred_class == 2) | (pred_class == 3)] = 2 # Red
|
| 85 |
+
return final_mask, probs
|
| 86 |
|
| 87 |
+
# ====================== VISUALIZATION ======================
|
| 88 |
+
def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image:
|
| 89 |
+
"""Color-coded lesion overlay"""
|
| 90 |
+
original_np = np.array(original)
|
| 91 |
+
mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]),
|
| 92 |
+
interpolation=cv2.INTER_NEAREST)
|
|
|
|
| 93 |
|
| 94 |
+
overlay = original_np.copy()
|
| 95 |
+
for class_idx, color in LESION_COLORS.items():
|
| 96 |
+
overlay[mask_resized == class_idx] = color
|
| 97 |
+
return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0))
|
| 98 |
+
|
| 99 |
+
def create_heatmap(prob_map: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray:
|
| 100 |
+
"""Probability heatmap visualization"""
|
| 101 |
+
resized = cv2.resize(prob_map, original_size)
|
| 102 |
+
return cv2.applyColorMap((resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| 103 |
+
|
| 104 |
+
# ====================== MAIN APP ======================
|
| 105 |
+
def main():
|
| 106 |
+
st.set_page_config(layout="wide")
|
| 107 |
+
st.title("Diabetic Retinopathy Analysis")
|
| 108 |
|
| 109 |
+
uploaded_file = st.file_uploader("Upload retinal image", type=["jpg", "jpeg", "png"])
|
| 110 |
+
if not uploaded_file:
|
| 111 |
+
st.info("Please upload an image")
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
original_image = Image.open(uploaded_file).convert('RGB')
|
| 116 |
+
col1, col2 = st.columns(2)
|
| 117 |
|
| 118 |
+
with col1:
|
| 119 |
+
st.image(original_image, caption="Original Image", use_column_width=True)
|
| 120 |
+
|
| 121 |
+
# Classification
|
| 122 |
+
classifier = load_classifier()
|
| 123 |
+
clf_processed = preprocess_classifier(original_image)
|
| 124 |
+
clf_transform = get_classifier_transform()
|
| 125 |
+
img_tensor = clf_transform(Image.fromarray(clf_processed)).unsqueeze(0).to(device)
|
| 126 |
|
| 127 |
with torch.no_grad():
|
| 128 |
+
logps = classifier(img_tensor)
|
| 129 |
+
ps = torch.exp(logps)
|
| 130 |
+
pred_class = torch.argmax(ps).item()
|
| 131 |
+
probabilities = ps[0].cpu().numpy() * 100
|
| 132 |
+
|
| 133 |
+
st.subheader("Classification Results")
|
| 134 |
+
if pred_class == 0:
|
| 135 |
+
st.success(f"**Prediction:** {CLASS_NAMES[pred_class]}")
|
| 136 |
+
else:
|
| 137 |
+
st.error(f"**Prediction:** {CLASS_NAMES[pred_class]}")
|
| 138 |
+
st.write("**Confidence Levels:**")
|
| 139 |
+
for name, prob in zip(CLASS_NAMES, probabilities):
|
| 140 |
+
st.progress(int(prob))
|
| 141 |
+
st.write(f"{name}: {prob:.1f}%")
|
| 142 |
+
|
| 143 |
+
# Segmentation
|
| 144 |
+
segmenter = load_segmenter()
|
| 145 |
+
with st.spinner("Detecting lesions..."):
|
| 146 |
+
seg_results = segment_image(original_image, segmenter)
|
| 147 |
+
overlay = create_lesion_overlay(original_image, seg_results['mask'])
|
| 148 |
+
heat_bright = create_heatmap(seg_results['probs'][1] + seg_results['probs'][4],
|
| 149 |
+
original_image.size)
|
| 150 |
+
heat_red = create_heatmap(seg_results['probs'][2] + seg_results['probs'][3],
|
| 151 |
+
original_image.size)
|
| 152 |
+
|
| 153 |
+
with col2:
|
| 154 |
+
st.image(overlay, caption="Lesion Overlay", use_column_width=True)
|
| 155 |
+
st.image(heat_bright, caption="Bright Lesion Probability", use_column_width=True)
|
| 156 |
+
st.image(heat_red, caption="Red Lesion Probability", use_column_width=True)
|
| 157 |
+
|
| 158 |
+
# Metrics
|
| 159 |
+
st.write("**Lesion Analysis:**")
|
| 160 |
+
cols = st.columns(3)
|
| 161 |
+
cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%")
|
| 162 |
+
cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%")
|
| 163 |
+
cols[2].metric("Total Affected",
|
| 164 |
+
f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%")
|
| 165 |
+
|
| 166 |
+
# Download
|
| 167 |
+
st.download_button(
|
| 168 |
+
"Download Mask",
|
| 169 |
+
cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(),
|
| 170 |
+
"dr_mask.png",
|
| 171 |
+
"image/png"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
st.error(f"Error processing image: {str(e)}")
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|