rnmee commited on
Commit
750718f
·
verified ·
1 Parent(s): 931d277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -16
app.py CHANGED
@@ -11,13 +11,74 @@ from typing import Tuple
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Constants
14
- CLASS_NAMES = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]
15
  LESION_COLORS = {
16
  0: [0, 0, 0], # Background (black)
17
  1: [255, 255, 0], # Bright lesions (yellow)
18
  2: [255, 0, 0] # Red lesions (red)
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ====================== CLASSIFIER ======================
22
  def create_classifier_model():
23
  model = models.resnet152(pretrained=False)
@@ -38,7 +99,6 @@ def load_classifier():
38
  return model
39
 
40
  def preprocess_classifier(image: Image.Image) -> np.ndarray:
41
- """Green channel + CLAHE preprocessing"""
42
  img_np = np.array(image)
43
  green_channel = img_np[:, :, 1]
44
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
@@ -54,12 +114,12 @@ def get_classifier_transform():
54
  # ====================== SEGMENTATION ======================
55
  @st.cache_resource
56
  def load_segmenter():
57
- model = torch.load('best_unet_model.pth', map_location=device)
 
58
  model.eval()
59
  return model
60
 
61
  def preprocess_segmenter(image: Image.Image) -> np.ndarray:
62
- """LAB + CLAHE + Median filtering"""
63
  img_np = np.array(image)
64
  img_filtered = cv2.medianBlur(img_np, 3)
65
  lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB)
@@ -76,20 +136,18 @@ def get_segmenter_transform():
76
  ])
77
 
78
  def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
79
- """Convert 5-class output to 3-class mask"""
80
  probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
81
  pred_class = np.argmax(probs, axis=0)
82
  final_mask = np.zeros_like(pred_class, dtype=np.uint8)
83
- final_mask[(pred_class == 1) | (pred_class == 4)] = 1 # Bright
84
- final_mask[(pred_class == 2) | (pred_class == 3)] = 2 # Red
85
  return final_mask, probs
86
 
87
  # ====================== VISUALIZATION ======================
88
  def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image:
89
- """Color-coded lesion overlay"""
90
  original_np = np.array(original)
91
  mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]),
92
- interpolation=cv2.INTER_NEAREST)
93
 
94
  overlay = original_np.copy()
95
  for class_idx, color in LESION_COLORS.items():
@@ -97,10 +155,27 @@ def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Imag
97
  return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0))
98
 
99
  def create_heatmap(prob_map: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray:
100
- """Probability heatmap visualization"""
101
  resized = cv2.resize(prob_map, original_size)
102
  return cv2.applyColorMap((resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # ====================== MAIN APP ======================
105
  def main():
106
  st.set_page_config(layout="wide")
@@ -116,13 +191,12 @@ def main():
116
  col1, col2 = st.columns(2)
117
 
118
  with col1:
119
- st.image(original_image, caption="Original Image", use_column_width=True)
120
 
121
  # Classification
122
  classifier = load_classifier()
123
  clf_processed = preprocess_classifier(original_image)
124
- clf_transform = get_classifier_transform()
125
- img_tensor = clf_transform(Image.fromarray(clf_processed)).unsqueeze(0).to(device)
126
 
127
  with torch.no_grad():
128
  logps = classifier(img_tensor)
@@ -151,9 +225,9 @@ def main():
151
  original_image.size)
152
 
153
  with col2:
154
- st.image(overlay, caption="Lesion Overlay", use_column_width=True)
155
- st.image(heat_bright, caption="Bright Lesion Probability", use_column_width=True)
156
- st.image(heat_red, caption="Red Lesion Probability", use_column_width=True)
157
 
158
  # Metrics
159
  st.write("**Lesion Analysis:**")
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Constants
14
+ CLASS_NAMES = ["No_DR", "Mild", "Moderate", "Severe", "Proliferate_DR"]
15
  LESION_COLORS = {
16
  0: [0, 0, 0], # Background (black)
17
  1: [255, 255, 0], # Bright lesions (yellow)
18
  2: [255, 0, 0] # Red lesions (red)
19
  }
20
 
21
+ # ====================== UNET ARCHITECTURE ======================
22
+ class UNet(nn.Module):
23
+ def __init__(self, input_channels=3, num_classes=5):
24
+ super(UNet, self).__init__()
25
+
26
+ def conv_block(in_channels, out_channels):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
29
+ nn.ReLU(inplace=True),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
32
+ nn.ReLU(inplace=True),
33
+ nn.BatchNorm2d(out_channels),
34
+ )
35
+
36
+ self.encoder1 = conv_block(input_channels, 32)
37
+ self.pool1 = nn.MaxPool2d(2)
38
+ self.encoder2 = conv_block(32, 64)
39
+ self.pool2 = nn.MaxPool2d(2)
40
+ self.encoder3 = conv_block(64, 128)
41
+ self.pool3 = nn.MaxPool2d(2)
42
+
43
+ self.bottleneck = conv_block(128, 256)
44
+
45
+ self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
46
+ self.decoder3 = conv_block(256, 128)
47
+
48
+ self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
49
+ self.decoder2 = conv_block(128, 64)
50
+
51
+ self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
52
+ self.decoder1 = conv_block(64, 32)
53
+
54
+ self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
55
+
56
+ def forward(self, x):
57
+ enc1 = self.encoder1(x)
58
+ x = self.pool1(enc1)
59
+
60
+ enc2 = self.encoder2(x)
61
+ x = self.pool2(enc2)
62
+
63
+ enc3 = self.encoder3(x)
64
+ x = self.pool3(enc3)
65
+
66
+ x = self.bottleneck(x)
67
+
68
+ x = self.up3(x)
69
+ x = torch.cat([x, enc3], dim=1)
70
+ x = self.decoder3(x)
71
+
72
+ x = self.up2(x)
73
+ x = torch.cat([x, enc2], dim=1)
74
+ x = self.decoder2(x)
75
+
76
+ x = self.up1(x)
77
+ x = torch.cat([x, enc1], dim=1)
78
+ x = self.decoder1(x)
79
+
80
+ return self.final_conv(x)
81
+
82
  # ====================== CLASSIFIER ======================
83
  def create_classifier_model():
84
  model = models.resnet152(pretrained=False)
 
99
  return model
100
 
101
  def preprocess_classifier(image: Image.Image) -> np.ndarray:
 
102
  img_np = np.array(image)
103
  green_channel = img_np[:, :, 1]
104
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
 
114
  # ====================== SEGMENTATION ======================
115
  @st.cache_resource
116
  def load_segmenter():
117
+ model = UNet().to(device)
118
+ model.load_state_dict(torch.load('best_unet_model.pth', map_location=device))
119
  model.eval()
120
  return model
121
 
122
  def preprocess_segmenter(image: Image.Image) -> np.ndarray:
 
123
  img_np = np.array(image)
124
  img_filtered = cv2.medianBlur(img_np, 3)
125
  lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB)
 
