rprkh commited on
Commit
1909cd0
·
verified ·
1 Parent(s): 195e04a

update app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -74
app.py CHANGED
@@ -33,22 +33,263 @@ import albumentations as A
33
 
34
  from huggingface_hub import login, hf_hub_download
35
 
36
- from unet_models import UNet
37
- from segformer_model import SegFormer
38
- from multimodal_model import MultimodalModel
39
-
40
 
41
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  model = UNet(3, 1).to(DEVICE)
45
  out = model(torch.randn(1, 3, 128, 128).to(DEVICE))
46
  print(out.shape)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  model = SegFormer(n_channels=3, n_classes=1).to(DEVICE)
49
  out = model(torch.randn(1, 3, 128, 128).to(DEVICE))
50
  print(out.shape)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  secret_value_0 = os.getenv("carc_hf_token")
53
  login(token=secret_value_0)
54
  print("Logged in successfully")
@@ -102,7 +343,6 @@ segformer_macula_segmentation_model.eval()
102
  segformer_macula_segmentation_model = segformer_macula_segmentation_model.to(DEVICE)
103
  print("Loaded SegFormer macula segmentation model")
104
 
105
-
106
  multimodal_glaucoma_classification_model = MultimodalModel(num_numeric_features=23, num_classes=2)
107
  multimodal_glaucoma_classification_model_path = hf_hub_download(
108
  repo_id="rprkh/multimodal_glaucoma_classification",
@@ -191,7 +431,7 @@ def extract_macula_area(image_pil):
191
 
192
  with torch.no_grad():
193
  output = segformer_macula_segmentation_model(img_tensor)
194
- mask = (output > 0.5).float()
195
 
196
  macula_area = mask.sum().item()
197
  return macula_area
@@ -243,10 +483,7 @@ def compute_vasculature_density(image_pil, model, device, threshold=0.05, radius
243
  roi_area = roi_tensor.sum().item()
244
  density = vessel_area / roi_area if roi_area > 0 else 0.0
245
 
246
- overlay_image = blend_image_with_mask(image_pil, masked) # Green vessels
247
- vasculature_density_masked = Image.fromarray((overlay_image * 255).astype(np.uint8))
248
-
249
- return density, vasculature_density_masked
250
 
251
  def create_circular_roi_mask(image_shape, radius_ratio=0.95):
252
  h, w = image_shape
@@ -405,7 +642,7 @@ def predict_all_diameters(image_path):
405
 
406
  vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB) / 255.0
407
 
408
- vd, vasculature_density_masked = compute_vasculature_density(
409
  image_pil=image_pil,
410
  model=retinal_vasculature_segmentation_model,
411
  device=DEVICE,
@@ -532,65 +769,65 @@ def predict_all_diameters(image_path):
532
  return measurement
533
 
534
  result_text = f"""
535
- <div id="results_container">
536
- <div id="results_table" style="display:flex; gap:7px;">
537
- <div style="flex:1;">
538
- <table>
539
- <tr><th>Measurement</th><th>Value</th></tr>
540
-
541
- <tr><td><b>Retina Diameter</b></td><td>{convert_to_mm(retina_diameter)} mm</td></tr>
542
- <tr><td><b>Optic Cup Diameter</b></td><td>{convert_to_mm(cup_diameter)} mm</td></tr>
543
- <tr><td><b>Optic Disc Diameter</b></td><td>{convert_to_mm(disc_diameter)} mm</td></tr>
544
- <tr><td><b>Macular Diameter</b></td><td>{convert_to_mm(macula_diameter)} mm</td></tr>
545
- <tr><td><b>Vasculature Density</b></td><td>{round(vd * 100, 3)}%</td></tr>
546
-
547
- <tr><td><b>Retina Radius</b></td><td>{convert_to_mm(retina_radius)} mm</td></tr>
548
- <tr><td><b>Retina Area</b></td><td>{convert_to_mm2(retina_area)} mm<sup>2</sup></td></tr>
549
- <tr><td><b>Retina Circumference</b></td><td>{convert_to_mm(retina_circumference)} mm</td></tr>
550
- </table>
551
- </div>
552
-
553
- <div style="flex:1;">
554
- <table>
555
- <tr><th>Measurement</th><th>Value</th></tr>
556
-
557
- <tr><td><b>Optic Cup Radius</b></td><td>{convert_to_mm(optic_cup_radius)} mm</td></tr>
558
- <tr><td><b>Optic Cup Area</b></td><td>{convert_to_mm2(optic_disc_area)} mm<sup>2</sup></td></tr>
559
- <tr><td><b>Optic Cup Circumference</b></td><td>{convert_to_mm(optic_cup_circumference)} mm</td></tr>
560
-
561
- <tr><td><b>Optic Disc Radius</b></td><td>{convert_to_mm(optic_disc_radius)} mm</td></tr>
562
- <tr><td><b>Optic Disc Area</b></td><td>{convert_to_mm2(optic_cup_area)} mm<sup>2</sup></td></tr>
563
- <tr><td><b>Optic Disc Circumference</b></td><td>{convert_to_mm(optic_disc_circumference)} mm</td></tr>
564
-
565
- <tr><td><b>Macula Radius</b></td><td>{convert_to_mm(macula_radius)} mm</td></tr>
566
- <tr><td><b>Macula Area</b></td><td>{convert_to_mm(macula_area)} mm</td></tr>
567
-
568
- </table>
569
- </div>
570
-
571
- <div style="flex:1;">
572
- <table>
573
- <tr><th>Measurement</th><th>Value</th></tr>
574
-
575
- <tr><td><b>Macula Circumference</b></td><td>{convert_to_mm(macula_circumference)} mm</td></tr>
576
-
577
- <tr><td><b>Optic Disc to Retina Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_disc_to_retina_diameter_ratio)}</td></tr>
578
- <tr><td><b>Optic Disc to Retina Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_disc_to_retina_area_ratio)}</td></tr>
579
-
580
- <tr><td><b>Optic Cup to Disc Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_disc_diameter_ratio)}</td></tr>
581
- <tr><td><b>Optic Cup to Disc Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_disc_area_ratio)}</td></tr>
582
- <tr><td><b>Optic Cup to Retina Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_retina_diameter_ratio)}</td></tr>
583
- <tr><td><b>Optic Cup to Retina Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_retina_area_ratio)}</td></tr>
584
- </table>
585
- </div>
586
- </div>
587
-
588
- <h3>Predicted Class: {prediction}</h3>
589
- <h3>Confidence: {round(confidence * 100, 3)}%</h3>
590
- <div>
591
- """
592
 
593
- return result_text, vasculature_density_masked
594
 
595
 
596
  custom_css = """
@@ -661,17 +898,14 @@ with gr.Blocks(title="Glaucoma Predictor", css=custom_css) as demo:
661
  image_input = gr.Image(type="filepath", label="Upload Fundus Image", elem_id="image_box")
662
  btn = gr.Button("Analyze Image", variant="primary", elem_id="prediction_button")
663
  result_md = gr.Markdown(elem_id="results_container")
664
-
665
- with gr.Row():
666
- vasculature_density_masked = gr.Image(label="Segmentation Visualization", elem_id="Retina Diameter")
667
 
668
  btn.click(
669
  fn=lambda: ("Analyzing... Please wait.",),
670
- outputs=[result_md, vasculature_density_masked]
671
  ).then(
672
  fn=predict_all_diameters,
673
  inputs=image_input,
674
- outputs=[result_md, vasculature_density_masked]
675
  )
676
 
677
  if __name__ == "__main__":
 
33
 
34
  from huggingface_hub import login, hf_hub_download
35
 
 
 
 
 
36
 
37
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
 
40
+ class DoubleConv(nn.Module):
41
+ def __init__(self, in_channels, out_channels, mid_channels=None):
42
+ super().__init__()
43
+ if not mid_channels:
44
+ mid_channels = out_channels
45
+ self.double_conv = nn.Sequential(
46
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
47
+ nn.BatchNorm2d(mid_channels),
48
+ nn.ReLU(inplace=True),
49
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
50
+ nn.BatchNorm2d(out_channels),
51
+ nn.ReLU(inplace=True))
52
+ def forward(self, x):
53
+ return self.double_conv(x)
54
+
55
+ class Down(nn.Module):
56
+ def __init__(self, in_channels, out_channels):
57
+ super().__init__()
58
+ self.maxpool_conv = nn.Sequential(
59
+ nn.MaxPool2d(2),
60
+ DoubleConv(in_channels, out_channels))
61
+ def forward(self, x):
62
+ return self.maxpool_conv(x)
63
+
64
+ class Up(nn.Module):
65
+ def __init__(self, in_channels, out_channels, bilinear=True):
66
+ super().__init__()
67
+ if bilinear:
68
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
69
+ self.conv = DoubleConv(in_channels, out_channels, in_channels//2)
70
+ else:
71
+ self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
72
+ self.conv = DoubleConv(in_channels, out_channels)
73
+ def forward(self, x1, x2):
74
+ x1 = self.up(x1)
75
+ diffY = x2.size()[2] - x1.size()[2]
76
+ diffX = x2.size()[3] - x1.size()[3]
77
+
78
+ x1 = F.pad(x1, [diffX//2, diffX-diffX//2,
79
+ diffY//2, diffY-diffY//2])
80
+ x = torch.cat([x2, x1], dim=1)
81
+ return self.conv(x)
82
+
83
+ class OutConv(nn.Module):
84
+ def __init__(self, in_channels, out_channels):
85
+ super(OutConv, self).__init__()
86
+ self.conv = nn.Sequential(
87
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
88
+ nn.Sigmoid())
89
+ def forward(self, x):
90
+ return self.conv(x)
91
+
92
+ class UNet(nn.Module):
93
+ def __init__(self, n_channels, n_classes, bilinear=True):
94
+ super(UNet, self).__init__()
95
+ self.n_channels = n_channels
96
+ self.n_classes = n_classes
97
+ self.bilinear = bilinear
98
+
99
+ self.inc = DoubleConv(n_channels, 64)
100
+ self.down1 = Down(64, 128)
101
+ self.down2 = Down(128, 256)
102
+ self.down3 = Down(256, 512)
103
+ factor = 2 if bilinear else 1
104
+ self.down4 = Down(512, 1024//factor)
105
+ self.up1 = Up(1024, 512//factor, bilinear)
106
+ self.up2 = Up(512, 256//factor, bilinear)
107
+ self.up3 = Up(256, 128//factor, bilinear)
108
+ self.up4 = Up(128, 64, bilinear)
109
+ self.outc = OutConv(64, n_classes)
110
+
111
+ def forward(self, x):
112
+ x1 = self.inc(x)
113
+ x2 = self.down1(x1)
114
+ x3 = self.down2(x2)
115
+ x4 = self.down3(x3)
116
+ x5 = self.down4(x4)
117
+ x = self.up1(x5, x4)
118
+ x = self.up2(x, x3)
119
+ x = self.up3(x, x2)
120
+ x = self.up4(x, x1)
121
+ logits = self.outc(x)
122
+ return logits
123
+
124
  model = UNet(3, 1).to(DEVICE)
125
  out = model(torch.randn(1, 3, 128, 128).to(DEVICE))
126
  print(out.shape)
127
 
128
+
129
+ class SqueezeExcitation(nn.Module):
130
+ def __init__(self, channels: int, reduction: int = 16):
131
+ super(SqueezeExcitation, self).__init__()
132
+ self.se = nn.Sequential(
133
+ nn.AdaptiveAvgPool2d(1),
134
+ nn.Conv2d(channels, channels // reduction, kernel_size=1),
135
+ nn.ReLU(inplace=True),
136
+ nn.Conv2d(channels // reduction, channels, kernel_size=1),
137
+ nn.Sigmoid()
138
+ )
139
+
140
+ def forward(self, x: Tensor) -> Tensor:
141
+ return x * self.se(x)
142
+
143
+ class SegFormer(nn.Module):
144
+ def __init__(self, n_channels: int, n_classes: int, pretrained_model: str = "nvidia/mit-b5"):
145
+ super(SegFormer, self).__init__()
146
+ self.n_channels = n_channels
147
+ self.n_classes = n_classes
148
+
149
+ config = SegformerConfig.from_pretrained(
150
+ pretrained_model,
151
+ num_channels=n_channels,
152
+ num_labels=n_classes,
153
+ hidden_dropout_prob=0.3,
154
+ attention_probs_dropout_prob=0.3,
155
+ drop_path_rate=0.1
156
+ )
157
+
158
+ self.segformer = SegformerForSemanticSegmentation.from_pretrained(
159
+ pretrained_model,
160
+ config=config,
161
+ ignore_mismatched_sizes=True
162
+ )
163
+
164
+ if n_channels != 3:
165
+ self.segformer.segformer.encoder.patch_embeddings[0].proj = nn.Conv2d(
166
+ n_channels, config.hidden_sizes[0], kernel_size=7, stride=4, padding=2
167
+ )
168
+
169
+ self.segformer.decode_head.classifier = nn.Sequential(
170
+ nn.Conv2d(config.decoder_hidden_size, config.decoder_hidden_size // 2, kernel_size=3, padding=1),
171
+ nn.BatchNorm2d(config.decoder_hidden_size // 2, momentum=0.05),
172
+ nn.ReLU(inplace=True),
173
+ nn.Dropout2d(0.4),
174
+ nn.Conv2d(config.decoder_hidden_size // 2, n_classes, kernel_size=1)
175
+ )
176
+
177
+ self.fpn = nn.ModuleList([
178
+ nn.Sequential(
179
+ nn.Conv2d(h, 128, kernel_size=1),
180
+ nn.BatchNorm2d(128, momentum=0.05),
181
+ nn.ReLU(inplace=True),
182
+ SqueezeExcitation(128)
183
+ ) for h in config.hidden_sizes
184
+ ])
185
+
186
+ self.fusion = nn.Sequential(
187
+ nn.Conv2d(128 * len(config.hidden_sizes), 256, kernel_size=3, padding=1),
188
+ nn.BatchNorm2d(256, momentum=0.05),
189
+ nn.ReLU(inplace=True),
190
+ nn.Dropout2d(0.3),
191
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
192
+ nn.BatchNorm2d(256, momentum=0.05),
193
+ nn.ReLU(inplace=True)
194
+ )
195
+
196
+ self.fusion_residual = nn.Conv2d(128 * len(config.hidden_sizes), 256, kernel_size=1)
197
+
198
+ self.refinement = nn.Sequential(
199
+ nn.Conv2d(256 + n_classes, 128, kernel_size=3, padding=1),
200
+ nn.BatchNorm2d(128, momentum=0.05),
201
+ nn.ReLU(inplace=True),
202
+ nn.Dropout2d(0.2),
203
+ nn.Conv2d(128, n_classes, kernel_size=1)
204
+ )
205
+
206
+ def forward(self, x: Tensor) -> Tensor:
207
+ input_size = x.size()[2:]
208
+
209
+ outputs = self.segformer(pixel_values=x)
210
+ logits = outputs.logits
211
+
212
+ encoder_outputs = self.segformer.segformer.encoder(pixel_values=x, output_hidden_states=True)
213
+ hidden_states = encoder_outputs.hidden_states
214
+
215
+ fpn_feats = []
216
+ for i, (feat, layer) in enumerate(zip(hidden_states, self.fpn)):
217
+ f = layer(feat)
218
+ f = F.interpolate(f, size=logits.shape[2:], mode="bilinear", align_corners=False)
219
+ fpn_feats.append(f)
220
+
221
+ fused = torch.cat(fpn_feats, dim=1)
222
+ residual = self.fusion_residual(fused)
223
+ fused = self.fusion(fused)
224
+ fused = fused + residual
225
+
226
+ logits = F.interpolate(logits, size=input_size, mode="bilinear", align_corners=False)
227
+ fused = F.interpolate(fused, size=input_size, mode="bilinear", align_corners=False)
228
+
229
+ concat = torch.cat([fused, logits], dim=1)
230
+ out = self.refinement(concat)
231
+
232
+ return out
233
+
234
  model = SegFormer(n_channels=3, n_classes=1).to(DEVICE)
235
  out = model(torch.randn(1, 3, 128, 128).to(DEVICE))
236
  print(out.shape)
237
 
238
+
239
+ class MultimodalModel(nn.Module):
240
+ def __init__(self, num_numeric_features, num_classes):
241
+ super(MultimodalModel, self).__init__()
242
+
243
+ self.vit = models.vit_b_16(pretrained=True)
244
+ self.vit.heads = nn.Identity()
245
+
246
+ self.swin_b = models.swin_b(pretrained=True)
247
+ self.swin_b.head = nn.Identity()
248
+
249
+ self.swinv2_b = models.swin_v2_b(pretrained=True)
250
+ self.swinv2_b.head = nn.Identity()
251
+
252
+ self.numeric_branch = nn.Sequential(
253
+ nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
254
+ nn.BatchNorm1d(16),
255
+ nn.ReLU(),
256
+ nn.MaxPool1d(kernel_size=2, stride=2),
257
+ nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
258
+ nn.BatchNorm1d(32),
259
+ nn.ReLU(),
260
+ nn.MaxPool1d(kernel_size=2, stride=2),
261
+ nn.Flatten(),
262
+ nn.Linear((num_numeric_features // 4) * 32, 64),
263
+ nn.BatchNorm1d(64),
264
+ nn.ReLU(),
265
+ nn.Linear(64, num_classes)
266
+ )
267
+
268
+ self.image_fc = nn.Sequential(
269
+ nn.Linear(768 + 1024 + 1024, 128),
270
+ nn.ReLU(),
271
+ nn.Dropout(0.3),
272
+ nn.Linear(128, num_classes)
273
+ )
274
+
275
+ def forward(self, image, numeric_data):
276
+ vit_features = self.vit(image)
277
+
278
+ swin_b_features = self.swin_b(image)
279
+
280
+ swinv2_b_features = self.swinv2_b(image)
281
+
282
+ combined_image_features = torch.cat((vit_features, swin_b_features, swinv2_b_features), dim=1) # Shape: (N, 2816)
283
+ combined_image_output = self.image_fc(combined_image_features)
284
+
285
+ numeric_data = numeric_data.unsqueeze(1)
286
+ numeric_output = self.numeric_branch(numeric_data)
287
+
288
+ final_output = 0.95 * combined_image_output + 0.05 * numeric_output
289
+
290
+ return final_output
291
+
292
+
293
  secret_value_0 = os.getenv("carc_hf_token")
294
  login(token=secret_value_0)
295
  print("Logged in successfully")
 
343
  segformer_macula_segmentation_model = segformer_macula_segmentation_model.to(DEVICE)
344
  print("Loaded SegFormer macula segmentation model")
345
 
 
346
  multimodal_glaucoma_classification_model = MultimodalModel(num_numeric_features=23, num_classes=2)
347
  multimodal_glaucoma_classification_model_path = hf_hub_download(
348
  repo_id="rprkh/multimodal_glaucoma_classification",
 
431
 
432
  with torch.no_grad():
433
  output = segformer_macula_segmentation_model(img_tensor)
434
+ mask = (output > 0.5).float() # (1,1,H,W)
435
 
436
  macula_area = mask.sum().item()
437
  return macula_area
 
483
  roi_area = roi_tensor.sum().item()
484
  density = vessel_area / roi_area if roi_area > 0 else 0.0
485
 
486
+ return density
 
 
 
487
 
488
  def create_circular_roi_mask(image_shape, radius_ratio=0.95):
489
  h, w = image_shape
 
642
 
643
  vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB) / 255.0
644
 
645
+ vd = compute_vasculature_density(
646
  image_pil=image_pil,
647
  model=retinal_vasculature_segmentation_model,
648
  device=DEVICE,
 
769
  return measurement
770
 
771
  result_text = f"""
772
+ <div id="results_container">
773
+ <div id="results_table" style="display:flex; gap:7px;">
774
+ <div style="flex:1;">
775
+ <table>
776
+ <tr><th>Measurement</th><th>Value</th></tr>
777
+
778
+ <tr><td><b>Retina Diameter</b></td><td>{convert_to_mm(retina_diameter)} mm</td></tr>
779
+ <tr><td><b>Optic Cup Diameter</b></td><td>{convert_to_mm(cup_diameter)} mm</td></tr>
780
+ <tr><td><b>Optic Disc Diameter</b></td><td>{convert_to_mm(disc_diameter)} mm</td></tr>
781
+ <tr><td><b>Macular Diameter</b></td><td>{convert_to_mm(macula_diameter)} mm</td></tr>
782
+ <tr><td><b>Vasculature Density</b></td><td>{round(vd * 100, 3)}%</td></tr>
783
+
784
+ <tr><td><b>Retina Radius</b></td><td>{convert_to_mm(retina_radius)} mm</td></tr>
785
+ <tr><td><b>Retina Area</b></td><td>{convert_to_mm2(retina_area)} mm<sup>2</sup></td></tr>
786
+ <tr><td><b>Retina Circumference</b></td><td>{convert_to_mm(retina_circumference)} mm</td></tr>
787
+ </table>
788
+ </div>
789
+
790
+ <div style="flex:1;">
791
+ <table>
792
+ <tr><th>Measurement</th><th>Value</th></tr>
793
+
794
+ <tr><td><b>Optic Cup Radius</b></td><td>{convert_to_mm(optic_cup_radius)} mm</td></tr>
795
+ <tr><td><b>Optic Cup Area</b></td><td>{convert_to_mm2(optic_disc_area)} mm<sup>2</sup></td></tr>
796
+ <tr><td><b>Optic Cup Circumference</b></td><td>{convert_to_mm(optic_cup_circumference)} mm</td></tr>
797
+
798
+ <tr><td><b>Optic Disc Radius</b></td><td>{convert_to_mm(optic_disc_radius)} mm</td></tr>
799
+ <tr><td><b>Optic Disc Area</b></td><td>{convert_to_mm2(optic_cup_area)} mm<sup>2</sup></td></tr>
800
+ <tr><td><b>Optic Disc Circumference</b></td><td>{convert_to_mm(optic_disc_circumference)} mm</td></tr>
801
+
802
+ <tr><td><b>Macula Radius</b></td><td>{convert_to_mm(macula_radius)} mm</td></tr>
803
+ <tr><td><b>Macula Area</b></td><td>{convert_to_mm(macula_area)} mm</td></tr>
804
+
805
+ </table>
806
+ </div>
807
+
808
+ <div style="flex:1;">
809
+ <table>
810
+ <tr><th>Measurement</th><th>Value</th></tr>
811
+
812
+ <tr><td><b>Macula Circumference</b></td><td>{convert_to_mm(macula_circumference)} mm</td></tr>
813
+
814
+ <tr><td><b>Optic Disc to Retina Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_disc_to_retina_diameter_ratio)}</td></tr>
815
+ <tr><td><b>Optic Disc to Retina Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_disc_to_retina_area_ratio)}</td></tr>
816
+
817
+ <tr><td><b>Optic Cup to Disc Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_disc_diameter_ratio)}</td></tr>
818
+ <tr><td><b>Optic Cup to Disc Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_disc_area_ratio)}</td></tr>
819
+ <tr><td><b>Optic Cup to Retina Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_retina_diameter_ratio)}</td></tr>
820
+ <tr><td><b>Optic Cup to Retina Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_retina_area_ratio)}</td></tr>
821
+ </table>
822
+ </div>
823
+ </div>
824
+
825
+ <h3>Predicted Class: {prediction}</h3>
826
+ <h3>Confidence: {round(confidence * 100, 3)}%</h3>
827
+ <div>
828
+ """
829
 
830
+ return result_text
831
 
832
 
833
  custom_css = """
 
898
  image_input = gr.Image(type="filepath", label="Upload Fundus Image", elem_id="image_box")
899
  btn = gr.Button("Analyze Image", variant="primary", elem_id="prediction_button")
900
  result_md = gr.Markdown(elem_id="results_container")
 
 
 
901
 
902
  btn.click(
903
  fn=lambda: ("Analyzing... Please wait.",),
904
+ outputs=[result_md]
905
  ).then(
906
  fn=predict_all_diameters,
907
  inputs=image_input,
908
+ outputs=[result_md]
909
  )
910
 
911
  if __name__ == "__main__":