Rahil Parikh commited on
Commit
14c75f1
·
1 Parent(s): d371573

restore to original

Browse files
Files changed (1) hide show
  1. app.py +309 -180
app.py CHANGED
@@ -376,7 +376,7 @@ def preprocess_for_unet(image_path):
376
  img_t = img.transpose(2, 0, 1)
377
  img_t = torch.from_numpy(img_t).unsqueeze(0).to(DEVICE)
378
  img_t = normalize(img_t)
379
- return img_t, img
380
 
381
  def denormalize(tensor):
382
  means = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
@@ -431,7 +431,7 @@ def extract_macula_area(image_pil):
431
 
432
  with torch.no_grad():
433
  output = segformer_macula_segmentation_model(img_tensor)
434
- mask = (output > 0.5).float()
435
 
436
  macula_area = mask.sum().item()
437
  return macula_area
@@ -445,16 +445,26 @@ def fit_min_enclosing_circle(mask_np):
445
  (x, y), radius = cv2.minEnclosingCircle(largest)
446
  return (x, y), radius, 2 * radius
447
 
 
 
 
 
 
 
 
 
 
448
  def refine_mask(mask):
449
  m = mask.cpu().numpy().squeeze()
450
  m = ndimage.binary_closing(m, structure=np.ones((3, 3)), iterations=1)
451
- return torch.from_numpy(m).float().to(DEVICE)
452
 
453
  def load_image_from_pil(image_pil, device):
454
  transform = transforms.Compose([
455
  transforms.Resize((128, 128)),
456
  transforms.ToTensor(),
457
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
458
  ])
459
  return transform(image_pil).unsqueeze(0).to(device)
460
 
