halilolcay commited on
Commit
a7d786a
·
verified ·
1 Parent(s): b801753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -34
app.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import random
5
  import os
6
  import gradio as gr
7
- from transformers import pipeline
8
  from datasets import load_dataset
9
  from sentence_transformers import SentenceTransformer, util
10
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
@@ -14,13 +14,31 @@ warnings.filterwarnings("ignore")
14
  # ============================================================================
15
  # 1. INITIALIZATION & MODELS
16
  # ============================================================================
17
- device = "mps" if torch.backends.mps.is_available() else "cpu"
18
 
19
- print("[INFO] Loading Expert Models...")
20
- nli_model = pipeline("text-classification", model="pritamdeka/PubMedBERT-MNLI-MedNLI", device=device,truncation=True, max_length=512)
21
  sim_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
22
  clf_model = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-6-v2", device=device, truncation=True, max_length=512)
23
- correction_llm = pipeline("text2text-generation", model="google/flan-t5-large", device=device, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # ============================================================================
26
  # 2. CORE FUNCTIONS
@@ -37,20 +55,41 @@ def detect_similarity(evidence, answer):
37
  def detect_uncertainty(evidence, answer):
38
  return clf_model(f"{evidence} [SEP] {answer}")[0]["score"]
39
 
40
- def build_correction_prompt(query, wrong, truth):
41
- return f"You are a doctor. Explain error in: {wrong}. Correct it using: {truth} for Question: {query}"
42
-
43
- def generate_correction(prompt):
44
- return correction_llm(prompt)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # ============================================================================
47
- # 3. THE AUDIT ENGINE (Main Logic for Gradio)
48
  # ============================================================================
49
  def run_clinical_audit():
50
- # Load Dataset (Streaming)
51
  dataset = load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled", split="train", streaming=True)
52
- data_pool = list(dataset.take(100))
53
- samples = random.sample(data_pool, 30)
54
 
55
  results = []
56
  y_true, y_pred = [], []
@@ -59,45 +98,53 @@ def run_clinical_audit():
59
  evidence = " ".join(sample["Knowledge"])
60
  query = sample["Question"]
61
  factual = sample["Ground Truth"]
 
62
 
63
- # Balanced flip
64
  label = 1 if i % 2 == 0 else 0
65
- llm_answer = sample["Hallucinated Answer"] if label == 1 else factual
66
 
67
- # Detection logic
68
  nli_label, _ = detect_nli(evidence, llm_answer)
69
  sim_score = detect_similarity(evidence, llm_answer)
70
  unc_score = detect_uncertainty(evidence, llm_answer)
71
 
72
  detected = 0
73
- reason = "Consistent"
74
  if nli_label == "contradiction" or sim_score < 0.30 or unc_score < 0.25:
75
  detected = 1
76
- reason = "Hallucination Detected"
77
 
78
  y_true.append(label)
79
  y_pred.append(detected)
80
 
81
  correction = None
82
  if detected:
83
- prompt = build_correction_prompt(query, llm_answer, factual)
84
- correction = {"corrected": generate_correction(prompt)}
 
 
 
85
 
86
  results.append({
87
  "case_id": i + 1,
88
  "query": query,
89
- "detection": {"label": label, "prediction": detected, "reason": reason},
 
 
 
 
 
 
 
90
  "correction": correction
91
  })
92
 
93
- # Metrics
94
  metrics = {
95
  "accuracy": accuracy_score(y_true, y_pred),
96
  "recall": recall_score(y_true, y_pred),
97
- "f1": f1_score(y_true, y_pred)
 
98
  }
99
 
100
- # Save File
101
  file_name = "final_clinical_hallucination_results.json"
102
  with open(file_name, "w") as f:
103
  json.dump({"metrics": metrics, "results": results}, f, indent=2)
@@ -105,18 +152,16 @@ def run_clinical_audit():
105
  return f"✅ Audit Complete!\nAccuracy: {metrics['accuracy']:.2f}\nRecall: {metrics['recall']:.2f}", file_name
106
 
107
  # ============================================================================
108
- # 4. GRADIO INTERFACE (To see and download file)
109
  # ============================================================================
110
  with gr.Blocks() as demo:
111
- gr.Markdown("# 🩺 Healthcare LLM Hallucination Audit System")
112
- gr.Markdown("Click the button below to start the 30-case randomized clinical evaluation.")
113
-
114
- with gr.Row():
115
- run_btn = gr.Button("🚀 Start Clinical Audit", variant="primary")
116
 
117
- output_text = gr.Textbox(label="Status & Summary")
118
- output_file = gr.File(label="📥 Download Result JSON")
 
119
 
120
- run_btn.click(fn=run_clinical_audit, inputs=None, outputs=[output_text, output_file])
121
 
122
  demo.launch()
 
4
  import random
5
  import os
6
  import gradio as gr
7
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
  from datasets import load_dataset
9
  from sentence_transformers import SentenceTransformer, util
10
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
 
14
  # ============================================================================
15
  # 1. INITIALIZATION & MODELS
16
  # ============================================================================
17
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
18
 
19
+ print("[INFO] Loading Expert Models (NLI, Similarity, Uncertainty)...")
20
+ nli_model = pipeline("text-classification", model="pritamdeka/PubMedBERT-MNLI-MedNLI", device=device, truncation=True, max_length=512)
21
  sim_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
22
  clf_model = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-6-v2", device=device, truncation=True, max_length=512)
23
+
24
+ # Nous-Hermes-2-Mistral-7B-DPO Yükleme (4-bit Sıkıştırma ile)
25
+ print("[INFO] Loading Nous-Hermes-2-Mistral-7B-DPO (4-bit optimized)...")
26
+ model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO"
27
+
28
+ # Ücretsiz HF Space (16GB VRAM) için kritik ayarlar
29
+ quant_config = BitsAndBytesConfig(
30
+ load_in_4bit=True,
31
+ bnb_4bit_compute_dtype=torch.float16,
32
+ bnb_4bit_quant_type="nf4",
33
+ bnb_4bit_use_double_quant=True
34
+ )
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
37
+ correction_model = AutoModelForCausalLM.from_pretrained(
38
+ model_id,
39
+ quantization_config=quant_config,
40
+ device_map="auto"
41
+ )
42
 
43
  # ============================================================================
44
  # 2. CORE FUNCTIONS
 
55
  def detect_uncertainty(evidence, answer):
56
  return clf_model(f"{evidence} [SEP] {answer}")[0]["score"]
57
 
58
+ def generate_correction(query, wrong, truth):
59
+ # Nous-Hermes-2 ChatML Formatı
60
+ prompt = f"""<|im_start|>system
61
+ You are a board-certified medical doctor. Analyze the clinical error and provide a fix based ONLY on verified evidence.<|im_end|>
62
+ <|im_start|>user
63
+ QUESTION: {query}
64
+ INCORRECT ANSWER: {wrong}
65
+ VERIFIED EVIDENCE: {truth}
66
+
67
+ TASK:
68
+ 1. Explain why the answer is incorrect.
69
+ 2. Provide the clinically accurate correction.<|im_end|>
70
+ <|im_start|>assistant
71
+ """
72
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
73
+
74
+ with torch.no_grad():
75
+ outputs = correction_model.generate(
76
+ **inputs,
77
+ max_new_tokens=300,
78
+ temperature=0.1, # Tıbbi doğruluk için düşük sıcaklık
79
+ eos_token_id=tokenizer.eos_token_id
80
+ )
81
+
82
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+ # Sadece asistanın cevabını ayıklıyoruz
84
+ return decoded.split("assistant")[-1].strip()
85
 
86
  # ============================================================================
87
+ # 3. THE AUDIT ENGINE (N=20)
88
  # ============================================================================
89
  def run_clinical_audit():
 
90
  dataset = load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled", split="train", streaming=True)
91
+ data_pool = list(dataset.take(150))
92
+ samples = random.sample(data_pool, 20)
93
 
94
  results = []
95
  y_true, y_pred = [], []
 
98
  evidence = " ".join(sample["Knowledge"])
99
  query = sample["Question"]
100
  factual = sample["Ground Truth"]
101
+ hallucinated = sample["Hallucinated Answer"]
102
 
 
103
  label = 1 if i % 2 == 0 else 0
104
+ llm_answer = hallucinated if label == 1 else factual
105
 
 
106
  nli_label, _ = detect_nli(evidence, llm_answer)
107
  sim_score = detect_similarity(evidence, llm_answer)
108
  unc_score = detect_uncertainty(evidence, llm_answer)
109
 
110
  detected = 0
111
+ reason = "Factual"
112
  if nli_label == "contradiction" or sim_score < 0.30 or unc_score < 0.25:
113
  detected = 1
114
+ reason = "Clinical Hallucination Detected"
115
 
116
  y_true.append(label)
117
  y_pred.append(detected)
118
 
119
  correction = None
120
  if detected:
121
+ corrected_text = generate_correction(query, llm_answer, factual)
122
+ correction = {
123
+ "physician_prompt": "Nous-Hermes-2 ChatML Structure",
124
+ "llm_corrected_answer": corrected_text
125
+ }
126
 
127
  results.append({
128
  "case_id": i + 1,
129
  "query": query,
130
+ "llm_original_answer": llm_answer,
131
+ "ground_truth_answer": factual,
132
+ "detection": {
133
+ "label": label,
134
+ "prediction": detected,
135
+ "reason": reason,
136
+ "signals": {"nli": nli_label, "similarity": round(sim_score, 3), "uncertainty": round(unc_score, 3)}
137
+ },
138
  "correction": correction
139
  })
140
 
 
141
  metrics = {
142
  "accuracy": accuracy_score(y_true, y_pred),
143
  "recall": recall_score(y_true, y_pred),
144
+ "f1": f1_score(y_true, y_pred),
145
+ "confusion_matrix": confusion_matrix(y_true, y_pred).tolist()
146
  }
147
 
 
148
  file_name = "final_clinical_hallucination_results.json"
149
  with open(file_name, "w") as f:
150
  json.dump({"metrics": metrics, "results": results}, f, indent=2)
 
152
  return f"✅ Audit Complete!\nAccuracy: {metrics['accuracy']:.2f}\nRecall: {metrics['recall']:.2f}", file_name
153
 
154
  # ============================================================================
155
+ # 4. GRADIO INTERFACE
156
  # ============================================================================
157
  with gr.Blocks() as demo:
158
+ gr.Markdown("# 🩺 Healthcare LLM Auditor (Nous-Hermes-2 Engine)")
159
+ gr.Markdown("Bu sistem 20 vakayı 4-bit optimize edilmiş Nous-Hermes-2 ile denetler.")
 
 
 
160
 
161
+ run_btn = gr.Button("🚀 Start Clinical Audit", variant="primary")
162
+ output_text = gr.Textbox(label="Status Summary")
163
+ output_file = gr.File(label="📥 Download JSON Results")
164
 
165
+ run_btn.click(fn=run_clinical_audit, outputs=[output_text, output_file])
166
 
167
  demo.launch()