bdtimuhammad commited on
Commit
1fbbb20
·
verified ·
1 Parent(s): a5bd0dd

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +135 -0
  2. requirements.txt +11 -0
  3. test_access.py +1 -0
  4. test_load.py +28 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import dotenv
3
+ from PIL import Image
4
+
5
+ dotenv.load_dotenv()
6
+
7
+ from models.loader import ModelLoader
8
+ from models.inference import run_inference
9
+ from reporting.tutor import generate_socratic_assessment
10
+
11
+ # Initialize Model Loader
12
+ loader = ModelLoader()
13
+
14
+ # Global state to hold results between steps
15
+ current_ai_results = {}
16
+
17
+ def process_clinical_assessment(image):
18
+ if image is None:
19
+ yield image, "Please upload a Chest X-Ray image first."
20
+ return
21
+
22
+ yield image, "Loading Model and running CNN Inference (this may take a moment)..."
23
+ model = loader.load_model()
24
+
25
+ if model is None:
26
+ yield image, "### Error\nModel failed to load. Please check console logs."
27
+ return
28
+
29
+ results = run_inference(model, image)
30
+
31
+ global current_ai_results
32
+ current_ai_results = results
33
+
34
+ diagnosis_data = {
35
+ "modality": "X-Ray",
36
+ "probabilities": results["all_probabilities"]
37
+ }
38
+
39
+ yield image, "Generating Medical Tutor Socratic Feedack..."
40
+
41
+ stream_generator = generate_socratic_assessment(diagnosis_data)
42
+
43
+ for partial_text in stream_generator:
44
+ yield image, partial_text
45
+
46
+ def reveal_ai_analysis():
47
+ global current_ai_results
48
+ if not current_ai_results:
49
+ return None, "No results yet. Please run the Clinical Assessment first.", {}
50
+
51
+ heatmap = current_ai_results.get("heatmap_image", None)
52
+ confidence = current_ai_results.get("confidence", 0.0)
53
+ probs = current_ai_results.get("all_probabilities", {})
54
+
55
+ top_diagnosis = current_ai_results.get("top_diagnosis", "Unknown")
56
+ confidence_text = f"## Top Diagnosis: **{top_diagnosis}** ({confidence*100:.1f}%)\n\nReview the heatmap to audit for spatial mismatch."
57
+
58
+ return heatmap, confidence_text, probs
59
+
60
+ def log_audit(audit_status, user_notes):
61
+ if not audit_status:
62
+ return "Please select an Audit Status."
63
+ return f"Audit Logged successfully!\n\nStatus: {audit_status}\nNotes: {user_notes}\n\nProceed to the next case to continue your education."
64
+
65
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI-VECINNA") as demo:
66
+ gr.Markdown("# AI-VECINNA: Dual-Model Medical Auditing (Powered by Local MedGemma)")
67
+ gr.Markdown("**Status: Running on NVIDIA T4 GPU (Persistent)**")
68
+ gr.Markdown("Welcome! This system aims to train you in identifying discrepancies between AI predictions (heatmaps/probabilities) and your clinical knowledge. Remember, AI is fallible; you are the human-in-the-loop.")
69
+
70
+ with gr.Accordion("Step 1: Clinical Assessment & Tutor", open=True):
71
+ with gr.Row():
72
+ with gr.Column():
73
+ image_input = gr.Image(type="pil", label="Upload Chest X-Ray")
74
+ analyze_btn = gr.Button("Analyze & Generate Tutor Scenario", variant="primary")
75
+
76
+ with gr.Column():
77
+ gr.Markdown("### MedGemma Socratic Tutor Feedback")
78
+ scenario_output = gr.Markdown("Tutor instructions and questions will appear here after analysis.")
79
+
80
+ gr.Markdown("### Your Human Hypothesis")
81
+ human_hypothesis = gr.Textbox(label="Record your initial differential diagnosis here...", lines=3)
82
+
83
+ with gr.Accordion("Step 2: AI Reveal & Audit (The Safety Standard)", open=False):
84
+ reveal_btn = gr.Button("Reveal AI Analysis", variant="secondary")
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ heatmap_output = gr.Image(type="pil", label="GradCAM AI Heatmap")
89
+ with gr.Column():
90
+ confidence_output = gr.Markdown("Confidence results will appear here.")
91
+ probs_output = gr.Label(label="Full Probability Distribution")
92
+
93
+ gr.Markdown("### Audit Form")
94
+ audit_radio = gr.Radio(
95
+ label="AI Safety Audit (Human-in-the-Loop Validation)",
96
+ choices=[
97
+ "AI Verified: Heatmap & Confidence clinically align.",
98
+ "AI Flagged: Spatial mismatch (Hallucination).",
99
+ "AI Flagged: Low confidence/over-reliance risk."
100
+ ]
101
+ )
102
+ audit_notes = gr.Textbox(label="Additional Audit Notes", lines=2)
103
+ audit_btn = gr.Button("Submit Safety Audit", variant="primary")
104
+
105
+ audit_result = gr.Textbox(label="Audit Submission Status", interactive=False)
106
+
107
+ # Wiring with loading state
108
+ analyze_btn.click(
109
+ fn=lambda: gr.update(interactive=False, value="Analyzing & Generating (Streaming)..."),
110
+ inputs=None,
111
+ outputs=analyze_btn
112
+ ).then(
113
+ fn=process_clinical_assessment,
114
+ inputs=[image_input],
115
+ outputs=[image_input, scenario_output]
116
+ ).then(
117
+ fn=lambda: gr.update(interactive=True, value="Analyze & Generate Tutor Scenario"),
118
+ inputs=None,
119
+ outputs=analyze_btn
120
+ )
121
+
122
+ reveal_btn.click(
123
+ fn=reveal_ai_analysis,
124
+ inputs=[],
125
+ outputs=[heatmap_output, confidence_output, probs_output]
126
+ )
127
+
128
+ audit_btn.click(
129
+ fn=log_audit,
130
+ inputs=[audit_radio, audit_notes],
131
+ outputs=audit_result
132
+ )
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchxrayvision
4
+ grad-cam
5
+ gradio
6
+ Pillow
7
+ numpy
8
+ python-dotenv
9
+ transformers
10
+ accelerate
11
+ huggingface_hub
test_access.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from transformers import AutoConfig; import os; from dotenv import load_dotenv; load_dotenv(); print("Testing access..."); config = AutoConfig.from_pretrained("google/medgemma-1.5-4b-it", token=os.getenv("HF_TOKEN")); print("Access Verified: ", config.model_type)
test_load.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+ hf_token = os.getenv("HF_TOKEN")
8
+ model_id = "google/medgemma-1.5-4b-it"
9
+
10
+ print(f"Testing HF_TOKEN: {hf_token[:5]}...{hf_token[-5:] if hf_token else 'None'}")
11
+ print(f"Model ID: {model_id}")
12
+
13
+ try:
14
+ print("Attempting to load tokenizer...")
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
16
+ print("Tokenizer loaded successfully.")
17
+
18
+ print("Attempting to load model config (not weights yet)...")
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_id,
21
+ token=hf_token,
22
+ torch_dtype=torch.bfloat16,
23
+ device_map="cpu",
24
+ low_cpu_mem_usage=True
25
+ )
26
+ print("Model loaded successfully.")
27
+ except Exception as e:
28
+ print(f"DIAGNOSTIC FAILURE: {e}")