@@ -463,16 +473,17 @@ def compute_vasculature_density(image_pil, model, device, threshold=0.05, radius
463
  with torch.no_grad():
464
  img_tensor = load_image_from_pil(image_pil, device)
465
  output = model(img_tensor)
466
- pred = (output > threshold).float()
467
  refined = refine_mask(pred)
468
  h, w = refined.shape[-2], refined.shape[-1]
469
- roi = create_circular_roi_mask((h, w), radius_ratio)
470
  roi_tensor = torch.from_numpy(roi).to(device)
471
  masked = refined * roi_tensor
472
  vessel_area = masked.sum().item()
473
- roi_area = roi_tensor.sum().item()
474
  density = vessel_area / roi_area if roi_area > 0 else 0.0
475
- return density
 
476
 
477
  def create_circular_roi_mask(image_shape, radius_ratio=0.95):
478
  h, w = image_shape
@@ -487,150 +498,136 @@ def create_dataset_statistics(df):
487
  df["optic_cup_diameter"] = pd.to_numeric(df["optic_cup_diameter"], errors="coerce")
488
  df["optic_disc_diameter"] = pd.to_numeric(df["optic_disc_diameter"], errors="coerce")
489
  df["macula_diameter"] = pd.to_numeric(df["macula_diameter"], errors="coerce")
490
- df["vasculature_density"] = pd.to_numeric(df["vasculature_density"], errors="coerce")
491
 
492
  df["retina_radius"] = df["retina_diameter"] / 2
493
  df["retina_area"] = np.pi * (df["retina_radius"] ** 2)
494
  df["retina_circumference"] = 2 * np.pi * df["retina_radius"]
 
495
  df["optic_cup_radius"] = df["optic_cup_diameter"] / 2
496
  df["optic_cup_area"] = np.pi * (df["optic_cup_radius"] ** 2)
497
  df["optic_cup_circumference"] = 2 * np.pi * df["optic_cup_radius"]
 
498
  df["optic_disc_radius"] = df["optic_disc_diameter"] / 2
499
  df["optic_disc_area"] = np.pi * (df["optic_disc_radius"] ** 2)
500
  df["optic_disc_circumference"] = 2 * np.pi * df["optic_disc_radius"]
 
501
  df["macula_radius"] = df["macula_diameter"] / 2
502
  df["macula_area"] = np.pi * (df["macula_radius"] ** 2)
503
  df["macula_circumference"] = 2 * np.pi * df["macula_radius"]
504
 
505
  df["optic_disc_to_retina_diameter_ratio"] = df["optic_disc_diameter"] / df["retina_diameter"]
506
  df["optic_disc_to_retina_area_ratio"] = df["optic_disc_area"] / df["retina_area"]
 
507
  df["optic_cup_to_disc_diameter_ratio"] = df["optic_cup_diameter"] / df["optic_disc_diameter"]
508
  df["optic_cup_to_disc_area_ratio"] = df["optic_cup_area"] / df["optic_disc_area"]
 
509
  df["optic_cup_to_retina_diameter_ratio"] = df["optic_cup_diameter"] / df["retina_diameter"]
510
  df["optic_cup_to_retina_area_ratio"] = df["optic_cup_area"] / df["retina_area"]
511
 
 
 
512
  return df
513
 
514
  def reorder_df(df):
515
- return df[["retina_diameter", "optic_disc_diameter", "optic_cup_diameter", "macula_diameter", "vasculature_density",
516
- "retina_radius", "retina_area", "retina_circumference", "optic_cup_radius", "optic_cup_area",
517
- "optic_cup_circumference", "optic_disc_radius", "optic_disc_area", "optic_disc_circumference",
518
- "macula_radius", "macula_area", "macula_circumference", "optic_disc_to_retina_diameter_ratio",
519
- "optic_disc_to_retina_area_ratio", "optic_cup_to_disc_diameter_ratio", "optic_cup_to_disc_area_ratio",
520
- "optic_cup_to_retina_diameter_ratio", "optic_cup_to_retina_area_ratio"]]
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  def load_scaler(scaler_file_path):
523
- with open(scaler_file_path, 'rb') as f:
524
  scaler = pickle.load(f)
 
 
525
  return scaler
526
 
527
  def predict_multimodal(model, image_path, df_row, device=DEVICE):
528
  model.eval()
 
529
  image_pil = Image.open(image_path).convert("RGB")
 
530
  img = test_val_transform(image_pil).unsqueeze(0).to(device)
531
- numeric_data = torch.tensor(df_row.astype(float).values, dtype=torch.float32).unsqueeze(0).to(device)
 
 
 
 
 
532
  with torch.no_grad():
533
  logits = model(img, numeric_data)
534
  probs = torch.softmax(logits, dim=1)
535
  confidence, pred_class = torch.max(probs, dim=1)
 
 
 
536
  return {
537
  "pred_class": int(pred_class.item()),
538
  "confidence": float(confidence.item()),
539
  "probabilities": probs.cpu().numpy().flatten().tolist()
540
  }
541
 
542
- # ==================== NEW: Full Segmentation Visualization ====================
543
-
544
- def create_full_segmentation_visualization(image_path, retina_diameter, cup_mask, disc_mask, macula_mask, vessel_mask,
545
- cup_center=None, cup_radius=None, disc_center=None, disc_radius=None,
546
- retina_center=None, retina_radius=None):
547
- img = cv2.imread(image_path)
548
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
549
- h, w = img_rgb.shape[:2]
550
-
551
- # Resize masks
552
- cup_mask_r = cv2.resize(cup_mask, (w, h), interpolation=cv2.INTER_NEAREST)
553
- disc_mask_r = cv2.resize(disc_mask, (w, h), interpolation=cv2.INTER_NEAREST)
554
- macula_mask_r = cv2.resize(macula_mask, (w, h), interpolation=cv2.INTER_NEAREST)
555
- vessel_mask_r = cv2.resize(vessel_mask, (w, h), interpolation=cv2.INTER_NEAREST)
556
-
557
- vis = img_rgb.copy().astype(np.float32) / 255.0
558
-
559
- # Overlays (order matters)
560
- vis[vessel_mask_r > 0.5] = vis[vessel_mask_r > 0.5] * 0.5 + np.array([0, 1, 1]) * 0.5 # Cyan vasculature
561
- vis[macula_mask_r > 0.5] = vis[macula_mask_r > 0.5] * 0.6 + np.array([1, 1, 0]) * 0.4 # Yellow macula
562
- vis[cup_mask_r > 0.5] = vis[cup_mask_r > 0.5] * 0.5 + np.array([0, 0, 1]) * 0.5 # Blue cup
563
- vis[disc_mask_r > 0.5] = vis[disc_mask_r > 0.5] * 0.6 + np.array([0, 1, 0]) * 0.4 # Green disc
564
-
565
- # Circles
566
- if retina_center and retina_radius:
567
- cv2.circle(vis, (int(retina_center[0]), int(retina_center[1])), int(retina_radius), (1, 0, 0), 4) # Red
568
- if disc_center and disc_radius:
569
- cv2.circle(vis, (int(disc_center[0]), int(disc_center[1])), int(disc_radius), (0, 1, 0), 3) # Green
570
- if cup_center and cup_radius:
571
- cv2.circle(vis, (int(cup_center[0]), int(cup_center[1])), int(cup_radius), (0, 0, 1), 3) # Blue
572
-
573
- # Legend with matplotlib
574
- fig, ax = plt.subplots(1, 1, figsize=(12, 12))
575
- ax.imshow(vis)
576
- ax.axis('off')
577
- legend_elements = [
578
- patches.Patch(facecolor=(1,0,0,0.6), edgecolor='r', label='Retina Boundary'),
579
- patches.Patch(facecolor=(0,1,0,0.6), edgecolor='g', label='Optic Disc'),
580
- patches.Patch(facecolor=(0,0,1,0.6), edgecolor='b', label='Optic Cup'),
581
- patches.Patch(facecolor=(1,1,0,0.6), edgecolor='y', label='Macula'),
582
- patches.Patch(facecolor=(0,1,1,0.6), edgecolor='c', label='Vasculature')
583
- ]
584
- ax.legend(handles=legend_elements, loc='upper right', fontsize=14, framealpha=0.95)
585
-
586
- buf = io.BytesIO()
587
- plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
588
- plt.close(fig)
589
- buf.seek(0)
590
- return buf
591
-
592
- # ==================== MAIN PREDICTION FUNCTION (updated) ====================
593
-
594
- fmt = "{:.3f}"
595
 
596
  def predict_all_diameters(image_path):
597
  if image_path is None:
598
  return "Please upload an image.", None
599
 
600
  image_pil = Image.open(image_path).convert('RGB')
601
- img_tensor, _ = preprocess_for_unet(image_path)
602
 
603
- # Macula
604
  macula_area = extract_macula_area(image_pil)
605
  macula_radius = (macula_area / math.pi) ** 0.5
606
  macula_diameter = 2 * macula_radius
607
-
608
- # Retina
609
  retina_diameter = get_retina_statistics_from_image(image_path)
610
- if retina_diameter == "NA":
611
- retina_diameter = "NA"
 
 
612
 
613
- # Optic disc & cup
614
  with torch.no_grad():
615
  cup_out = optic_cup_segmentation_model(img_tensor)
616
  disc_out = optic_disc_segmentation_model(img_tensor)
617
  cup_mask = (cup_out > 0.5).float().cpu().numpy()[0, 0]
618
  disc_mask = (disc_out > 0.5).float().cpu().numpy()[0, 0]
619
 
620
- cup_center_full, cup_radius_full, cup_diameter = fit_min_enclosing_circle(cv2.resize(cup_mask, (cv2.imread(image_path).shape[1], cv2.imread(image_path).shape[0]), cv2.INTER_NEAREST))
621
- disc_center_full, disc_radius_full, disc_diameter = fit_min_enclosing_circle(cv2.resize(disc_mask, (cv2.imread(image_path).shape[1], cv2.imread(image_path).shape[0]), cv2.INTER_NEAREST))
622
 
623
- # Macula mask
624
- with torch.no_grad():
625
- macula_out = segformer_macula_segmentation_model(img_tensor)
626
- macula_mask = (macula_out > 0.5).float().cpu().numpy()[0, 0]
 
 
 
 
 
 
 
 
 
 
627
 
628
- # Vasculature
629
- vd = compute_vasculature_density(image_pil, retinal_vasculature_segmentation_model, DEVICE)
 
 
 
 
630
 
631
- # Retina center/radius for visualization
632
- retina_center = None
633
- retina_radius_val = None
634
  if isinstance(retina_diameter, float):
635
  img_full = cv2.imread(image_path)
636
  gray = cv2.cvtColor(img_full, cv2.COLOR_BGR2GRAY)
@@ -640,127 +637,258 @@ def predict_all_diameters(image_path):
640
  if circles is not None:
641
  circles = np.uint16(np.around(circles))
642
  largest = circles[0, np.argmax(circles[0, :, 2])]
643
- retina_center = (largest[0], largest[1])
644
- retina_radius_val = largest[2]
645
-
646
- # Visualization
647
- vis_buf = create_full_segmentation_visualization(
648
- image_path, retina_diameter, cup_mask, disc_mask, macula_mask,
649
- refine_mask((retinal_vasculature_segmentation_model(img_tensor) > 0.05).float()).cpu().numpy()[0,0],
650
- cup_center=cup_center_full, cup_radius=cup_radius_full,
651
- disc_center=disc_center_full, disc_radius=disc_radius_full,
652
- retina_center=retina_center, retina_radius=retina_radius_val
 
653
  )
654
 
655
- # Dataframe and prediction
 
 
 
 
 
 
 
 
 
 
 
656
  df = pd.DataFrame([{
657
- "retina_diameter": retina_diameter if isinstance(retina_diameter, float) else None,
658
  "optic_cup_diameter": cup_diameter,
659
  "optic_disc_diameter": disc_diameter,
660
  "macula_diameter": macula_diameter,
661
  "vasculature_density": vd
662
  }])
663
 
664
- original_df = create_dataset_statistics(df.copy())
665
  df = reorder_df(df)
 
666
 
667
- scaler_path = hf_hub_download(repo_id="rprkh/multimodal_glaucoma_classification", filename="scaler.pkl", use_auth_token=secret_value_0)
668
- scaler = load_scaler(scaler_path)
669
- scaled = scaler.transform(df)
670
- df_scaled = pd.DataFrame(scaled, columns=df.columns)
671
-
672
- result = predict_multimodal(multimodal_glaucoma_classification_model, image_path, df_scaled.iloc[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  pred_class = result["pred_class"]
675
  confidence = result["confidence"]
676
- prediction = "Glaucomatous Retina" if pred_class == 1 else "Non-Glaucomatous Retina"
677
 
678
- def get_val(col, default="NA"):
679
- val = original_df[col].iloc[0]
680
- return fmt.format(val) if pd.notna(val) else default
681
-
682
- # Conversion helpers
683
- def convert_to_mm(x): return "N/A" if x in ["NA", None] else str(round(float(x) * 0.06818, 3))
684
- def convert_to_mm2(x): return "N/A" if x in ["NA", None] else str(round(float(x) * 0.06818**2, 3))
685
- def round3(x): return "N/A" if x in ["NA", None] else str(round(float(x), 3))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
- # Fixed result HTML
688
  result_text = f"""
689
- <div id="results_container">
690
- <div id="results_table" style="display:flex; gap:7px;">
691
- <div style="flex:1;"><table>
692
- <tr><th>Measurement</th><th>Value</th></tr>
693
- <tr><td><b>Retina Diameter</b></td><td>{convert_to_mm(retina_diameter)} mm</td></tr>
694
- <tr><td><b>Optic Cup Diameter</b></td><td>{convert_to_mm(cup_diameter)} mm</td></tr>
695
- <tr><td><b>Optic Disc Diameter</b></td><td>{convert_to_mm(disc_diameter)} mm</td></tr>
696
- <tr><td><b>Macular Diameter</b></td><td>{convert_to_mm(macula_diameter)} mm</td></tr>
697
- <tr><td><b>Vasculature Density</b></td><td>{round(vd * 100, 3)}%</td></tr>
698
- <tr><td><b>Retina Radius</b></td><td>{convert_to_mm(get_val('retina_radius', 'NA'))} mm</td></tr>
699
- <tr><td><b>Retina Area</b></td><td>{convert_to_mm2(get_val('retina_area', 'NA'))} mm<sup>2</sup></td></tr>
700
- <tr><td><b>Retina Circumference</b></td><td>{convert_to_mm(get_val('retina_circumference', 'NA'))} mm</td></tr>
701
- </table></div>
702
-
703
- <div style="flex:1;"><table>
704
- <tr><th>Measurement</th><th>Value</th></tr>
705
- <tr><td><b>Optic Cup Radius</b></td><td>{convert_to_mm(get_val('optic_cup_radius', 'NA'))} mm</td></tr>
706
- <tr><td><b>Optic Cup Area</b></td><td>{convert_to_mm2(get_val('optic_cup_area', 'NA'))} mm<sup>2</sup></td></tr>
707
- <tr><td><b>Optic Cup Circumference</b></td><td>{convert_to_mm(get_val('optic_cup_circumference', 'NA'))} mm</td></tr>
708
- <tr><td><b>Optic Disc Radius</b></td><td>{convert_to_mm(get_val('optic_disc_radius', 'NA'))} mm</td></tr>
709
- <tr><td><b>Optic Disc Area</b></td><td>{convert_to_mm2(get_val('optic_disc_area', 'NA'))} mm<sup>2</sup></td></tr>
710
- <tr><td><b>Optic Disc Circumference</b></td><td>{convert_to_mm(get_val('optic_disc_circumference', 'NA'))} mm</td></tr>
711
- <tr><td><b>Macula Radius</b></td><td>{convert_to_mm(get_val('macula_radius', 'NA'))} mm</td></tr>
712
- <tr><td><b>Macula Area</b></td><td>{convert_to_mm2(get_val('macula_area', 'NA'))} mm<sup>2</sup></td></tr>
713
- </table></div>
714
-
715
- <div style="flex:1;"><table>
716
- <tr><th>Measurement</th><th>Value</th></tr>
717
- <tr><td><b>Macula Circumference</b></td><td>{convert_to_mm(get_val('macula_circumference', 'NA'))} mm</td></tr>
718
- <tr><td><b>Optic Disc to Retina Diameter Ratio</b></td><td>{round3(get_val('optic_disc_to_retina_diameter_ratio', 'NA'))}</td></tr>
719
- <tr><td><b>Optic Disc to Retina Area Ratio</b></td><td>{round3(get_val('optic_disc_to_retina_area_ratio', 'NA'))}</td></tr>
720
- <tr><td><b>Optic Cup to Disc Diameter Ratio (CDR)</b></td><td>{round3(get_val('optic_cup_to_disc_diameter_ratio', 'NA'))}</td></tr>
721
- <tr><td><b>Optic Cup to Disc Area Ratio</b></td><td>{round3(get_val('optic_cup_to_disc_area_ratio', 'NA'))}</td></tr>
722
- <tr><td><b>Optic Cup to Retina Diameter Ratio</b></td><td>{round3(get_val('optic_cup_to_retina_diameter_ratio', 'NA'))}</td></tr>
723
- <tr><td><b>Optic Cup to Retina Area Ratio</b></td><td>{round3(get_val('optic_cup_to_retina_area_ratio', 'NA'))}</td></tr>
724
- </table></div>
725
- </div>
726
- <h3>Predicted Class: {prediction}</h3>
727
- <h3>Confidence: {round(confidence * 100, 3)}%</h3>
728
- </div>
729
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
- return result_text, Image.frombytes("PNG", vis_buf.size, vis_buf.getvalue())
732
 
733
- # ==================== GRADIO INTERFACE ====================
734
 
735
  custom_css = """
736
- #container_1100, #image_box, #prediction_button, #results_container { max-width: 1100px !important; margin-left: auto !important; margin-right: auto !important; }
737
- .center_text { text-align: center !important; }
738
- .abstract_block { text-align: justify !important; max-width: 1100px !important; margin-left: auto !important; margin-right: auto !important; font-size: 15px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  """
740
 
741
  with gr.Blocks(title="Glaucoma Predictor", css=custom_css) as demo:
 
742
  gr.HTML("""
 
743
  <div id="container_1100">
744
  <div class="center_text" style="font-size: 24px; font-weight: bold;">
745
  Multimodal Glaucoma Classification Using Segmentation Based Biomarker Extraction
746
- </div><br>
 
 
747
  <div class="center_text" style="font-size: 16px;">
748
  Rahil Parikh<sup>a</sup>, Van Nguyen<sup>b</sup>, Anita Penkova<sup>c</sup><br>
749
  <sup>a</sup>Department of Computer Science, University of Southern California, Los Angeles, 90089, California, USA<br>
750
  <sup>b</sup>Roski Eye Institute, Keck School of Medicine, University of Southern California, Los Angeles, 90033, California, USA<br>
751
  <sup>c</sup>Department of Aerospace and Mechanical Engineering, University of Southern California, Los Angeles, 90089, California, USA<br>
752
  Email: {rahilpar, vann4675, penkova}@usc.edu
753
- </div><br>
754
- <div class="center_text" style="font-size: 15px; font-weight: bold;">Abstract</div>
 
 
 
 
 
755
  <div class="abstract_block">
756
  Glaucoma is a progressive eye disease that leads to irreversible vision loss and, if not addressed promptly,
757
- can result in blindness. While various treatment options exist, early detection and diagnosis can mitigate
 
758
  the devastating vision loss of glaucoma. The primary focus of the proposed study is the development of a
759
- robust pipeline for the efficient extraction of quantifiable data from fundus images. Our unique approach
760
- leverages segmentation-based models in conjunction with computer vision techniques to efficiently extract
761
- and compute clinical glaucoma biomarkers such as optic disc area and optic cup to disc ratio. The extracted
762
- biomarkers are consistent with trends observed in clinical practice. The weighted combination of image-based
763
- features with glaucoma biomarkers achieves a test accuracy of 91.38% for glaucoma classification.
 
 
 
 
 
 
764
  </div>
765
  </div>
766
  """)
@@ -769,15 +897,16 @@ with gr.Blocks(title="Glaucoma Predictor", css=custom_css) as demo:
769
  with gr.Column():
770
  image_input = gr.Image(type="filepath", label="Upload Fundus Image", elem_id="image_box")
771
  btn = gr.Button("Analyze Image", variant="primary", elem_id="prediction_button")
772
-
773
- with gr.Row():
774
- with gr.Column():
775
  result_md = gr.Markdown(elem_id="results_container")
776
- with gr.Column():
777
- vis_output = gr.Image(label="Segmentation Visualization", type="pil")
778
 
779
- btn.click(fn=lambda: ("Analyzing... Please wait.", None), outputs=[result_md, vis_output]) \
780
- .then(fn=predict_all_diameters, inputs=image_input, outputs=[result_md, vis_output])
 
 
 
 
 
 
781
 
782
  if __name__ == "__main__":
783
- demo.launch()
 
376
  img_t = img.transpose(2, 0, 1)
377
  img_t = torch.from_numpy(img_t).unsqueeze(0).to(DEVICE)
378
  img_t = normalize(img_t)
379
+ return img_t, img # tensor + [0,1] RGB
380
 
381
  def denormalize(tensor):
382
  means = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
 
431
 
432
  with torch.no_grad():
433
  output = segformer_macula_segmentation_model(img_tensor)
434
+ mask = (output > 0.5).float() # (1,1,H,W)
435
 
436
  macula_area = mask.sum().item()
437
  return macula_area
 
445
  (x, y), radius = cv2.minEnclosingCircle(largest)
446
  return (x, y), radius, 2 * radius
447
 
448
+ def blend_image_with_mask(image, mask, alpha=0.8, mask_color=(0, 0, 255)):
449
+ image = image.cpu().permute(1, 2, 0).numpy()
450
+ mask = mask.cpu().numpy()
451
+ mask_rgb = np.zeros_like(image)
452
+ for i in range(3):
453
+ mask_rgb[..., i] = mask * mask_color[i]
454
+ blended = (1 - alpha) * image + alpha * mask_rgb
455
+ return np.clip(blended, 0, 1)
456
+
457
  def refine_mask(mask):
458
  m = mask.cpu().numpy().squeeze()
459
  m = ndimage.binary_closing(m, structure=np.ones((3, 3)), iterations=1)
460
+ return torch.from_numpy(m).float().to(DEVICE) # Ensure on correct device
461
 
462
  def load_image_from_pil(image_pil, device):
463
  transform = transforms.Compose([
464
  transforms.Resize((128, 128)),
465
  transforms.ToTensor(),
466
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
467
+ std=[0.229, 0.224, 0.225])
468
  ])
469
  return transform(image_pil).unsqueeze(0).to(device)
470
 
 
473
  with torch.no_grad():
474
  img_tensor = load_image_from_pil(image_pil, device)
475
  output = model(img_tensor)
476
+ pred = (output > threshold).float()
477
  refined = refine_mask(pred)
478
  h, w = refined.shape[-2], refined.shape[-1]
479
+ roi = create_circular_roi_mask((h, w), radius_ratio)
480
  roi_tensor = torch.from_numpy(roi).to(device)
481
  masked = refined * roi_tensor
482
  vessel_area = masked.sum().item()
483
+ roi_area = roi_tensor.sum().item()
484
  density = vessel_area / roi_area if roi_area > 0 else 0.0
485
+
486
+ return density
487
 
488
  def create_circular_roi_mask(image_shape, radius_ratio=0.95):
489
  h, w = image_shape
 
498
  df["optic_cup_diameter"] = pd.to_numeric(df["optic_cup_diameter"], errors="coerce")
499
  df["optic_disc_diameter"] = pd.to_numeric(df["optic_disc_diameter"], errors="coerce")
500
  df["macula_diameter"] = pd.to_numeric(df["macula_diameter"], errors="coerce")
501
+ df["vasculature_density"] = pd.to_numeric(df["macula_diameter"], errors="coerce")
502
 
503
  df["retina_radius"] = df["retina_diameter"] / 2
504
  df["retina_area"] = np.pi * (df["retina_radius"] ** 2)
505
  df["retina_circumference"] = 2 * np.pi * df["retina_radius"]
506
+
507
  df["optic_cup_radius"] = df["optic_cup_diameter"] / 2
508
  df["optic_cup_area"] = np.pi * (df["optic_cup_radius"] ** 2)
509
  df["optic_cup_circumference"] = 2 * np.pi * df["optic_cup_radius"]
510
+
511
  df["optic_disc_radius"] = df["optic_disc_diameter"] / 2
512
  df["optic_disc_area"] = np.pi * (df["optic_disc_radius"] ** 2)
513
  df["optic_disc_circumference"] = 2 * np.pi * df["optic_disc_radius"]
514
+
515
  df["macula_radius"] = df["macula_diameter"] / 2
516
  df["macula_area"] = np.pi * (df["macula_radius"] ** 2)
517
  df["macula_circumference"] = 2 * np.pi * df["macula_radius"]
518
 
519
  df["optic_disc_to_retina_diameter_ratio"] = df["optic_disc_diameter"] / df["retina_diameter"]
520
  df["optic_disc_to_retina_area_ratio"] = df["optic_disc_area"] / df["retina_area"]
521
+
522
  df["optic_cup_to_disc_diameter_ratio"] = df["optic_cup_diameter"] / df["optic_disc_diameter"]
523
  df["optic_cup_to_disc_area_ratio"] = df["optic_cup_area"] / df["optic_disc_area"]
524
+
525
  df["optic_cup_to_retina_diameter_ratio"] = df["optic_cup_diameter"] / df["retina_diameter"]
526
  df["optic_cup_to_retina_area_ratio"] = df["optic_cup_area"] / df["retina_area"]
527
 
528
+ print("Dataset statistics succesfully generated")
529
+
530
  return df
531
 
532
  def reorder_df(df):
533
+ df = df[["retina_diameter",
534
+ "optic_disc_diameter", "optic_cup_diameter",
535
+ "macula_diameter",
536
+ "vasculature_density",
537
+ "retina_radius", "retina_area",
538
+ "retina_circumference", "optic_cup_radius",
539
+ "optic_cup_area", "optic_cup_circumference",
540
+ "optic_disc_radius", "optic_disc_area",
541
+ "optic_disc_circumference",
542
+ "macula_radius", "macula_area", "macula_circumference",
543
+ "optic_disc_to_retina_diameter_ratio",
544
+ "optic_disc_to_retina_area_ratio", "optic_cup_to_disc_diameter_ratio",
545
+ "optic_cup_to_disc_area_ratio", "optic_cup_to_retina_diameter_ratio",
546
+ "optic_cup_to_retina_area_ratio"
547
+ ]]
548
+
549
+ return df
550
 
551
  def load_scaler(scaler_file_path):
552
+ with open(f'{scaler_file_path}', 'rb') as f:
553
  scaler = pickle.load(f)
554
+ print("Scaler loaded successfully")
555
+
556
  return scaler
557
 
558
  def predict_multimodal(model, image_path, df_row, device=DEVICE):
559
  model.eval()
560
+
561
  image_pil = Image.open(image_path).convert("RGB")
562
+
563
  img = test_val_transform(image_pil).unsqueeze(0).to(device)
564
+
565
+ numeric_data = torch.tensor(
566
+ df_row.iloc[:].astype(float).values,
567
+ dtype=torch.float32
568
+ ).unsqueeze(0).to(device)
569
+
570
  with torch.no_grad():
571
  logits = model(img, numeric_data)
572
  probs = torch.softmax(logits, dim=1)
573
  confidence, pred_class = torch.max(probs, dim=1)
574
+
575
+ print(probs)
576
+
577
  return {
578
  "pred_class": int(pred_class.item()),
579
  "confidence": float(confidence.item()),
580
  "probabilities": probs.cpu().numpy().flatten().tolist()
581
  }
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
584
  def predict_all_diameters(image_path):
585
  if image_path is None:
586
  return "Please upload an image.", None
587
 
588
  image_pil = Image.open(image_path).convert('RGB')
 
589
 
 
590
  macula_area = extract_macula_area(image_pil)
591
  macula_radius = (macula_area / math.pi) ** 0.5
592
  macula_diameter = 2 * macula_radius
593
+
 
594
  retina_diameter = get_retina_statistics_from_image(image_path)
595
+ if retina_diameter is None:
596
+ retina_diameter = "Not detected"
597
+
598
+ img_tensor, img_rgb = preprocess_for_unet(image_path)
599
 
 
600
  with torch.no_grad():
601
  cup_out = optic_cup_segmentation_model(img_tensor)
602
  disc_out = optic_disc_segmentation_model(img_tensor)
603
  cup_mask = (cup_out > 0.5).float().cpu().numpy()[0, 0]
604
  disc_mask = (disc_out > 0.5).float().cpu().numpy()[0, 0]
605
 
606
+ _, _, cup_diameter = fit_min_enclosing_circle(cup_mask)
607
+ _, _, disc_diameter = fit_min_enclosing_circle(disc_mask)
608
 
609
+ cdr = cup_diameter / disc_diameter if (cup_diameter and disc_diameter and disc_diameter > 0) else None
610
+
611
+ img_vis = denormalize(img_tensor).squeeze(0).permute(1, 2, 0).cpu().numpy()
612
+
613
+ h, w = img_vis.shape[:2]
614
+ cup_mask_resized = cv2.resize(cup_mask, (w, h), interpolation=cv2.INTER_NEAREST)
615
+ disc_mask_resized = cv2.resize(disc_mask, (w, h), interpolation=cv2.INTER_NEAREST)
616
+
617
+ vis = img_vis.copy()
618
+ vis[cup_mask_resized > 0.5] = vis[cup_mask_resized > 0.5] * 0.4 + np.array([1, 0, 0]) * 0.6
619
+ vis[disc_mask_resized > 0.5] = vis[disc_mask_resized > 0.5] * 0.6 + np.array([0, 1, 0]) * 0.4
620
+
621
+ vis_uint8 = (vis * 255).astype(np.uint8)
622
+ vis_bgr = cv2.cvtColor(vis_uint8, cv2.COLOR_RGB2BGR)
623
 
624
+ if cup_diameter:
625
+ (cx, cy), r, _ = fit_min_enclosing_circle(cup_mask_resized)
626
+ cv2.circle(vis_bgr, (int(cx), int(cy)), int(r), (255, 0, 0), 2)
627
+ if disc_diameter:
628
+ (cx, cy), r, _ = fit_min_enclosing_circle(disc_mask_resized)
629
+ cv2.circle(vis_bgr, (int(cx), int(cy)), int(r), (0, 255, 0), 2)
630
 
 
 
 
631
  if isinstance(retina_diameter, float):
632
  img_full = cv2.imread(image_path)
633
  gray = cv2.cvtColor(img_full, cv2.COLOR_BGR2GRAY)
 
637
  if circles is not None:
638
  circles = np.uint16(np.around(circles))
639
  largest = circles[0, np.argmax(circles[0, :, 2])]
640
+ x, y, r = largest
641
+ cv2.circle(vis_bgr, (x, y), r, (0, 0, 255), 2) # Red circle for retina
642
+
643
+ vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB) / 255.0
644
+
645
+ vd = compute_vasculature_density(
646
+ image_pil=image_pil,
647
+ model=retinal_vasculature_segmentation_model,
648
+ device=DEVICE,
649
+ threshold=0.05,
650
+ radius_ratio=0.95
651
  )
652
 
653
+ if not retina_diameter:
654
+ retina_diameter = "NA"
655
+ if not cup_diameter:
656
+ cup_diameter = "NA"
657
+ if not disc_diameter:
658
+ disc_diameter = "NA"
659
+ if not vd:
660
+ vd = "NA"
661
+ if not macula_diameter:
662
+ macula_diameter = "NA"
663
+
664
+ df = pd.DataFrame()
665
  df = pd.DataFrame([{
666
+ "retina_diameter": retina_diameter,
667
  "optic_cup_diameter": cup_diameter,
668
  "optic_disc_diameter": disc_diameter,
669
  "macula_diameter": macula_diameter,
670
  "vasculature_density": vd
671
  }])
672
 
673
+ df = create_dataset_statistics(df)
674
  df = reorder_df(df)
675
+ print(df)
676
 
677
+ original_df = df
678
+
679
+ scaler_path = hf_hub_download(
680
+ repo_id="rprkh/multimodal_glaucoma_classification",
681
+ filename="scaler.pkl",
682
+ use_auth_token=secret_value_0
683
+ )
684
+ standard_scaler = load_scaler(f"{scaler_path}")
685
+ scaled_array = standard_scaler.transform(df.iloc[:, :])
686
+ print("Data scaled successfully")
687
+ print(df)
688
+
689
+ df = pd.DataFrame(scaled_array, columns=df.columns)
690
+ df_row = df.iloc[0]
691
+
692
+ result = predict_multimodal(
693
+ model=multimodal_glaucoma_classification_model,
694
+ image_path=image_path,
695
+ df_row=df_row,
696
+ device=DEVICE
697
+ )
698
 
699
  pred_class = result["pred_class"]
700
  confidence = result["confidence"]
 
701
 
702
+ if pred_class == 0:
703
+ prediction = "Non-Glaucomatous Retina"
704
+ if pred_class == 1:
705
+ prediction = "Glaucomatous Retina"
706
+
707
+ def get_df_value(df, col_name, default="NA"):
708
+ if col_name not in df.columns:
709
+ return default
710
+ val = df[col_name].iloc[0]
711
+ if pd.isna(val) or val == "NA":
712
+ return default
713
+ try:
714
+ return fmt.format(val)
715
+ except Exception:
716
+ return str(val)
717
+
718
+ retina_radius = get_df_value(original_df, "retina_radius")
719
+ retina_area = get_df_value(original_df, "retina_area")
720
+ retina_circumference = get_df_value(original_df, "retina_circumference")
721
+
722
+ optic_cup_radius = get_df_value(original_df, "optic_cup_radius")
723
+ optic_cup_area = get_df_value(original_df, "optic_cup_area")
724
+ optic_cup_circumference = get_df_value(original_df, "optic_cup_circumference")
725
+
726
+ optic_disc_radius = get_df_value(original_df, "optic_disc_radius")
727
+ optic_disc_area = get_df_value(original_df, "optic_disc_area")
728
+ optic_disc_circumference = get_df_value(original_df, "optic_disc_circumference")
729
+
730
+ macula_radius = get_df_value(original_df, "macula_radius")
731
+ macula_area = get_df_value(original_df, "macula_area")
732
+ macula_circumference = get_df_value(original_df, "macula_circumference")
733
+
734
+ optic_disc_to_retina_diameter_ratio = get_df_value(original_df, "optic_disc_to_retina_diameter_ratio")
735
+ optic_disc_to_retina_area_ratio = get_df_value(original_df, "optic_disc_to_retina_area_ratio")
736
+
737
+ optic_cup_to_disc_diameter_ratio = get_df_value(original_df, "optic_cup_to_disc_diameter_ratio")
738
+ optic_cup_to_disc_area_ratio = get_df_value(original_df, "optic_cup_to_disc_area_ratio")
739
+ optic_cup_to_retina_diameter_ratio = get_df_value(original_df, "optic_cup_to_retina_diameter_ratio")
740
+ optic_cup_to_retina_area_ratio = get_df_value(original_df, "optic_cup_to_retina_area_ratio")
741
+
742
+ def convert_to_mm(measurement):
743
+ try:
744
+ measurement = float(measurement)
745
+ measurement = measurement * 0.06818
746
+ measurement = round(measurement, 3)
747
+ except:
748
+ measurement = "N/A"
749
+
750
+ return measurement
751
+
752
+ def convert_to_mm2(measurement):
753
+ try:
754
+ measurement = float(measurement)
755
+ measurement = measurement * 0.06818 * 0.06818
756
+ measurement = round(measurement, 3)
757
+ except:
758
+ measurement = "N/A"
759
+
760
+ return measurement
761
+
762
+ def round_measurement_to_3_dp(measurement):
763
+ try:
764
+ measurement = float(measurement)
765
+ measurement = round(measurement, 3)
766
+ except:
767
+ measurement = "N/A"
768
+
769
+ return measurement
770
 
 
771
  result_text = f"""
772
+ <div id="results_container">
773
+ <div id="results_table" style="display:flex; gap:7px;">
774
+ <div style="flex:1;">
775
+ <table>
776
+ <tr><th>Measurement</th><th>Value</th></tr>
777
+
778
+ <tr><td><b>Retina Diameter</b></td><td>{convert_to_mm(retina_diameter)} mm</td></tr>
779
+ <tr><td><b>Optic Cup Diameter</b></td><td>{convert_to_mm(cup_diameter)} mm</td></tr>
780
+ <tr><td><b>Optic Disc Diameter</b></td><td>{convert_to_mm(disc_diameter)} mm</td></tr>
781
+ <tr><td><b>Macular Diameter</b></td><td>{convert_to_mm(macula_diameter)} mm</td></tr>
782
+ <tr><td><b>Vasculature Density</b></td><td>{round(vd * 100, 3)}%</td></tr>
783
+
784
+ <tr><td><b>Retina Radius</b></td><td>{convert_to_mm(retina_radius)} mm</td></tr>
785
+ <tr><td><b>Retina Area</b></td><td>{convert_to_mm2(retina_area)} mm<sup>2</sup></td></tr>
786
+ <tr><td><b>Retina Circumference</b></td><td>{convert_to_mm(retina_circumference)} mm</td></tr>
787
+ </table>
788
+ </div>
789
+
790
+ <div style="flex:1;">
791
+ <table>
792
+ <tr><th>Measurement</th><th>Value</th></tr>
793
+
794
+ <tr><td><b>Optic Cup Radius</b></td><td>{convert_to_mm(optic_cup_radius)} mm</td></tr>
795
+ <tr><td><b>Optic Cup Area</b></td><td>{convert_to_mm2(optic_disc_area)} mm<sup>2</sup></td></tr>
796
+ <tr><td><b>Optic Cup Circumference</b></td><td>{convert_to_mm(optic_cup_circumference)} mm</td></tr>
797
+
798
+ <tr><td><b>Optic Disc Radius</b></td><td>{convert_to_mm(optic_disc_radius)} mm</td></tr>
799
+ <tr><td><b>Optic Disc Area</b></td><td>{convert_to_mm2(optic_cup_area)} mm<sup>2</sup></td></tr>
800
+ <tr><td><b>Optic Disc Circumference</b></td><td>{convert_to_mm(optic_disc_circumference)} mm</td></tr>
801
+
802
+ <tr><td><b>Macula Radius</b></td><td>{convert_to_mm(macula_radius)} mm</td></tr>
803
+ <tr><td><b>Macula Area</b></td><td>{convert_to_mm(macula_area)} mm</td></tr>
804
+
805
+ </table>
806
+ </div>
807
+
808
+ <div style="flex:1;">
809
+ <table>
810
+ <tr><th>Measurement</th><th>Value</th></tr>
811
+
812
+ <tr><td><b>Macula Circumference</b></td><td>{convert_to_mm(macula_circumference)} mm</td></tr>
813
+
814
+ <tr><td><b>Optic Disc to Retina Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_disc_to_retina_diameter_ratio)}</td></tr>
815
+ <tr><td><b>Optic Disc to Retina Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_disc_to_retina_area_ratio)}</td></tr>
816
+
817
+ <tr><td><b>Optic Cup to Disc Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_disc_diameter_ratio)}</td></tr>
818
+ <tr><td><b>Optic Cup to Disc Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_disc_area_ratio)}</td></tr>
819
+ <tr><td><b>Optic Cup to Retina Diameter Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_retina_diameter_ratio)}</td></tr>
820
+ <tr><td><b>Optic Cup to Retina Area Ratio</b></td><td>{round_measurement_to_3_dp(optic_cup_to_retina_area_ratio)}</td></tr>
821
+ </table>
822
+ </div>
823
+ </div>
824
+
825
+ <h3>Predicted Class: {prediction}</h3>
826
+ <h3>Confidence: {round(confidence * 100, 3)}%</h3>
827
+ <div>
828
+ """
829
 
830
+ return result_text
831
 
 
832
 
833
  custom_css = """
