Preetham22 commited on
Commit
1ff0d2d
·
1 Parent(s): 42e56c5

Medi-LLM: public demo

Browse files
app/demo/demo.py CHANGED
@@ -70,7 +70,7 @@ def classify(role, mode, normalize_mode, emr_text, image, use_rollout):
70
 
71
  # Model caching
72
  if mode not in model_cache:
73
- model_cache[mode] = load_model(mode, MODEL_PATHS[mode])
74
  model = model_cache[mode]
75
 
76
  # Run prediction
@@ -529,301 +529,306 @@ def reset_ui():
529
  )
530
 
531
 
532
- # --- Gradio UI ---
533
- style_path = Path(__file__).resolve().parent / "style.css"
534
- with open(style_path, "r") as f:
535
- custom_css = f.read()
536
-
537
- with gr.Blocks(css=custom_css) as demo:
538
- # Centered title and subtitle
539
- gr.Markdown("<h2 class='centered'>🩺 Medi-LLM: Clinical Triage Assistant 🩻</h2>")
540
- gr.Markdown("<p class='centered'>Upload a chest X-ray and/or enter EMR text to get a triage level prediction.</p>")
541
- gr.HTML(
542
- """
543
- <div class='welcome-banner' style="background-color: #24283b; border-left: 4px solid #7aa2f7; padding: 16px; border-radius: 8px; margin-bottom: 16px;">
544
- <h3 style="margin-top: 0; color: #c0caf5;">👋 Welcome to Medi-LLM</h3>
545
- <p style="color: #a9b1d6; line-height: 1.6;">
546
- This AI assistant helps triage patients using <strong>EMR text</strong> and <strong>chest X-rays</strong>.<br>
547
- 📝 Enter EMR notes, 📷 upload a chest X-ray, or use both for a multimodal diagnosis.<br>
548
- 👩‍⚕️ Select <strong>Doctor</strong> mode to view insights like Grad-CAM heatmaps and token-level attention.<br>
549
- 💾 Save your results for later by exporting them to a CSV file.
550
- </p>
551
- </div>
552
- """
553
- )
554
-
555
- # Hidden State
556
- role_state = gr.State(value="User")
557
- mode_state = gr.State(value=DEFAULT_MODE)
558
- rollout_state = gr.State(value=False)
559
- normaliza_mode_state = gr.State(value="visual")
560
- inference_done = gr.State(value=False)
561
-
562
- # Role and Mode selection
563
- with gr.Row(equal_height=True):
564
- with gr.Column():
565
- role = gr.Radio(["User", "Doctor"], value="User", label="Select Role", info="Doctors see insights like Grad-CAM and token attention", elem_id="role_selector")
566
- mode = gr.Radio(["text", "image", "multimodal"], value=DEFAULT_MODE, label="Select Input Mode", info="Choose Diagnosis input type", elem_id="mode_selector")
567
- with gr.Column(visible=False) as normalize_mode_column:
568
- normalize_mode = gr.Radio(
569
- ["visual", "probabilistic"],
570
- value="visual",
571
- label="Attention Normalization",
572
- info="Softmax sums to 1 (probabilistic). Visual uses gamma-boosted scaling for color clarity."
573
- )
574
- use_rollout = gr.Checkbox(
575
- label="Use attention rollout (CLS -> inputs)",
576
- value=False,
577
- info="Includes residuals and multiplies attention across layers. Slower but often more faithful."
578
- )
579
-
580
- normalize_mode.change(
581
- fn=lambda val: val,
582
- inputs=[normalize_mode],
583
- outputs=[normaliza_mode_state]
584
- )
585
-
586
- use_rollout.change(
587
- fn=lambda v: v,
588
- inputs=[use_rollout],
589
- outputs=[rollout_state]
590
- )
591
 
