bdtimuhammad commited on
Commit
fc05593
·
verified ·
1 Parent(s): 1d89d63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -7
app.py CHANGED
@@ -1,7 +1,21 @@
1
- # Initialize Model Loader globally so it stays in VRAM
 
 
 
 
 
 
 
 
 
 
 
2
  loader = ModelLoader()
3
- # Load the model once during startup
4
- clinical_model = loader.load_model()
 
 
 
5
 
6
  def process_clinical_assessment(image):
7
  if image is None:
@@ -10,11 +24,94 @@ def process_clinical_assessment(image):
10
 
11
  yield image, "Running CNN Inference (utilizing pre-loaded model)..."
12
 
13
- # Use the globally loaded model instead of re-loading
14
  results = run_inference(clinical_model, image)
15
- # ... (rest of your logic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # ... inside your demo block ...
18
  if __name__ == "__main__":
19
- # Disable SSR for maximum stability on Hugging Face
20
  demo.launch(ssr_mode=False)
 
1
+ import gradio as gr
2
+ import dotenv
3
+ import os
4
+ from PIL import Image
5
+
6
+ dotenv.load_dotenv()
7
+
8
+ from loader import ModelLoader
9
+ from inference import run_inference
10
+ from tutor import generate_socratic_assessment
11
+
12
+ # Initialize Model Loader
13
  loader = ModelLoader()
14
+ # Load model once at startup to save VRAM and time
15
+ clinical_model = loader.load_model()
16
+
17
+ # Global state to hold results between steps
18
+ current_ai_results = {}
19
 
20
  def process_clinical_assessment(image):
21
  if image is None:
 
24
 
25
  yield image, "Running CNN Inference (utilizing pre-loaded model)..."
26
 
27
+ # Use the globally loaded model
28
  results = run_inference(clinical_model, image)
29
+
30
+ global current_ai_results
31
+ current_ai_results = results
32
+
33
+ diagnosis_data = {
34
+ "modality": "X-Ray",
35
+ "probabilities": results["all_probabilities"]
36
+ }
37
+
38
+ yield image, "Generating Medical Tutor Socratic Feedback..."
39
+
40
+ stream_generator = generate_socratic_assessment(diagnosis_data)
41
+
42
+ for partial_text in stream_generator:
43
+ yield image, partial_text
44
+
45
+ def reveal_ai_analysis():
46
+ global current_ai_results
47
+ if not current_ai_results:
48
+ return None, "No results yet. Please run the Clinical Assessment first.", {}
49
+
50
+ heatmap = current_ai_results.get("heatmap_image", None)
51
+ confidence = current_ai_results.get("confidence", 0.0)
52
+ probs = current_ai_results.get("all_probabilities", {})
53
+
54
+ top_diagnosis = current_ai_results.get("top_diagnosis", "Unknown")
55
+ confidence_text = f"## Top Diagnosis: **{top_diagnosis}** ({confidence*100:.1f}%)\n\nReview the heatmap to audit for spatial mismatch."
56
+
57
+ return heatmap, confidence_text, probs
58
+
59
+ def log_audit(audit_status, user_notes):
60
+ if not audit_status:
61
+ return "Please select an Audit Status."
62
+ return f"Audit Logged successfully!\n\nStatus: {audit_status}\nNotes: {user_notes}"
63
+
64
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI-VECINNA") as demo:
65
+ gr.Markdown("# AI-VECINNA: Dual-Model Medical Auditing (Powered by Local MedGemma)")
66
+ gr.Markdown("**Status: Running on NVIDIA T4 GPU (Persistent)**")
67
+
68
+ with gr.Accordion("Step 1: Clinical Assessment & Tutor", open=True):
69
+ with gr.Row():
70
+ with gr.Column():
71
+ image_input = gr.Image(type="pil", label="Upload Chest X-Ray")
72
+ analyze_btn = gr.Button("Analyze & Generate Tutor Scenario", variant="primary")
73
+ with gr.Column():
74
+ gr.Markdown("### MedGemma Socratic Tutor Feedback")
75
+ scenario_output = gr.Markdown("Tutor instructions will appear here.")
76
+ human_hypothesis = gr.Textbox(label="Record your initial differential diagnosis here...", lines=3)
77
+
78
+ with gr.Accordion("Step 2: AI Reveal & Audit", open=False):
79
+ reveal_btn = gr.Button("Reveal AI Analysis", variant="secondary")
80
+ with gr.Row():
81
+ with gr.Column():
82
+ heatmap_output = gr.Image(type="pil", label="GradCAM AI Heatmap")
83
+ with gr.Column():
84
+ confidence_output = gr.Markdown("Confidence results here.")
85
+ probs_output = gr.Label(label="Full Probability Distribution")
86
+
87
+ audit_radio = gr.Radio(
88
+ label="AI Safety Audit",
89
+ choices=[
90
+ "AI Verified: Heatmap & Confidence clinically align.",
91
+ "AI Flagged: Spatial mismatch (Hallucination).",
92
+ "AI Flagged: Low confidence/over-reliance risk."
93
+ ]
94
+ )
95
+ audit_notes = gr.Textbox(label="Audit Notes", lines=2)
96
+ audit_btn = gr.Button("Submit Safety Audit", variant="primary")
97
+ audit_result = gr.Textbox(label="Status", interactive=False)
98
+
99
+ analyze_btn.click(
100
+ fn=process_clinical_assessment,
101
+ inputs=[image_input],
102
+ outputs=[image_input, scenario_output]
103
+ )
104
+
105
+ reveal_btn.click(
106
+ fn=reveal_ai_analysis,
107
+ outputs=[heatmap_output, confidence_output, probs_output]
108
+ )
109
+
110
+ audit_btn.click(
111
+ fn=log_audit,
112
+ inputs=[audit_radio, audit_notes],
113
+ outputs=audit_result
114
+ )
115
 
 
116
  if __name__ == "__main__":
 
117
  demo.launch(ssr_mode=False)