834
+ #container_1100, #image_box, #prediction_button, #results_container {
835
+ max-width: 1100px !important;
836
+ margin-left: auto !important;
837
+ margin-right: auto !important;
838
+ }
839
+
840
+ .center_text {
841
+ text-align: center !important;
842
+ }
843
+
844
+ .abstract_block {
845
+ text-align: justify !important;
846
+ max-width: 1100px !important;
847
+ margin-left: auto !important;
848
+ margin-right: auto !important;
849
+ font-size: 15px;
850
+ }
851
  """
852
 
853
  with gr.Blocks(title="Glaucoma Predictor", css=custom_css) as demo:
854
+
855
  gr.HTML("""
856
+ <!--html-->
857
  <div id="container_1100">
858
  <div class="center_text" style="font-size: 24px; font-weight: bold;">
859
  Multimodal Glaucoma Classification Using Segmentation Based Biomarker Extraction
860
+ </div>
861
+ <br>
862
+
863
  <div class="center_text" style="font-size: 16px;">
864
  Rahil Parikh<sup>a</sup>, Van Nguyen<sup>b</sup>, Anita Penkova<sup>c</sup><br>
865
  <sup>a</sup>Department of Computer Science, University of Southern California, Los Angeles, 90089, California, USA<br>
866
  <sup>b</sup>Roski Eye Institute, Keck School of Medicine, University of Southern California, Los Angeles, 90033, California, USA<br>