136
  ])
137
 
138
  def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
 
139
  probs = torch.softmax(output, dim=1).cpu().numpy().squeeze()
140
  pred_class = np.argmax(probs, axis=0)
141
  final_mask = np.zeros_like(pred_class, dtype=np.uint8)
142
+ final_mask[(pred_class == 1) | (pred_class == 4)] = 1
143
+ final_mask[(pred_class == 2) | (pred_class == 3)] = 2
144
  return final_mask, probs
145
 
146
  # ====================== VISUALIZATION ======================
147
  def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image:
 
148
  original_np = np.array(original)
149
  mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]),
150
+ interpolation=cv2.INTER_NEAREST)
151
 
152
  overlay = original_np.copy()
153
  for class_idx, color in LESION_COLORS.items():
 
155
  return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0))
156
 
157
  def create_heatmap(prob_map: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray:
 
158
  resized = cv2.resize(prob_map, original_size)
159
  return cv2.applyColorMap((resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
160
 
161
+ def segment_image(image: Image.Image, model: nn.Module) -> dict:
162
+ processed_img = preprocess_segmenter(image)
163
+ img_pil = Image.fromarray(processed_img)
164
+ transform = get_segmenter_transform()
165
+ image_tensor = transform(img_pil).unsqueeze(0).to(device)
166
+
167
+ with torch.no_grad():
168
+ output = model(image_tensor)
169
+
170
+ final_mask, class_probs = process_segmentation_output(output)
171
+ total_pixels = final_mask.size
172
+ return {
173
+ 'mask': final_mask,
174
+ 'probs': class_probs,
175
+ 'bright_area': (np.sum(final_mask == 1) / total_pixels * 100),
176
+ 'red_area': (np.sum(final_mask == 2) / total_pixels * 100)
177
+ }
178
+
179
  # ====================== MAIN APP ======================
180
  def main():
181
  st.set_page_config(layout="wide")
 
191
  col1, col2 = st.columns(2)
192
 
193
  with col1:
194
+ st.image(original_image, caption="Original Image", use_container_width=True)
195
 
196
  # Classification
197
  classifier = load_classifier()
198
  clf_processed = preprocess_classifier(original_image)
199
+ img_tensor = get_classifier_transform()(Image.fromarray(clf_processed)).unsqueeze(0).to(device)
 
200
 
201
  with torch.no_grad():
202
  logps = classifier(img_tensor)
 
225
  original_image.size)
226
 
227
  with col2:
228
+ st.image(overlay, caption="Lesion Overlay", use_container_width=True)
229
+ st.image(heat_bright, caption="Bright Lesion Probability", use_container_width=True)
230
+ st.image(heat_red, caption="Red Lesion Probability", use_container_width=True)
231
 
232
  # Metrics
233
  st.write("**Lesion Analysis:**")