halilolcay commited on
Commit
b2b69aa
·
verified ·
1 Parent(s): 825a0a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -156
app.py CHANGED
@@ -1,8 +1,9 @@
1
-
2
  import warnings
3
  import json
4
  import torch
5
  import random
 
 
6
  from transformers import pipeline
7
  from datasets import load_dataset
8
  from sentence_transformers import SentenceTransformer, util
@@ -11,51 +12,18 @@ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_sc
11
  warnings.filterwarnings("ignore")
12
 
13
  # ============================================================================
14
- # 1.DATA
15
  # ============================================================================
16
  device = "mps" if torch.backends.mps.is_available() else "cpu"
17
 
18
- dataset = load_dataset(
19
- "UTAustin-AIHealth/MedHallu",
20
- "pqa_labeled",
21
- split="train",
22
- streaming=True
23
- )
24
-
25
- data_pool = list(dataset.take(200))
26
- samples = random.sample(data_pool, 30)
27
-
28
- # ============================================================================
29
- # 2. MODELS
30
- # ============================================================================
31
-
32
- nli_model = pipeline(#mantık
33
- "text-classification",
34
- model="pritamdeka/PubMedBERT-MNLI-MedNLI",
35
- device=device,
36
- truncation=True,
37
- max_length=512
38
- )
39
-
40
- sim_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)#kapsam
41
-
42
- clf_model = pipeline(#alaka
43
- "text-classification",
44
- model="cross-encoder/ms-marco-MiniLM-L-6-v2",
45
- device=device,
46
- truncation=True
47
- )
48
-
49
- # Instruction-following correction model
50
- correction_llm = pipeline(
51
- "text2text-generation",
52
- model="google/flan-t5-large",
53
- device=device,
54
- max_length=512
55
- )
56
 
57
  # ============================================================================
58
- # 3. DETECTION FUNCTIONS
59
  # ============================================================================
60
  def detect_nli(evidence, answer):
61
  res = nli_model(f"{evidence} [SEP] {answer}")[0]
@@ -69,129 +37,86 @@ def detect_similarity(evidence, answer):
69
  def detect_uncertainty(evidence, answer):
70
  return clf_model(f"{evidence} [SEP] {answer}")[0]["score"]
71
 
72
- # ============================================================================
73
- # 4. CORRECTION PROMPT
74
- # ============================================================================
75
  def build_correction_prompt(query, wrong, truth):
76
- return f"""
77
- You are a board-certified medical doctor.
78
-
79
- A previous AI answer contains a clinical error.
80
-
81
- QUESTION:
82
- {query}
83
-
84
- INCORRECT ANSWER:
85
- {wrong}
86
-
87
- VERIFIED MEDICAL EVIDENCE:
88
- {truth}
89
-
90
- TASK:
91
- 1. Briefly explain why the original answer is incorrect.
92
- 2. Provide the corrected, clinically accurate answer.
93
- """
94
 
95
  def generate_correction(prompt):
96
  return correction_llm(prompt)[0]["generated_text"]
97
 
98
  # ============================================================================
99
- # 5. EVALUATION LOOP
100
  # ============================================================================
101
- results = []
102
- y_true, y_pred = [], []
103
-
104
- for i, sample in enumerate(samples):
105
- evidence = " ".join(sample["Knowledge"])
106
- query = sample["Question"]
107
- hallucinated = sample["Hallucinated Answer"]
108
- factual = sample["Ground Truth"]
109
-
110
- # Balanced evaluation
111
- if i % 2 == 0:
112
- llm_answer = hallucinated
113
- label = 1
114
- else:
115
- llm_answer = factual
116
- label = 0
117
-
118
-
119
- nli_label, _ = detect_nli(evidence, llm_answer)
120
- sim_score = detect_similarity(evidence, llm_answer)
121
- unc_score = detect_uncertainty(evidence, llm_answer)
122
-
123
- detected = 0
124
- reason = "Consistent with evidence"
125
-
126
- # Safety-first but calibrated thresholds
127
- if nli_label == "contradiction":
128
- detected = 1
129
- reason = "Logical contradiction with medical evidence"
130
- elif sim_score < 0.30:
131
  detected = 1
132
- reason = "Semantic drift from clinical context"
133
- elif unc_score < 0.25:
134
- detected = 1
135
- reason = "Low relevance / high uncertainty"
136
-
137
- y_true.append(label)
138
- y_pred.append(detected)
139
-
140
- correction = None
141
- if detected:
142
- prompt = build_correction_prompt(query, llm_answer, factual)
143
- corrected_answer = generate_correction(prompt)
144
- correction = {
145
- "physician_prompt": prompt,
146
- "llm_corrected_answer": corrected_answer
147
- }
148
-
149
- results.append({
150
- "case_id": i + 1,
151
- "query": query,
152
- "llm_original_answer": llm_answer,
153
- "ground_truth_answer": factual,
154
- "detection": {
155
- "label": label,
156
- "prediction": detected,
157
- "reason": reason,
158
- "signals": {
159
- "nli": nli_label,
160
- "similarity": round(sim_score, 3),
161
- "uncertainty": round(unc_score, 3)
162
- }
163
- },
164
- "correction": correction
165
- })
166
-
167
- print(f"Case {i+1:02}: {'⚠️ Hallucination' if detected else '✅ Factual'}")
168
 
169
  # ============================================================================
170
- # 6. METRICS
171
  # ============================================================================