867
  <sup>c</sup>Department of Aerospace and Mechanical Engineering, University of Southern California, Los Angeles, 90089, California, USA<br>
868
  Email: {rahilpar, vann4675, penkova}@usc.edu
869
+ </div>
870
+
871
+ <br>
872
+ <div class="center_text" style="font-size: 15px; font-weight: bold;">
873
+ Abstract
874
+ </div>
875
+
876
  <div class="abstract_block">
877
  Glaucoma is a progressive eye disease that leads to irreversible vision loss and, if not addressed promptly,
878
+ can result in blindness. While various treatment options such as eye drops, oral medications, and surgical
879
+ interventions exist, the disease may still progress. Therefore, early detection and diagnosis can mitigate
880
  the devastating vision loss of glaucoma. The primary focus of the proposed study is the development of a
881
+ robust pipeline for the efficient extraction of quantifiable data from fundus images. What sets our approach
882
+ apart from other approaches is the emphasis on automated feature extraction of glaucoma biomarkers from
883
+ particular regions of interest (ROI) along with the utilization of a weighted average of image features and
884
+ clinical measurements. Our unique approach leverages segmentation based models in conjunction with
885
+ computer vision techniques to efficiently extract and compute clinical glaucoma biomarkers such as optic disc
886
+ area and optic cup to disc ratio. The extracted biomarkers are consistent with trends observed in clinical
887
+ practice, thus supporting the validity of the feature extraction approach. While subtle disease progression,
888
+ inconsistencies in image quality and a general lack of metadata impact classification performance, the
889
+ weighted combination of image based features with glaucoma biomarkers achieves a test accuracy of 91.38%
890
+ for glaucoma classification, successfully addressing the limitations of traditional single-modality approaches
891
+ such as fundus imaging and optical coherence tomography (OCT).
892
  </div>
893
  </div>
894
  """)
 
897
  with gr.Column():
898
  image_input = gr.Image(type="filepath", label="Upload Fundus Image", elem_id="image_box")
899
  btn = gr.Button("Analyze Image", variant="primary", elem_id="prediction_button")
 
 
 
900
  result_md = gr.Markdown(elem_id="results_container")
 
 
901
 
902
+ btn.click(
903
+ fn=lambda: ("Analyzing... Please wait.",),
904
+ outputs=[result_md]
905
+ ).then(
906
+ fn=predict_all_diameters,
907
+ inputs=image_input,
908
+ outputs=[result_md]
909
+ )
910
 
911
  if __name__ == "__main__":
912
+ demo.launch()