yashvshetty commited on
Commit
f4b5b5b
·
1 Parent(s): 71b7cef

Add LoRA evaluation on startup (RUN_LORA_EVAL flag)

Browse files
Files changed (2) hide show
  1. scripts/eval_lora.py +335 -0
  2. scripts/start.sh +11 -0
scripts/eval_lora.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate LoRA adapter by generating letters for all 5 patients and computing BLEU/ROUGE."""
2
+ import os
3
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_cache"
4
+ os.environ["USER"] = os.environ.get("USER", "appuser")
5
+
6
+ import gc
7
+ import json
8
+ import re
9
+ import math
10
+ from collections import Counter
11
+ from pathlib import Path
12
+ from datetime import datetime, timezone
13
+
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
16
+ from peft import PeftModel
17
+ from jinja2 import Template
18
+
19
+ print("=" * 60)
20
+ print("CLARKE LoRA EVALUATION")
21
+ print("=" * 60)
22
+
23
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
24
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
25
+
26
+ MODEL_ID = "google/medgemma-27b-text-it"
27
+ ADAPTER_ID = "yashvshetty/clarke-medgemma-27b-lora"
28
+
29
+ # Load prompt template
30
+ template_text = Path("backend/prompts/document_generation.j2").read_text()
31
+ TEMPLATE = Template(template_text)
32
+
33
+ # Load gold standard references
34
+ GOLD_DIR = Path("evaluation/gold_standards")
35
+ REFERENCES = {}
36
+ for ref_file in sorted(GOLD_DIR.glob("ref_*.txt")):
37
+ key = ref_file.stem.replace("ref_", "")
38
+ REFERENCES[key] = ref_file.read_text(encoding="utf-8").strip()
39
+ print(f"Loaded {len(REFERENCES)} gold standard references: {list(REFERENCES.keys())}")
40
+
41
+ # Load FHIR bundles for patient context
42
+ FHIR_DIR = Path("data/fhir_bundles")
43
+ PATIENTS = {
44
+ "mrs_thompson": "pt-001",
45
+ "mr_okafor": "pt-002",
46
+ "ms_patel": "pt-003",
47
+ "mr_williams": "pt-004",
48
+ "mrs_khan": "pt-005",
49
+ }
50
+
51
+ # Load transcripts
52
+ TRANSCRIPT_DIR = Path("data/demo")
53
+ TRANSCRIPTS = {}
54
+ for name, pt_id in PATIENTS.items():
55
+ # Try different naming patterns
56
+ for pattern in [f"{pt_id}_transcript.txt", f"{name}_transcript.txt"]:
57
+ t_path = TRANSCRIPT_DIR / pattern
58
+ if t_path.exists():
59
+ TRANSCRIPTS[name] = t_path.read_text(encoding="utf-8").strip()
60
+ break
61
+ print(f"Loaded {len(TRANSCRIPTS)} transcripts")
62
+
63
+ # Load FHIR contexts
64
+ def load_fhir_context(pt_id):
65
+ bundle_path = FHIR_DIR / f"{pt_id}.json"
66
+ if not bundle_path.exists():
67
+ print(f"WARNING: No FHIR bundle for {pt_id}")
68
+ return {}
69
+ bundle = json.loads(bundle_path.read_text())
70
+ # Extract key info from FHIR bundle
71
+ context = {
72
+ "patient_id": pt_id,
73
+ "demographics": {},
74
+ "problem_list": [],
75
+ "medications": [],
76
+ "allergies": [],
77
+ "recent_labs": [],
78
+ "recent_imaging": [],
79
+ }
80
+ if "entry" in bundle:
81
+ for entry in bundle["entry"]:
82
+ resource = entry.get("resource", {})
83
+ rtype = resource.get("resourceType", "")
84
+ if rtype == "Patient":
85
+ name_parts = resource.get("name", [{}])[0]
86
+ given = " ".join(name_parts.get("given", []))
87
+ family = name_parts.get("family", "")
88
+ prefix = name_parts.get("prefix", [""])[0] if name_parts.get("prefix") else ""
89
+ context["demographics"]["name"] = f"{prefix} {given} {family}".strip()
90
+ context["demographics"]["dob"] = resource.get("birthDate", "")
91
+ nhs = ""
92
+ for ident in resource.get("identifier", []):
93
+ if "nhs" in ident.get("system", "").lower():
94
+ nhs = ident.get("value", "")
95
+ context["demographics"]["nhs_number"] = nhs
96
+ context["demographics"]["sex"] = resource.get("gender", "").capitalize()
97
+ elif rtype == "Condition":
98
+ code = resource.get("code", {}).get("text", "")
99
+ if not code:
100
+ codings = resource.get("code", {}).get("coding", [])
101
+ code = codings[0].get("display", "") if codings else ""
102
+ if code:
103
+ context["problem_list"].append(code)
104
+ elif rtype == "MedicationStatement" or rtype == "MedicationRequest":
105
+ med_code = resource.get("medicationCodeableConcept", {})
106
+ med_name = med_code.get("text", "")
107
+ if not med_name:
108
+ codings = med_code.get("coding", [])
109
+ med_name = codings[0].get("display", "") if codings else ""
110
+ dosage = resource.get("dosage", [{}])[0] if resource.get("dosage") else {}
111
+ dose_text = dosage.get("text", "")
112
+ context["medications"].append({"name": med_name, "dose": dose_text})
113
+ elif rtype == "AllergyIntolerance":
114
+ substance = resource.get("code", {}).get("text", "")
115
+ if not substance:
116
+ codings = resource.get("code", {}).get("coding", [])
117
+ substance = codings[0].get("display", "") if codings else ""
118
+ reaction_list = resource.get("reaction", [])
119
+ reaction = ""
120
+ if reaction_list:
121
+ manifestations = reaction_list[0].get("manifestation", [])
122
+ if manifestations:
123
+ reaction = manifestations[0].get("coding", [{}])[0].get("display", "")
124
+ context["allergies"].append({"substance": substance, "reaction": reaction})
125
+ elif rtype == "Observation":
126
+ code = resource.get("code", {})
127
+ obs_name = code.get("text", "")
128
+ if not obs_name:
129
+ codings = code.get("coding", [])
130
+ obs_name = codings[0].get("display", "") if codings else ""
131
+ value = ""
132
+ unit = ""
133
+ if "valueQuantity" in resource:
134
+ value = str(resource["valueQuantity"].get("value", ""))
135
+ unit = resource["valueQuantity"].get("unit", "")
136
+ elif "valueString" in resource:
137
+ value = resource["valueString"]
138
+ date = resource.get("effectiveDateTime", "")
139
+ context["recent_labs"].append({"name": obs_name, "value": value, "unit": unit, "date": date})
140
+ elif rtype == "DiagnosticReport":
141
+ code = resource.get("code", {})
142
+ report_name = code.get("text", "")
143
+ if not report_name:
144
+ codings = code.get("coding", [])
145
+ report_name = codings[0].get("display", "") if codings else ""
146
+ conclusion = resource.get("conclusion", "")
147
+ date = resource.get("effectiveDateTime", resource.get("issued", ""))
148
+ context["recent_imaging"].append({"type": report_name, "date": date, "summary": conclusion})
149
+ return context
150
+
151
+ CONTEXTS = {}
152
+ for name, pt_id in PATIENTS.items():
153
+ CONTEXTS[name] = load_fhir_context(pt_id)
154
+ print(f"Loaded {len(CONTEXTS)} FHIR contexts")
155
+
156
+ # Evaluation functions
157
+ def tokenize_text(text):
158
+ return re.findall(r'\b\w+\b', text.lower())
159
+
160
+ def ngrams(tokens, n):
161
+ return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
162
+
163
+ def bleu_score(reference, hypothesis, max_n=4):
164
+ ref_tokens = tokenize_text(reference)
165
+ hyp_tokens = tokenize_text(hypothesis)
166
+ if not hyp_tokens:
167
+ return {"bleu1": 0.0, "bleu4": 0.0}
168
+ log_avg = 0.0
169
+ bleu1_val = 0.0
170
+ for n in range(1, max_n+1):
171
+ ref_ng = Counter(ngrams(ref_tokens, n))
172
+ hyp_ng = Counter(ngrams(hyp_tokens, n))
173
+ clipped = sum(min(hyp_ng[ng], ref_ng[ng]) for ng in hyp_ng)
174
+ total = sum(hyp_ng.values())
175
+ precision = clipped / total if total > 0 else 0.0
176
+ if n == 1:
177
+ bleu1_val = round(precision, 4)
178
+ log_avg += math.log(precision) if precision > 0 else float('-inf')
179
+ bp = min(1.0, math.exp(1 - len(ref_tokens)/len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
180
+ cumulative = bp * math.exp(log_avg / max_n) if log_avg > float('-inf') else 0.0
181
+ return {"bleu1": bleu1_val, "bleu4": round(cumulative, 4)}
182
+
183
+ def rouge_l_f1(reference, hypothesis):
184
+ ref_tokens = tokenize_text(reference)
185
+ hyp_tokens = tokenize_text(hypothesis)
186
+ if not ref_tokens or not hyp_tokens:
187
+ return 0.0
188
+ m, n = len(ref_tokens), len(hyp_tokens)
189
+ dp = [[0]*(n+1) for _ in range(m+1)]
190
+ for i in range(1, m+1):
191
+ for j in range(1, n+1):
192
+ if ref_tokens[i-1] == hyp_tokens[j-1]:
193
+ dp[i][j] = dp[i-1][j-1] + 1
194
+ else:
195
+ dp[i][j] = max(dp[i-1][j], dp[i][j-1])
196
+ lcs = dp[m][n]
197
+ precision = lcs / n
198
+ recall = lcs / m
199
+ if precision + recall == 0:
200
+ return 0.0
201
+ return round(2 * precision * recall / (precision + recall), 4)
202
+
203
+ # Load model
204
+ print("\nLoading tokenizer...")
205
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
206
+
207
+ print("Loading base model in 4-bit...")
208
+ bnb_config = BitsAndBytesConfig(
209
+ load_in_4bit=True,
210
+ bnb_4bit_quant_type="nf4",
211
+ bnb_4bit_compute_dtype=torch.bfloat16,
212
+ bnb_4bit_use_double_quant=True,
213
+ )
214
+ model = AutoModelForCausalLM.from_pretrained(
215
+ MODEL_ID,
216
+ quantization_config=bnb_config,
217
+ device_map="auto",
218
+ torch_dtype=torch.bfloat16,
219
+ )
220
+
221
+ print("Loading LoRA adapter...")
222
+ model = PeftModel.from_pretrained(model, ADAPTER_ID)
223
+ model.eval()
224
+ print(f"Model + adapter loaded. GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")
225
+
226
+ # Generate letters
227
+ generated_letters = {}
228
+ for name in PATIENTS:
229
+ if name not in TRANSCRIPTS:
230
+ print(f"SKIP {name}: no transcript")
231
+ continue
232
+ if name not in CONTEXTS:
233
+ print(f"SKIP {name}: no context")
234
+ continue
235
+
236
+ print(f"\nGenerating letter for: {name}")
237
+ context = CONTEXTS[name]
238
+ context_json = json.dumps(context, ensure_ascii=False, indent=2)
239
+ demo = context.get("demographics", {})
240
+
241
+ prompt = TEMPLATE.render(
242
+ letter_date=datetime.now(tz=timezone.utc).strftime("%d %b %Y"),
243
+ clinician_name="Dr Sarah Chen",
244
+ clinician_title="Consultant, General Practice",
245
+ gp_name="Dr Andrew Wilson",
246
+ gp_address="Riverside Medical Practice",
247
+ patient_name=demo.get("name", ""),
248
+ patient_dob=demo.get("dob", ""),
249
+ patient_nhs=demo.get("nhs_number", ""),
250
+ transcript=TRANSCRIPTS[name],
251
+ context_json=context_json,
252
+ )
253
+
254
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
255
+ with torch.no_grad():
256
+ output_ids = model.generate(
257
+ **inputs,
258
+ max_new_tokens=2048,
259
+ do_sample=False,
260
+ repetition_penalty=1.1,
261
+ )
262
+
263
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
264
+ if full_output.startswith(prompt):
265
+ letter = full_output[len(prompt):].strip()
266
+ else:
267
+ letter = full_output.strip()
268
+
269
+ generated_letters[name] = letter
270
+ word_count = len(tokenize_text(letter))
271
+ print(f" Generated {word_count} words")
272
+
273
+ # Evaluate
274
+ BASELINE = {
275
+ "mrs_thompson": {"bleu1": 0.7970, "bleu4": 0.4882, "rouge_l": 0.6958},
276
+ "mr_okafor": {"bleu1": 0.7971, "bleu4": 0.6220, "rouge_l": 0.7247},
277
+ "ms_patel": {"bleu1": 0.8117, "bleu4": 0.5608, "rouge_l": 0.7119},
278
+ "mr_williams": {"bleu1": 0.8754, "bleu4": 0.7386, "rouge_l": 0.8139},
279
+ "mrs_khan": {"bleu1": 0.8244, "bleu4": 0.6425, "rouge_l": 0.7513},
280
+ }
281
+
282
+ print("\n" + "="*80)
283
+ print("EVALUATION RESULTS: LoRA Adapter vs Base Model (no adapter)")
284
+ print("="*80)
285
+ print(f"\n{'Patient':<20} {'Metric':<10} {'Base':<10} {'LoRA':<10} {'Delta':<10}")
286
+ print("-"*60)
287
+
288
+ lora_totals = {"bleu1": 0, "bleu4": 0, "rouge_l": 0}
289
+ base_totals = {"bleu1": 0, "bleu4": 0, "rouge_l": 0}
290
+ count = 0
291
+
292
+ for name in PATIENTS:
293
+ if name not in generated_letters or name not in REFERENCES:
294
+ continue
295
+ ref = REFERENCES[name]
296
+ hyp = generated_letters[name]
297
+ bl = bleu_score(ref, hyp)
298
+ rl = rouge_l_f1(ref, hyp)
299
+ scores = {"bleu1": bl["bleu1"], "bleu4": bl["bleu4"], "rouge_l": rl}
300
+ base = BASELINE.get(name, {"bleu1": 0, "bleu4": 0, "rouge_l": 0})
301
+
302
+ for metric in ["bleu1", "bleu4", "rouge_l"]:
303
+ delta = scores[metric] - base[metric]
304
+ sign = "+" if delta >= 0 else ""
305
+ label = {"bleu1": "BLEU-1", "bleu4": "BLEU-4", "rouge_l": "ROUGE-L"}[metric]
306
+ print(f"{name:<20} {label:<10} {base[metric]:<10.4f} {scores[metric]:<10.4f} {sign}{delta:.4f}")
307
+ lora_totals[metric] += scores[metric]
308
+ base_totals[metric] += base[metric]
309
+ count += 1
310
+ print()
311
+
312
+ if count > 0:
313
+ print("-"*60)
314
+ print(f"{'AVERAGE':<20} {'Metric':<10} {'Base':<10} {'LoRA':<10} {'Delta':<10}")
315
+ print("-"*60)
316
+ for metric in ["bleu1", "bleu4", "rouge_l"]:
317
+ avg_base = base_totals[metric] / count
318
+ avg_lora = lora_totals[metric] / count
319
+ delta = avg_lora - avg_base
320
+ sign = "+" if delta >= 0 else ""
321
+ label = {"bleu1": "BLEU-1", "bleu4": "BLEU-4", "rouge_l": "ROUGE-L"}[metric]
322
+ print(f"{'AVERAGE':<20} {label:<10} {avg_base:<10.4f} {avg_lora:<10.4f} {sign}{delta:.4f}")
323
+
324
+ # Save generated letters
325
+ for name, letter in generated_letters.items():
326
+ Path(f"/tmp/lora_{name}.txt").write_text(letter)
327
+ print(f"Saved: /tmp/lora_{name}.txt")
328
+
329
+ print("\nEVALUATION COMPLETE.")
330
+
331
+ # Cleanup
332
+ del model
333
+ gc.collect()
334
+ torch.cuda.empty_cache()
335
+ print("Memory freed.")
scripts/start.sh CHANGED
@@ -4,6 +4,17 @@ export USER="${USER:-appuser}"
4
  export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_cache"
5
 
6
  echo "Starting Clarke..."
 
 
 
 
 
 
 
 
 
 
 
7
  echo "USE_MOCK_FHIR=${USE_MOCK_FHIR:-false}"
8
  echo "MEDASR_MODEL_ID=${MEDASR_MODEL_ID:-not set}"
9
 
 
4
  export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_cache"
5
 
6
  echo "Starting Clarke..."
7
+
8
+ if [ "${RUN_LORA_EVAL}" = "true" ]; then
9
+ echo "============================================"
10
+ echo "LoRA evaluation requested. Running..."
11
+ echo "============================================"
12
+ python scripts/eval_lora.py || echo "WARNING: Evaluation failed but app will start normally"
13
+ echo "============================================"
14
+ echo "Evaluation phase complete. Starting app..."
15
+ echo "============================================"
16
+ fi
17
+
18
  echo "USE_MOCK_FHIR=${USE_MOCK_FHIR:-false}"
19
  echo "MEDASR_MODEL_ID=${MEDASR_MODEL_ID:-not set}"
20