172
- acc = accuracy_score(y_true, y_pred)
173
- prec = precision_score(y_true, y_pred)
174
- rec = recall_score(y_true, y_pred)
175
- f1 = f1_score(y_true, y_pred)
176
- cm = confusion_matrix(y_true, y_pred)
177
-
178
- print("\n=== FINAL RESULTS ===")
179
- print(f"Accuracy : {acc:.3f}")
180
- print(f"Precision: {prec:.3f}")
181
- print(f"Recall : {rec:.3f}")
182
- print(f"F1-score : {f1:.3f}")
183
- print("Confusion Matrix:\n", cm)
184
-
185
- with open("final_clinical_hallucination_results.json", "w") as f:
186
- json.dump({
187
- "metrics": {
188
- "accuracy": acc,
189
- "precision": prec,
190
- "recall": rec,
191
- "f1": f1,
192
- "confusion_matrix": cm.tolist()
193
- },
194
- "results": results
195
- }, f, indent=2)
196
-
197
- print("\n✓ FINAL audit complete. Results saved.")
 
 
1
  import warnings
2
  import json
3
  import torch
4
  import random
5
+ import os
6
+ import gradio as gr
7
  from transformers import pipeline
8
  from datasets import load_dataset
9
  from sentence_transformers import SentenceTransformer, util
 
12
  warnings.filterwarnings("ignore")
13
 
14
  # ============================================================================
15
+ # 1. INITIALIZATION & MODELS
16
  # ============================================================================
17
  device = "mps" if torch.backends.mps.is_available() else "cpu"
18
 
19
+ print("[INFO] Loading Expert Models...")
20
+ nli_model = pipeline("text-classification", model="pritamdeka/PubMedBERT-MNLI-MedNLI", device=device)
21
+ sim_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
22
+ clf_model = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-6-v2", device=device)
23
+ correction_llm = pipeline("text2text-generation", model="google/flan-t5-large", device=device, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # ============================================================================
26
+ # 2. CORE FUNCTIONS
27
  # ============================================================================
28
  def detect_nli(evidence, answer):
29
  res = nli_model(f"{evidence} [SEP] {answer}")[0]
 
37
  def detect_uncertainty(evidence, answer):
38
  return clf_model(f"{evidence} [SEP] {answer}")[0]["score"]
39
 
 
 
 
40
  def build_correction_prompt(query, wrong, truth):
41
+ return f"You are a doctor. Explain error in: {wrong}. Correct it using: {truth} for Question: {query}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def generate_correction(prompt):
44
  return correction_llm(prompt)[0]["generated_text"]
45
 
46
  # ============================================================================
47
+ # 3. THE AUDIT ENGINE (Main Logic for Gradio)
48
  # ============================================================================
49
+ def run_clinical_audit():
50
+ # Load Dataset (Streaming)
51
+ dataset = load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled", split="train", streaming=True)
52
+ data_pool = list(dataset.take(100))
53
+ samples = random.sample(data_pool, 30)
54
+
55
+ results = []
56
+ y_true, y_pred = [], []
57
+
58
+ for i, sample in enumerate(samples):
59
+ evidence = " ".join(sample["Knowledge"])
60
+ query = sample["Question"]
61
+ factual = sample["Ground Truth"]
62
+
63
+ # Balanced flip
64
+ label = 1 if i % 2 == 0 else 0
65
+ llm_answer = sample["Hallucinated Answer"] if label == 1 else factual
66
+
67
+ # Detection logic
68
+ nli_label, _ = detect_nli(evidence, llm_answer)
69
+ sim_score = detect_similarity(evidence, llm_answer)
70
+ unc_score = detect_uncertainty(evidence, llm_answer)
71
+
72
+ detected = 0
73
+ reason = "Consistent"
74
+ if nli_label == "contradiction" or sim_score < 0.30 or unc_score < 0.25:
 
 
 
 
75
  detected = 1
76
+ reason = "Hallucination Detected"
77
+
78
+ y_true.append(label)
79
+ y_pred.append(detected)
80
+
81
+ correction = None
82
+ if detected:
83
+ prompt = build_correction_prompt(query, llm_answer, factual)
84
+ correction = {"corrected": generate_correction(prompt)}
85
+
86
+ results.append({
87
+ "case_id": i + 1,
88
+ "query": query,
89
+ "detection": {"label": label, "prediction": detected, "reason": reason},
90
+ "correction": correction
91
+ })
92
+
93
+ # Metrics
94
+ metrics = {
95
+ "accuracy": accuracy_score(y_true, y_pred),
96
+ "recall": recall_score(y_true, y_pred),
97
+ "f1": f1_score(y_true, y_pred)
98
+ }
99
+
100
+ # Save File
101
+ file_name = "final_clinical_hallucination_results.json"
102
+ with open(file_name, "w") as f:
103
+ json.dump({"metrics": metrics, "results": results}, f, indent=2)
104
+
105
+ return f"✅ Audit Complete!\nAccuracy: {metrics['accuracy']:.2f}\nRecall: {metrics['recall']:.2f}", file_name
 
 
 
 
 
 
106
 
107
  # ============================================================================
108
+ # 4. GRADIO INTERFACE (To see and download file)
109
  # ============================================================================
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown("# 🩺 Healthcare LLM Hallucination Audit System")
112
+ gr.Markdown("Click the button below to start the 30-case randomized clinical evaluation.")
113
+
114
+ with gr.Row():
115
+ run_btn = gr.Button("🚀 Start Clinical Audit", variant="primary")
116
+
117
+ output_text = gr.Textbox(label="Status & Summary")
118
+ output_file = gr.File(label="📥 Download Result JSON")
119
+
120
+ run_btn.click(fn=run_clinical_audit, inputs=None, outputs=[output_text, output_file])
121
+
122
+ demo.launch()