rnmee commited on
Commit
e299673
·
verified ·
1 Parent(s): 68295e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -85
app.py CHANGED
@@ -1,104 +1,178 @@
1
- import torch
2
- import streamlit as st
3
- import numpy as np
4
  import cv2
 
5
  from PIL import Image
6
- import torchvision.transforms as transforms
7
- import torchvision.models as models
8
- import io
9
-
10
- # Define the classification model structure (ResNet152 with modified FC layer)
11
- class ResNetClassifier(torch.nn.Module):
12
- def __init__(self, num_classes=5):
13
- super(ResNetClassifier, self).__init__()
14
- self.model = models.resnet152(pretrained=False)
15
- self.model.fc = torch.nn.Sequential(
16
- torch.nn.Linear(self.model.fc.in_features, 512),
17
- torch.nn.ReLU(),
18
- torch.nn.Linear(512, num_classes)
19
- )
20
-
21
- def forward(self, x):
22
- return self.model(x)
23
 
24
- # Load models
25
- @st.cache_resource
26
- def load_models():
27
- # Load classification model
28
- classification_model = ResNetClassifier()
29
- classifier_checkpoint = torch.load("classifier.pt", map_location=torch.device("cpu"))
30
- classification_model.load_state_dict(classifier_checkpoint["model_state_dict"])
31
- classification_model.eval()
32
-
33
- # Load segmentation model
34
- segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
35
- segmentation_model.eval()
36
-
37
- return classification_model, segmentation_model
38
 
39
- classification_model, segmentation_model = load_models()
 
 
 
 
 
 
40
 
41
- # Define preprocessing function for classification
42
- def preprocess_image(image):
43
- image = np.array(image)
44
- green_channel = image[:, :, 1] # Extract green channel
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
47
- img_clahe = clahe.apply(green_channel)
48
-
49
- transform = transforms.Compose([
 
 
50
  transforms.ToTensor(),
51
- transforms.Normalize(mean=[0.5], std=[0.5])
52
  ])
53
- return transform(img_clahe).unsqueeze(0)
54
 
55
- # Define function for segmentation preprocessing
56
- def preprocess_segmentation(image):
57
- transform = transforms.Compose([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  transforms.Resize((512, 512)),
59
  transforms.ToTensor(),
60
- transforms.Normalize(mean=[0.5], std=[0.5])
61
  ])
62
- return transform(image).unsqueeze(0)
63
 
64
- # Streamlit UI
65
- st.title("Diabetic Retinopathy Detection & Segmentation")
 
 
 
 
 
 
66
 
67
- uploaded_file = st.file_uploader("Upload a Retinal Image", type=["jpg", "png", "jpeg"])
68
- if uploaded_file:
69
- image = Image.open(uploaded_file)
70
- st.image(image, caption="Uploaded Image", use_column_width=True)
71
-
72
- # Preprocess image for classification
73
- input_tensor = preprocess_image(image)
74
 
75
- # Run classification
76
- with torch.no_grad():
77
- output = classification_model(input_tensor)
78
- predicted_class = torch.argmax(output).item()
79
-
80
- dr_stages = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
81
- st.write(f"**Diabetic Retinopathy Stage:** {dr_stages[predicted_class]}")
 
 
 
 
 
 
 
82
 
83
- # If DR detected, proceed to segmentation
84
- if predicted_class > 0:
85
- st.write("Lesion segmentation in progress...")
 
 
 
 
 
86
 
87
- segmentation_input = preprocess_segmentation(image)
 
 
 
 
 
 
 
88
 
89
  with torch.no_grad():
90
- segmentation_output = segmentation_model(segmentation_input)
91
-
92
- segmentation_mask = segmentation_output.squeeze().cpu().numpy()
93
- segmentation_mask = (segmentation_mask > 0.5).astype(np.uint8) * 255
94
-
95
- st.image(segmentation_mask, caption="Segmented Lesions", use_column_width=True)
96
-
97
- # Provide download option
98
- mask_image = Image.fromarray(segmentation_mask)
99
- buf = io.BytesIO()
100
- mask_image.save(buf, format="PNG")
101
- byte_im = buf.getvalue()
102
- st.download_button("Download Segmentation Mask", data=byte_im, file_name="segmentation_mask.png", mime="image/png")
103
- else:
104
- st.write("No DR detected. Segmentation not required.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
+ import numpy as np
3
  from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models, transforms
7
+ import streamlit as st
8
+ from typing import Tuple
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Device configuration
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Constants
14
+ CLASS_NAMES = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
15
+ LESION_COLORS = {
16
+ 0: [0, 0, 0], # Background (black)
17
+ 1: [255, 255, 0], # Bright lesions (yellow)
18
+ 2: [255, 0, 0] # Red lesions (red)
19
+ }
20
 
21
+ # ====================== CLASSIFIER ======================
22
+ def create_classifier_model():
23
+ model = models.resnet152(pretrained=False)
24
+ num_ftrs = model.fc.in_features
25
+ model.fc = nn.Sequential(
26
+ nn.Linear(num_ftrs, 512),
27
+ nn.ReLU(),
28
+ nn.Linear(512, 5),
29
+ nn.LogSoftmax(dim=1)
30
+ return model
31
+
32
+ @st.cache_resource
33
+ def load_classifier():
34
+ model = create_classifier_model().to(device)
35
+ checkpoint = torch.load('classifier.pt', map_location=device)
36
+ model.load_state_dict(checkpoint['model_state_dict'])
37
+ model.eval()
38
+ return model
39
+
40
+ def preprocess_classifier(image: Image.Image) -> np.ndarray:
41
+ """Green channel + CLAHE preprocessing"""
42
+ img_np = np.array(image)
43
+ green_channel = img_np[:, :, 1]
44
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
45
+ return np.stack([clahe.apply(green_channel)]*3, axis=-1)
46
+
47
+ def get_classifier_transform():
48
+ return transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
52
  ])
 