592
- # Input: EMR text and/or image
593
- with gr.Row():
594
- with gr.Column(scale=3, elem_id="text_col") as text_col:
595
- emr_text, image, max_file_note = render_inputs(DEFAULT_MODE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
- # Submit button
598
- with gr.Row():
599
- submit_btn = gr.Button(
600
- "🔍 Run Inference",
601
- elem_id="inference_btn"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  )
603
- reset_btn = gr.Button(
604
- "↩️ Reset",
605
- elem_id="reset_btn"
 
 
 
606
  )
607
 
608
- # Outputs
609
- with gr.Column(elem_classes=["output-box"]):
610
- result_box = gr.Textbox(label="🧪 Triage Prediction", interactive=False)
611
- confidence_label = gr.Label(label="📊 Confidence", visible=False)
612
- prediction_count_box = gr.Textbox(value="Predictions: 0", interactive=False, label="🧮 Count", elem_id="prediction_count_box")
613
- insights_tab = gr.Tabs(visible=False)
614
- class_probs_json = gr.JSON(label="🔍 Class Probabilities", visible=True, elem_classes=["json-box"])
615
- with insights_tab:
616
- with gr.Tab("📷 Grad-CAM"):
617
- gradcam_img = gr.Image(visible=False, elem_classes=["gr-image-box"])
618
- with gr.Tab("🔬 Token Attention"):
619
- token_attention = gr.HighlightedText(
620
- visible=False,
621
- show_legend=False,
622
- color_map={
623
- "0.0": "#7aa2f7", # blue
624
- "0.25": "#80deea", # cyan
625
- "0.5": "#fbc02d", # yellow
626
- "0.75": "#ff8a65", # orange
627
- "1.0": "#f7768e", # red
628
- },
629
- elem_classes=["token-attn-box"]
630
- )
631
- top5_html = gr.HTML(value="", visible=False)
632
 
633
- inject_tooltips()
 
 
 
 
634
 
635
- gr.HTML("""
636
- <div class="attention-legend">
637
- <div style="display: flex; align-items: center; gap: 8px;">
638
- <span style="font-size: 14px; color: #c0caf5;">0.0</span>
639
- <div class="attention-gradient-bar"></div>
640
- <span style="font-size: 14px; color: #c0caf5;">1.0</span>
641
- </div>
642
- </div>
643
- """)
644
-
645
- with gr.Row():
646
- loading_msg = gr.Markdown(value="", visible=False, elem_classes=["loading-msg"])
647
-
648
- # Bind inference
649
- submit_btn.click(
650
- fn=show_loading_msg,
651
- outputs=[loading_msg]
652
- ).then(
653
- fn=classify,
654
- inputs=[role_state, mode_state, normaliza_mode_state, emr_text, image, rollout_state],
655
- outputs=[
656
- result_box,
657
- gradcam_img,
658
- token_attention,
659
- top5_html,
660
- confidence_label,
661
- insights_tab,
662
- prediction_count_box,
663
- class_probs_json,
664
- ]
665
- ).then(
666
- fn=lambda: gr.update(value="", visible=False),
667
- outputs=[loading_msg]
668
- ).then(
669
- fn=lambda: True,
670
- outputs=[inference_done]
671
- )
672
 
673
- # Input Updates
674
- mode.change(
675
- fn=lambda m: (*render_inputs(m), m),
676
- inputs=[mode],
677
- outputs=[emr_text, image, max_file_note, mode_state]
678
- )
 
 
 
 
 
 
 
 
679
 
680
- role.change(
681
- fn=update_role_state,
682
- inputs=[role],
683
- outputs=[role_state, normalize_mode_column, insights_tab, token_attention, gradcam_img, use_rollout, top5_html]
684
- )
 
 
 
 
 
 
 
 
 
685
 
686
- normalize_mode.change(
687
- fn=rerun_if_done,
688
- inputs=[inference_done, role_state, mode_state, normalize_mode, emr_text, image, rollout_state],
689
- outputs=[
690
- result_box,
691
- gradcam_img,
692
- token_attention,
693
- top5_html,
694
- confidence_label,
695
- insights_tab,
696
- prediction_count_box,
697
- class_probs_json,
698
- ]
699
- )
700
 
701
- use_rollout.change(
702
- fn=rerun_if_done,
703
- inputs=[inference_done, role_state, mode_state, normalize_mode, emr_text, image, rollout_state],
704
- outputs=[
705
- result_box,
706
- gradcam_img,
707
- token_attention,
708
- top5_html,
709
- confidence_label,
710
- insights_tab,
711
- prediction_count_box,
712
- class_probs_json
713
- ]
714
- )
715
 
716
- # CSV Export UI
717
- gr.Markdown("### 📁 Export Prediction Log")
 
 
 
 
 
718
 
719
- with gr.Row(equal_height=True):
720
- with gr.Column(scale=3):
721
- filename_input = gr.Textbox(
722
- label="CSV filename (optional)",
723
- placeholder="e.g., triage_results.csv",
724
- info="Set filename as needed or leave blank for auto-naming",
725
- elem_id="csv_filename"
726
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
- export_status_box = gr.Textbox(
729
- value="",
730
- visible=False,
731
- interactive=False,
732
- label="",
733
- elem_id="export_status"
734
- )
735
 
736
- with gr.Column(scale=4):
737
- gr.Markdown(
738
- "📑 **Summary**\n\nDownload your triage results for clinical review or research.",
739
- elem_classes="centered"
740
- )
741
- with gr.Row():
742
- with gr.Column(scale=1, min_width=200):
743
- download_btn = gr.Button("💾 Export CSV", elem_id="export_button")
744
- with gr.Column(scale=1, min_width=200):
745
- clear_btn = gr.Button("🗑️ Clear Logs", elem_id="clear_button")
746
- confirm_clear_btn = gr.Button("✅ Confirm Clear", visible=False, elem_id="confirm_button")
747
- confirm_box = gr.Textbox(label="Status", interactive=False, visible=False, elem_id="confirm_box")
748
-
749
- with gr.Column(scale=3):
750
- csv_output = gr.File(label="📂 Download Link", elem_id="download_box")
751
-
752
- download_btn.click(
753
- fn=export_csv,
754
- inputs=[filename_input, role_state],
755
- outputs=[
756
- csv_output,
757
- csv_output,
758
- export_status_box
759
- ]
760
- ).then(
761
- fn=blink_box_effect,
762
- inputs=[csv_output],
763
- outputs=[csv_output]
764
- ).then(
765
- fn=disable_filename_input,
766
- outputs=[filename_input]
767
- )
768
 
769
- clear_btn.click(
770
- fn=lambda: (
771
- confirm_clear(),
772
- gr.Button(visible=True),
773
- ),
774
- outputs=[confirm_box, confirm_clear_btn]
775
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
- confirm_clear_btn.click(
778
- fn=clear_confirmed,
779
- inputs=[role_state],
780
- outputs=[
781
- prediction_count_box, # reset prediction count
782
- confirm_box, # show success message
783
- csv_output, # hide CSV output file
784
- filename_input # re-enable input box
785
- ]
786
- ).then(
787
- fn=lambda: gr.update(visible=False), # Hide confirm button
788
- outputs=[confirm_clear_btn]
789
- ).then(
790
- fn=reset_confirm_box,
791
- outputs=[confirm_box]
792
- )
793
 
794
- # Reset UI
795
- reset_btn.click(
796
- fn=reset_ui,
797
- outputs=[
798
- emr_text, # 1
799
- image, # 2
800
- max_file_note, # 3
801
- result_box, # 4
802
- gradcam_img, # 5
803
- token_attention, # 6
804
- top5_html, # 7
805
- confidence_label, # 8
806
- insights_tab, # 9
807
- class_probs_json, # 10
808
- role_state, # 11
809
- mode_state, # 12
810
- normaliza_mode_state, # 13
811
- role, # 14 (radio)
812
- mode, # 15 (radio)
813
- normalize_mode, # 16 (radio)
814
- normalize_mode_column, # 17 (column visibility)
815
- use_rollout, # 18
816
- rollout_state, # 19
817
- loading_msg, # 20
818
- inference_done, # 21
819
- export_status_box # 22
820
- ]
821
- )
822
 
823
  if __name__ == "__main__":
824
- for mode, path in MODEL_PATHS.items():
825
- if not os.path.exists(path):
826
- print(f" Missing model for mode {mode}: {path}")
827
- print("Please download or train your models before launching the demo.")
828
- exit(1)
829
- demo.launch()
 
70
 
71
  # Model caching
72
  if mode not in model_cache:
73
+ model_cache[mode] = load_model(mode)
74
  model = model_cache[mode]
75
 
76
  # Run prediction
 
529
  )
530
 
531
 
532
+ def build_ui():
533
+ # Load CSS safely (don't crash if file is missing on remote)
534
+ style_path = Path(__file__).resolve().parent / "style.css"
535
+ custom_css = style_path.read_text(encoding="utf-8") if style_path.exists() else ""
536
+
537
+ with gr.Blocks(css=custom_css) as demo:
538
+ # ----- Header -----
539
+ gr.Markdown("<h2 class='centered'>🩺 Medi-LLM: Clinical Triage Assistant 🩻</h2>")
540
+ gr.Markdown("<p class='centered'>Upload a chest X-ray and/or enter EMR text to get a triage level prediction.</p>")
541
+ gr.HTML(
542
+ """
543
+ <div class='welcome-banner' style="background-color: #24283b; border-left: 4px solid #7aa2f7; padding: 16px; border-radius: 8px; margin-bottom: 16px;">
544
+ <h3 style="margin-top: 0; color: #c0caf5;">👋 Welcome to Medi-LLM</h3>
545
+ <p style="color: #a9b1d6; line-height: 1.6;">
546
+ This AI assistant helps triage patients using <strong>EMR text</strong> and <strong>chest X-rays</strong>.<br>
547
+ 📝 Enter EMR notes, 📷 upload a chest X-ray, or use both for a multimodal diagnosis.<br>
548
+ 👩‍⚕️ Select <strong>Doctor</strong> mode to view insights like Grad-CAM heatmaps and token-level attention.<br>
549
+ 💾 Save your results for later by exporting them to a CSV file.
550
+ </p>
551
+ </div>
552
+ """
553
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
+ # ----- Hidden State -----
556
+ role_state = gr.State(value="User")
557
+ mode_state = gr.State(value=DEFAULT_MODE)
558
+ rollout_state = gr.State(value=False)
559
+ normaliza_mode_state = gr.State(value="visual")
560
+ inference_done = gr.State(value=False)
561
+
562
+ # ----- Role and Mode selection -----
563
+ with gr.Row(equal_height=True):
564
+ with gr.Column():
565
+ role = gr.Radio(["User", "Doctor"], value="User", label="Select Role", info="Doctors see insights like Grad-CAM and token attention", elem_id="role_selector")
566
+ mode = gr.Radio(["text", "image", "multimodal"], value=DEFAULT_MODE, label="Select Input Mode", info="Choose Diagnosis input type", elem_id="mode_selector")
567
+ with gr.Column(visible=False) as normalize_mode_column:
568
+ normalize_mode = gr.Radio(
569
+ ["visual", "probabilistic"],
570
+ value="visual",
571
+ label="Attention Normalization",
572
+ info="Softmax sums to 1 (probabilistic). Visual uses gamma-boosted scaling for color clarity."
573
+ )
574
+ use_rollout = gr.Checkbox(
575
+ label="Use attention rollout (CLS -> inputs)",
576
+ value=False,
577
+ info="Includes residuals and multiplies attention across layers. Slower but often more faithful."
578
+ )
579
+
580
+ # ----- Inputs -----
581
+ with gr.Row():
582
+ with gr.Column(scale=3, elem_id="text_col"):
583
+ emr_text, image, max_file_note = render_inputs(DEFAULT_MODE)
584
+
585
+ # ----- Actions -----
586
+ with gr.Row():
587
+ submit_btn = gr.Button(
588
+ "🔍 Run Inference",
589
+ elem_id="inference_btn"
590
+ )
591
+ reset_btn = gr.Button(
592
+ "↩️ Reset",
593
+ elem_id="reset_btn"
594
+ )
595
 
596
+ # ----- Outputs -----
597
+ with gr.Column(elem_classes=["output-box"]):
598
+ result_box = gr.Textbox(label="🧪 Triage Prediction", interactive=False)
599
+ confidence_label = gr.Label(label="📊 Confidence", visible=False)
600
+ prediction_count_box = gr.Textbox(value="Predictions: 0", interactive=False, label="🧮 Count", elem_id="prediction_count_box")
601
+ insights_tab = gr.Tabs(visible=False)
602
+ class_probs_json = gr.JSON(label="🔍 Class Probabilities", visible=True, elem_classes=["json-box"])
603
+ with insights_tab:
604
+ with gr.Tab("📷 Grad-CAM"):
605
+ gradcam_img = gr.Image(visible=False, elem_classes=["gr-image-box"])
606
+ with gr.Tab("🔬 Token Attention"):
607
+ token_attention = gr.HighlightedText(
608
+ visible=False,
609
+ show_legend=False,
610
+ color_map={
611
+ "0.0": "#7aa2f7", # blue
612
+ "0.25": "#80deea", # cyan
613
+ "0.5": "#fbc02d", # yellow
614
+ "0.75": "#ff8a65", # orange
615
+ "1.0": "#f7768e", # red
616
+ },
617
+ elem_classes=["token-attn-box"]
618
+ )
619
+ top5_html = gr.HTML(value="", visible=False)
620
+
621
+ inject_tooltips()
622
+
623
+ gr.HTML("""
624
+ <div class="attention-legend">
625
+ <div style="display: flex; align-items: center; gap: 8px;">
626
+ <span style="font-size: 14px; color: #c0caf5;">0.0</span>
627
+ <div class="attention-gradient-bar"></div>
628
+ <span style="font-size: 14px; color: #c0caf5;">1.0</span>
629
+ </div>
630
+ </div>
631
+ """)
632
+
633
+ with gr.Row():
634
+ loading_msg = gr.Markdown(value="", visible=False, elem_classes=["loading-msg"])
635
+
636
+ # ----- Inference Wiring -----
637
+ submit_btn.click(
638
+ fn=show_loading_msg,
639
+ outputs=[loading_msg]
640
+ ).then(
641
+ fn=classify,
642
+ inputs=[role_state, mode_state, normaliza_mode_state, emr_text, image, rollout_state],
643
+ outputs=[
644
+ result_box,
645
+ gradcam_img,
646
+ token_attention,
647
+ top5_html,
648
+ confidence_label,
649
+ insights_tab,
650
+ prediction_count_box,
651
+ class_probs_json,
652
+ ]
653
+ ).then(
654
+ fn=lambda: gr.update(value="", visible=False),
655
+ outputs=[loading_msg]
656
+ ).then(
657
+ fn=lambda: True,
658
+ outputs=[inference_done]
659
  )
660
+
661
+ # ----- Role/Mode/Param Change Wiring -----
662
+ role.change(
663
+ fn=update_role_state,
664
+ inputs=[role],
665
+ outputs=[role_state, normalize_mode_column, insights_tab, token_attention, gradcam_img, use_rollout, top5_html]
666
  )
667
 
668
+ # Input Updates
669
+ mode.change(
670
+ fn=lambda m: (*render_inputs(m), m),
671
+ inputs=[mode],
672
+ outputs=[emr_text, image, max_file_note, mode_state]
673
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
+ normalize_mode.change(
676
+ fn=lambda val: val,
677
+ inputs=[normalize_mode],
678
+ outputs=[normaliza_mode_state]
679
+ )
680
 
681
+ use_rollout.change(
682
+ fn=lambda v: v,
683
+ inputs=[use_rollout],
684
+ outputs=[rollout_state]
685
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
+ normalize_mode.change(
688
+ fn=rerun_if_done,
689
+ inputs=[inference_done, role_state, mode_state, normalize_mode, emr_text, image, rollout_state],
690
+ outputs=[
691
+ result_box,
692
+ gradcam_img,
693
+ token_attention,
694
+ top5_html,
695
+ confidence_label,
696
+ insights_tab,
697
+ prediction_count_box,
698
+ class_probs_json,
699
+ ]
700
+ )
701
 
702
+ use_rollout.change(
703
+ fn=rerun_if_done,
704
+ inputs=[inference_done, role_state, mode_state, normalize_mode, emr_text, image, rollout_state],
705
+ outputs=[
706
+ result_box,
707
+ gradcam_img,
708
+ token_attention,
709
+ top5_html,
710
+ confidence_label,
711
+ insights_tab,
712
+ prediction_count_box,
713
+ class_probs_json
714
+ ]
715
+ )
716
 
717
+ # ----- CSV Export & Log Controls -----
718
+ gr.Markdown("### 📁 Export Prediction Log")
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
+ with gr.Row(equal_height=True):
721
+ with gr.Column(scale=3):
722
+ filename_input = gr.Textbox(
723
+ label="CSV filename (optional)",
724
+ placeholder="e.g., triage_results.csv",
725
+ info="Set filename as needed or leave blank for auto-naming",
726
+ elem_id="csv_filename"
727
+ )
 
 
 
 
 
 
728
 
729
+ export_status_box = gr.Textbox(
730
+ value="",
731
+ visible=False,
732
+ interactive=False,
733
+ label="",
734
+ elem_id="export_status"
735
+ )
736
 
737
+ with gr.Column(scale=4):
738
+ gr.Markdown(
739
+ "📑 **Summary**\n\nDownload your triage results for clinical review or research.",
740
+ elem_classes="centered"
741
+ )
742
+ with gr.Row():
743
+ with gr.Column(scale=1, min_width=200):
744
+ download_btn = gr.Button("💾 Export CSV", elem_id="export_button")
745
+ with gr.Column(scale=1, min_width=200):
746
+ clear_btn = gr.Button("🗑️ Clear Logs", elem_id="clear_button")
747
+ confirm_clear_btn = gr.Button("✅ Confirm Clear", visible=False, elem_id="confirm_button")
748
+ confirm_box = gr.Textbox(label="Status", interactive=False, visible=False, elem_id="confirm_box")
749
+
750
+ with gr.Column(scale=3):
751
+ csv_output = gr.File(label="📂 Download Link", elem_id="download_box")
752
+
753
+ download_btn.click(
754
+ fn=export_csv,
755
+ inputs=[filename_input, role_state],
756
+ outputs=[
757
+ csv_output,
758
+ csv_output,
759
+ export_status_box
760
+ ]
761
+ ).then(
762
+ fn=blink_box_effect,
763
+ inputs=[csv_output],
764
+ outputs=[csv_output]
765
+ ).then(
766
+ fn=disable_filename_input,
767
+ outputs=[filename_input]
768
+ )
769
 
770
+ clear_btn.click(
771
+ fn=lambda: (
772
+ confirm_clear(),
773
+ gr.Button(visible=True),
774
+ ),
775
+ outputs=[confirm_box, confirm_clear_btn]
776
+ )
777
 
778
+ confirm_clear_btn.click(
779
+ fn=clear_confirmed,
780
+ inputs=[role_state],
781
+ outputs=[
782
+ prediction_count_box, # reset prediction count
783
+ confirm_box, # show success message
784
+ csv_output, # hide CSV output file
785
+ filename_input # re-enable input box
786
+ ]
787
+ ).then(
788
+ fn=lambda: gr.update(visible=False), # Hide confirm button
789
+ outputs=[confirm_clear_btn]
790
+ ).then(
791
+ fn=reset_confirm_box,
792
+ outputs=[confirm_box]
793
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
 
795
+ # ----- Reset Wiring -----
796
+ reset_btn.click(
797
+ fn=reset_ui,
798
+ outputs=[
799
+ emr_text, # 1
800
+ image, # 2
801
+ max_file_note, # 3
802
+ result_box, # 4
803
+ gradcam_img, # 5
804
+ token_attention, # 6
805
+ top5_html, # 7
806
+ confidence_label, # 8
807
+ insights_tab, # 9
808
+ class_probs_json, # 10
809
+ role_state, # 11
810
+ mode_state, # 12
811
+ normaliza_mode_state, # 13
812
+ role, # 14 (radio)
813
+ mode, # 15 (radio)
814
+ normalize_mode, # 16 (radio)
815
+ normalize_mode_column, # 17 (column visibility)
816
+ use_rollout, # 18
817
+ rollout_state, # 19
818
+ loading_msg, # 20
819
+ inference_done, # 21
820
+ export_status_box # 22
821
+ ]
822
+ )
823
+ return demo
824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
 
826
+ # Expose for Spaces & imports
827
+ demo = build_ui()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
 
829
  if __name__ == "__main__":
830
+ demo.launch(
831
+ server_name=os.getenv("GRADIO_SERVER_NAME", "127.0.0.1"),
832
+ server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
833
+ show_error=True,
834
+ )
 
app/utils/create_hf_space.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ HfApi().create_repo(
3
+ repo_id="Preetham22/medi-llm",
4
+ repo_type="space",
5
+ space_sdk="gradio",
6
+ exist_ok=True,
7
+ )
8
+
9
+ print("✅ Space ready: Preetham22/medi-llm")
app/utils/inference_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sys
2
  import torch
3
  import yaml
@@ -5,6 +6,7 @@ import numpy as np
5
  from pathlib import Path
6
  from transformers import AutoTokenizer
7
  from torchvision import transforms
 
8
 
9
  ROOT_DIR = Path(__file__).resolve().parent.parent.parent
10
  sys.path.append(str(ROOT_DIR))
@@ -12,10 +14,42 @@ sys.path.append(str(ROOT_DIR))
12
  from src.multimodal_model import MediLLMModel
13
  from app.utils.gradcam_utils import register_hooks, generate_gradcam
14
 
15
-
 
 
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- # Label map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  inv_map = {0: "low", 1: "medium", 2: "high"}
20
 
21
  # Tokenizer and image transform
@@ -27,22 +61,63 @@ image_transform = transforms.Compose([
27
  ])
28
 
29
 
30
- def load_model(mode, model_path, config_path=str(Path("config/config.yaml").resolve())):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  with open(config_path, "r") as f:
32
- config = yaml.safe_load(f)[mode]
 
 
 
33
 
 
34
  model = MediLLMModel(
35
  mode=mode,
36
  dropout=config["dropout"],
37
  hidden_dim=config["hidden_dim"]
38
  )
39
- state = torch.load(model_path, map_location=DEVICE)
40
- model.load_state_dict(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model.to(DEVICE)
42
  model.eval()
43
  return model
44
 
45
 
 
 
 
46
  def attention_rollout(attentions, last_k=4, residual_alpha=0.5):
47
  """
48
  attentions_tuple: tuple/list of layer attentions; each is (B,H,S,S)
@@ -139,6 +214,9 @@ def _normalize_for_display_wordlevel(attn_scores, normalize_mode="visual", tempe
139
  return attn_array0, labels
140
 
141
 
 
 
 
142
  def predict(
143
  model,
144
  mode,
 
1
+ import os
2
  import sys
3
  import torch
4
  import yaml
 
6
  from pathlib import Path
7
  from transformers import AutoTokenizer
8
  from torchvision import transforms
9
+ from huggingface_hub import hf_hub_download
10
 
11
  ROOT_DIR = Path(__file__).resolve().parent.parent.parent
12
  sys.path.append(str(ROOT_DIR))
 
14
  from src.multimodal_model import MediLLMModel
15
  from app.utils.gradcam_utils import register_hooks, generate_gradcam
16
 
17
+ # --------------------
18
+ # Runtime / Hub config
19
+ # --------------------
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # Map modes -> filenames in HF model repo
23
+ HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "Preetham22/medi-llm-weights")
24
+ HF_WEIGHTS_REV = os.getenv("HF_WEIGHTS_REV", None) # optional (commit/tag/branch)
25
+ FILENAMES = {
26
+ "text": "medi_llm_state_dict_text.pth",
27
+ "image": "medi_llm_state_dict_image.pth",
28
+ "multimodal": "medi_llm_state_dict_multimodal.pth"
29
+ }
30
+
31
+
32
+ def resolve_weights_path(mode: str) -> str:
33
+ """Download (or reuse cached) weights for the given mode from HF Hub."""
34
+ if mode not in FILENAMES:
35
+ raise ValueError(f"Unknown mode '{mode}'. Expected one of {list(FILENAMES)}.")
36
+ filename = FILENAMES[mode]
37
+ try:
38
+ return hf_hub_download(
39
+ repo_id=HF_MODEL_REPO,
40
+ filename=filename,
41
+ revision=HF_WEIGHTS_REV # can be None
42
+ )
43
+ except Exception as e:
44
+ raise RuntimeError(
45
+ f"Failed to fetch weights '{filename}' from repo '{HF_MODEL_REPO}'."
46
+ f"Set HF_MODEL_REPO or check filenames. Original error: {e}"
47
+ )
48
+
49
+
50
+ # ----------------------
51
+ # Labels / preprocessing
52
+ # ----------------------
53
  inv_map = {0: "low", 1: "medium", 2: "high"}
54
 
55
  # Tokenizer and image transform
 
61
  ])
62
 
63
 
64
+ # ----------------------
65
+ # Model load
66
+ # ----------------------
67
+ def _safe_torch_load(path: str, map_location: torch.device):
68
+ """
69
+ Prefer weights_only=True (newer Pytorch), but fall back if not supported.
70
+ """
71
+ try:
72
+ return torch.load(path, map_location=map_location, weights_only=True) # PyTorch >= 2.2/2.3
73
+ except TypeError:
74
+ return torch.load(path, map_location=map_location)
75
+
76
+
77
+ def load_model(mode: str, config_path: str = str(Path("config/config.yaml").resolve())):
78
+ """
79
+ Load MediLLMModel for the given mode and populate weights from HF Hub.
80
+ Expects config/config.yaml with keys per mode (dropout, hidden_dim).
81
+ """
82
  with open(config_path, "r") as f:
83
+ cfg_all = yaml.safe_load(f)
84
+ if mode not in cfg_all:
85
+ raise KeyError(f"Mode '{mode}' not found in {config_path}. Keys: {list(cfg_all.keys())}")
86
+ config = cfg_all[mode]
87
 
88
+ # Build model
89
  model = MediLLMModel(
90
  mode=mode,
91
  dropout=config["dropout"],
92
  hidden_dim=config["hidden_dim"]
93
  )
94
+
95
+ # Download weights & load
96
+ weights_path = resolve_weights_path(mode)
97
+ state = _safe_torch_load(weights_path, map_location=DEVICE)
98
+
99
+ # Sometimes checkpoints save as {'state_dict': ...}
100
+ if isinstance(state, dict) and "state_dict" in state:
101
+ state = state["state_dict"]
102
+
103
+ try:
104
+ model.load_state_dict(state) # strict by default
105
+ except RuntimeError as e:
106
+ # allow non-strict if minor mismatches (buffer names)
107
+ try:
108
+ model.load_state_dict(state, strict=False)
109
+ print(f"⚠️ Loaded with strict=False due to: {e}")
110
+ except Exception:
111
+ raise
112
+
113
  model.to(DEVICE)
114
  model.eval()
115
  return model
116
 
117
 
118
+ # -----------------------
119
+ # Attention rollout utils
120
+ # -----------------------
121
  def attention_rollout(attentions, last_k=4, residual_alpha=0.5):
122
  """
123
  attentions_tuple: tuple/list of layer attentions; each is (B,H,S,S)
 
214
  return attn_array0, labels
215
 
216
 
217
+ # ------------------
218
+ # Prediction
219
+ # ------------------
220
  def predict(
221
  model,
222
  mode,
requirements-dev.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (Bring in everything you listed, with safer caps)
2
+
3
+ # PyTorch stack (match minor versions!)
4
+ torch>=2.2,<2.4
5
+ torchvision>=0.17,<0.19
6
+ torchaudio>=2.2,<2.4
7
+
8
+ # Transformers & NLP
9
+ transformers>=4.41,<4.46
10
+ datasets>=2.19,<2.21
11
+ accelerate>=0.28,<0.31
12
+ peft>=0.11,<0.13
13
+ tokenizers>=0.15 # usually pulled by transformers, but explicit helps
14
+ safetensors>=0.4.3
15
+ huggingface_hub>=0.24,<0.26
16
+ # If you ever use T5/ALBERT etc., add:
17
+ # sentencepiece>=0.1.99
18
+
19
+ # Vision models
20
+ timm>=0.9.7,<1.0
21
+ opencv-python-headless>=4.8
22
+ Pillow>=10.0,<11
23
+
24
+ # Optimization / tracking
25
+ optuna>=3.5,<4
26
+ wandb>=0.16.6,<0.18
27
+
28
+ # Eval / viz
29
+ scikit-learn>=1.3,<1.6
30
+ matplotlib>=3.8,<3.9
31
+ seaborn>=0.13,<0.14
32
+ tqdm>=4.66,<5
33
+
34
+ # Data processing
35
+ pandas>=2.2,<2.3
36
+ numpy>=1.26,<2.2
37
+ pyyaml>=6.0
38
+ scipy>=1.11,<1.14 # sklearn relies on it; make it explicit to avoid surprises
39
+
40
+ # Optional deployment (FastAPI)
41
+ fastapi>=0.110,<0.114
42
+ pydantic>=2.5,<3
43
+ uvicorn>=0.27,<0.31
44
+ python-multipart>=0.0.6
45
+ # Optional perf:
46
+ json>=3.9
47
+
48
+ # Linting & testing
49
+ pytest>=7.4,<9
50
+ pytest-cov>=4.1,<5
51
+ pre-commit>=3.5,<4
52
+ flake8>=6.1,<7
53
+ # Optional modern linter:
54
+ ruff>=0.4,<0.7
requirements.txt CHANGED
@@ -1,41 +1,19 @@
1
- # Core PyTorch stack (CPU or GPU version to be installed separately)
2
- torch>=2.1.0
3
- torchvision>=0.16.0
4
- torchaudio>=2.1.0
5
 
6
- # Transformers and NLP
7
- transformers>=4.35.0
8
- datasets>=2.14.0
9
- accelerate>=0.25.0
10
- peft>=0.9.0
11
 
12
- # Vision and image models
13
- opencv-python>=4.8.0
14
- Pillow>=10.0.0
15
- timm>=0.9.2
16
-
17
- # Optimization and hyperparameter tuning
18
- optuna>=3.3.0
19
- wandb>=0.15.0
20
-
21
- # Evaluation and visualization
22
- scikit-learn>=1.3.0
23
- matplotlib>=3.8.0
24
- seaborn>=0.13.0
25
- tqdm>=4.65.0
26
-
27
- # Data processing
28
- pandas>=2.1.0
29
- numpy>=1.25.0
30
  pyyaml>=6.0
31
-
32
- # FastAPI for deployment
33
- fastapi>=0.100.0
34
- uvicorn>=0.27.0
35
- python-multipart>0.0.6
36
-
37
- # Linting and testing
38
- pytest>=7.4.0
39
- pytest-cov>=4.1
40
- pre-commit>=3.5.0
41
- flake8>=6.1.0
 
1
+ # Core runtime (CPU)
2
+ torch>=2.2,<2.4
3
+ torchvision>=0.17,<0.19
4
+ # torchaudio not needed for this app; add if you really use it
5
 
6
+ # UI + model fetch
7
+ gradio>=3.45.2,<3.47
8
+ huggingface_hub>=0.24,<0.26
9
+ safetensors>=0.4.3
 
10
 
11
+ # Image / utils
12
+ opencv-python-headless>=4.8
13
+ Pillow>=10.0,<11
14
+ pandas>=2.2,<2.3
15
+ numpy>=1.26,<2.2
16
+ scikit-learn>=1.3,<1.6
17
+ tqdm>=4.66,<5
18
+ matplotlib>=3.8,<3.9
 
 
 
 
 
 
 
 
 
 
19
  pyyaml>=6.0