53
 
54
+ # ====================== SEGMENTATION ======================
55
+ @st.cache_resource
56
+ def load_segmenter():
57
+ model = torch.load('best_unet_model.pth', map_location=device)
58
+ model.eval()
59
+ return model
60
+
61
+ def preprocess_segmenter(image: Image.Image) -> np.ndarray:
62
+ """LAB + CLAHE + Median filtering"""
63
+ img_np = np.array(image)
64
+ img_filtered = cv2.medianBlur(img_np, 3)
65
+ lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB)
66
+ l, a, b = cv2.split(lab)
67
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
68
+ lab_clahe = cv2.merge((clahe.apply(l), a, b))
69
+ return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
70
+
71
+ def get_segmenter_transform():
72
+ return transforms.Compose([
73
  transforms.Resize((512, 512)),
74
  transforms.ToTensor(),
75
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
76
  ])
 
77
 
78
+ def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
79
+ """Convert 5-class output to 3-class mask"""
80
+ probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
81
+ pred_class = np.argmax(probs, axis=0)
82
+ final_mask = np.zeros_like(pred_class, dtype=np.uint8)
83
+ final_mask[(pred_class == 1) | (pred_class == 4)] = 1 # Bright
84
+ final_mask[(pred_class == 2) | (pred_class == 3)] = 2 # Red
85
+ return final_mask, probs
86
 
87
+ # ====================== VISUALIZATION ======================
88
+ def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image:
89
+ """Color-coded lesion overlay"""
90
+ original_np = np.array(original)
91
+ mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]),
92
+ interpolation=cv2.INTER_NEAREST)
 
93
 
94
+ overlay = original_np.copy()
95
+ for class_idx, color in LESION_COLORS.items():
96
+ overlay[mask_resized == class_idx] = color
97
+ return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0))
98
+
99
+ def create_heatmap(prob_map: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray:
100
+ """Probability heatmap visualization"""
101
+ resized = cv2.resize(prob_map, original_size)
102
+ return cv2.applyColorMap((resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
103
+
104
+ # ====================== MAIN APP ======================
105
+ def main():
106
+ st.set_page_config(layout="wide")
107
+ st.title("Diabetic Retinopathy Analysis")
108
 
109
+ uploaded_file = st.file_uploader("Upload retinal image", type=["jpg", "jpeg", "png"])
110
+ if not uploaded_file:
111
+ st.info("Please upload an image")
112
+ return
113
+
114
+ try:
115
+ original_image = Image.open(uploaded_file).convert('RGB')
116
+ col1, col2 = st.columns(2)
117
 
118
+ with col1:
119
+ st.image(original_image, caption="Original Image", use_column_width=True)
120
+
121
+ # Classification
122
+ classifier = load_classifier()
123
+ clf_processed = preprocess_classifier(original_image)
124
+ clf_transform = get_classifier_transform()
125
+ img_tensor = clf_transform(Image.fromarray(clf_processed)).unsqueeze(0).to(device)
126
 
127
  with torch.no_grad():
128
+ logps = classifier(img_tensor)
129
+ ps = torch.exp(logps)
130
+ pred_class = torch.argmax(ps).item()
131
+ probabilities = ps[0].cpu().numpy() * 100
132
+
133
+ st.subheader("Classification Results")
134
+ if pred_class == 0:
135
+ st.success(f"**Prediction:** {CLASS_NAMES[pred_class]}")
136
+ else:
137
+ st.error(f"**Prediction:** {CLASS_NAMES[pred_class]}")
138
+ st.write("**Confidence Levels:**")
139
+ for name, prob in zip(CLASS_NAMES, probabilities):
140
+ st.progress(int(prob))
141
+ st.write(f"{name}: {prob:.1f}%")
142
+
143
+ # Segmentation
144
+ segmenter = load_segmenter()
145
+ with st.spinner("Detecting lesions..."):
146
+ seg_results = segment_image(original_image, segmenter)
147
+ overlay = create_lesion_overlay(original_image, seg_results['mask'])
148
+ heat_bright = create_heatmap(seg_results['probs'][1] + seg_results['probs'][4],
149
+ original_image.size)
150
+ heat_red = create_heatmap(seg_results['probs'][2] + seg_results['probs'][3],
151
+ original_image.size)
152
+
153
+ with col2:
154
+ st.image(overlay, caption="Lesion Overlay", use_column_width=True)
155
+ st.image(heat_bright, caption="Bright Lesion Probability", use_column_width=True)
156
+ st.image(heat_red, caption="Red Lesion Probability", use_column_width=True)
157
+
158
+ # Metrics
159
+ st.write("**Lesion Analysis:**")
160
+ cols = st.columns(3)
161
+ cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%")
162
+ cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%")
163
+ cols[2].metric("Total Affected",
164
+ f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%")
165
+
166
+ # Download
167
+ st.download_button(
168
+ "Download Mask",
169
+ cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(),
170
+ "dr_mask.png",
171
+ "image/png"
172
+ )
173
+
174
+ except Exception as e:
175
+ st.error(f"Error processing image: {str(e)}")
176
+
177
+ if __name__ == "__main__":
178
+ main()