clarke / scripts /eval_lora.py
yashvshetty's picture
Add LoRA evaluation on startup (RUN_LORA_EVAL flag)
f4b5b5b
"""Evaluate LoRA adapter by generating letters for all 5 patients and computing BLEU/ROUGE."""
import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_cache"
os.environ["USER"] = os.environ.get("USER", "appuser")
import gc
import json
import re
import math
from collections import Counter
from pathlib import Path
from datetime import datetime, timezone
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from jinja2 import Template
print("=" * 60)
print("CLARKE LoRA EVALUATION")
print("=" * 60)
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
MODEL_ID = "google/medgemma-27b-text-it"
ADAPTER_ID = "yashvshetty/clarke-medgemma-27b-lora"
# Load prompt template
template_text = Path("backend/prompts/document_generation.j2").read_text()
TEMPLATE = Template(template_text)
# Load gold standard references
GOLD_DIR = Path("evaluation/gold_standards")
REFERENCES = {}
for ref_file in sorted(GOLD_DIR.glob("ref_*.txt")):
key = ref_file.stem.replace("ref_", "")
REFERENCES[key] = ref_file.read_text(encoding="utf-8").strip()
print(f"Loaded {len(REFERENCES)} gold standard references: {list(REFERENCES.keys())}")
# Load FHIR bundles for patient context
FHIR_DIR = Path("data/fhir_bundles")
PATIENTS = {
"mrs_thompson": "pt-001",
"mr_okafor": "pt-002",
"ms_patel": "pt-003",
"mr_williams": "pt-004",
"mrs_khan": "pt-005",
}
# Load transcripts
TRANSCRIPT_DIR = Path("data/demo")
TRANSCRIPTS = {}
for name, pt_id in PATIENTS.items():
# Try different naming patterns
for pattern in [f"{pt_id}_transcript.txt", f"{name}_transcript.txt"]:
t_path = TRANSCRIPT_DIR / pattern
if t_path.exists():
TRANSCRIPTS[name] = t_path.read_text(encoding="utf-8").strip()
break
print(f"Loaded {len(TRANSCRIPTS)} transcripts")
# Load FHIR contexts
def load_fhir_context(pt_id):
bundle_path = FHIR_DIR / f"{pt_id}.json"
if not bundle_path.exists():
print(f"WARNING: No FHIR bundle for {pt_id}")
return {}
bundle = json.loads(bundle_path.read_text())
# Extract key info from FHIR bundle
context = {
"patient_id": pt_id,
"demographics": {},
"problem_list": [],
"medications": [],
"allergies": [],
"recent_labs": [],
"recent_imaging": [],
}
if "entry" in bundle:
for entry in bundle["entry"]:
resource = entry.get("resource", {})
rtype = resource.get("resourceType", "")
if rtype == "Patient":
name_parts = resource.get("name", [{}])[0]
given = " ".join(name_parts.get("given", []))
family = name_parts.get("family", "")
prefix = name_parts.get("prefix", [""])[0] if name_parts.get("prefix") else ""
context["demographics"]["name"] = f"{prefix} {given} {family}".strip()
context["demographics"]["dob"] = resource.get("birthDate", "")
nhs = ""
for ident in resource.get("identifier", []):
if "nhs" in ident.get("system", "").lower():
nhs = ident.get("value", "")
context["demographics"]["nhs_number"] = nhs
context["demographics"]["sex"] = resource.get("gender", "").capitalize()
elif rtype == "Condition":
code = resource.get("code", {}).get("text", "")
if not code:
codings = resource.get("code", {}).get("coding", [])
code = codings[0].get("display", "") if codings else ""
if code:
context["problem_list"].append(code)
elif rtype == "MedicationStatement" or rtype == "MedicationRequest":
med_code = resource.get("medicationCodeableConcept", {})
med_name = med_code.get("text", "")
if not med_name:
codings = med_code.get("coding", [])
med_name = codings[0].get("display", "") if codings else ""
dosage = resource.get("dosage", [{}])[0] if resource.get("dosage") else {}
dose_text = dosage.get("text", "")
context["medications"].append({"name": med_name, "dose": dose_text})
elif rtype == "AllergyIntolerance":
substance = resource.get("code", {}).get("text", "")
if not substance:
codings = resource.get("code", {}).get("coding", [])
substance = codings[0].get("display", "") if codings else ""
reaction_list = resource.get("reaction", [])
reaction = ""
if reaction_list:
manifestations = reaction_list[0].get("manifestation", [])
if manifestations:
reaction = manifestations[0].get("coding", [{}])[0].get("display", "")
context["allergies"].append({"substance": substance, "reaction": reaction})
elif rtype == "Observation":
code = resource.get("code", {})
obs_name = code.get("text", "")
if not obs_name:
codings = code.get("coding", [])
obs_name = codings[0].get("display", "") if codings else ""
value = ""
unit = ""
if "valueQuantity" in resource:
value = str(resource["valueQuantity"].get("value", ""))
unit = resource["valueQuantity"].get("unit", "")
elif "valueString" in resource:
value = resource["valueString"]
date = resource.get("effectiveDateTime", "")
context["recent_labs"].append({"name": obs_name, "value": value, "unit": unit, "date": date})
elif rtype == "DiagnosticReport":
code = resource.get("code", {})
report_name = code.get("text", "")
if not report_name:
codings = code.get("coding", [])
report_name = codings[0].get("display", "") if codings else ""
conclusion = resource.get("conclusion", "")
date = resource.get("effectiveDateTime", resource.get("issued", ""))
context["recent_imaging"].append({"type": report_name, "date": date, "summary": conclusion})
return context
CONTEXTS = {}
for name, pt_id in PATIENTS.items():
CONTEXTS[name] = load_fhir_context(pt_id)
print(f"Loaded {len(CONTEXTS)} FHIR contexts")
# Evaluation functions
def tokenize_text(text):
return re.findall(r'\b\w+\b', text.lower())
def ngrams(tokens, n):
return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
def bleu_score(reference, hypothesis, max_n=4):
ref_tokens = tokenize_text(reference)
hyp_tokens = tokenize_text(hypothesis)
if not hyp_tokens:
return {"bleu1": 0.0, "bleu4": 0.0}
log_avg = 0.0
bleu1_val = 0.0
for n in range(1, max_n+1):
ref_ng = Counter(ngrams(ref_tokens, n))
hyp_ng = Counter(ngrams(hyp_tokens, n))
clipped = sum(min(hyp_ng[ng], ref_ng[ng]) for ng in hyp_ng)
total = sum(hyp_ng.values())
precision = clipped / total if total > 0 else 0.0
if n == 1:
bleu1_val = round(precision, 4)
log_avg += math.log(precision) if precision > 0 else float('-inf')
bp = min(1.0, math.exp(1 - len(ref_tokens)/len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
cumulative = bp * math.exp(log_avg / max_n) if log_avg > float('-inf') else 0.0
return {"bleu1": bleu1_val, "bleu4": round(cumulative, 4)}
def rouge_l_f1(reference, hypothesis):
ref_tokens = tokenize_text(reference)
hyp_tokens = tokenize_text(hypothesis)
if not ref_tokens or not hyp_tokens:
return 0.0
m, n = len(ref_tokens), len(hyp_tokens)
dp = [[0]*(n+1) for _ in range(m+1)]
for i in range(1, m+1):
for j in range(1, n+1):
if ref_tokens[i-1] == hyp_tokens[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
lcs = dp[m][n]
precision = lcs / n
recall = lcs / m
if precision + recall == 0:
return 0.0
return round(2 * precision * recall / (precision + recall), 4)
# Load model
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Loading base model in 4-bit...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(model, ADAPTER_ID)
model.eval()
print(f"Model + adapter loaded. GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")
# Generate letters
generated_letters = {}
for name in PATIENTS:
if name not in TRANSCRIPTS:
print(f"SKIP {name}: no transcript")
continue
if name not in CONTEXTS:
print(f"SKIP {name}: no context")
continue
print(f"\nGenerating letter for: {name}")
context = CONTEXTS[name]
context_json = json.dumps(context, ensure_ascii=False, indent=2)
demo = context.get("demographics", {})
prompt = TEMPLATE.render(
letter_date=datetime.now(tz=timezone.utc).strftime("%d %b %Y"),
clinician_name="Dr Sarah Chen",
clinician_title="Consultant, General Practice",
gp_name="Dr Andrew Wilson",
gp_address="Riverside Medical Practice",
patient_name=demo.get("name", ""),
patient_dob=demo.get("dob", ""),
patient_nhs=demo.get("nhs_number", ""),
transcript=TRANSCRIPTS[name],
context_json=context_json,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False,
repetition_penalty=1.1,
)
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if full_output.startswith(prompt):
letter = full_output[len(prompt):].strip()
else:
letter = full_output.strip()
generated_letters[name] = letter
word_count = len(tokenize_text(letter))
print(f" Generated {word_count} words")
# Evaluate
BASELINE = {
"mrs_thompson": {"bleu1": 0.7970, "bleu4": 0.4882, "rouge_l": 0.6958},
"mr_okafor": {"bleu1": 0.7971, "bleu4": 0.6220, "rouge_l": 0.7247},
"ms_patel": {"bleu1": 0.8117, "bleu4": 0.5608, "rouge_l": 0.7119},
"mr_williams": {"bleu1": 0.8754, "bleu4": 0.7386, "rouge_l": 0.8139},
"mrs_khan": {"bleu1": 0.8244, "bleu4": 0.6425, "rouge_l": 0.7513},
}
print("\n" + "="*80)
print("EVALUATION RESULTS: LoRA Adapter vs Base Model (no adapter)")
print("="*80)
print(f"\n{'Patient':<20} {'Metric':<10} {'Base':<10} {'LoRA':<10} {'Delta':<10}")
print("-"*60)
lora_totals = {"bleu1": 0, "bleu4": 0, "rouge_l": 0}
base_totals = {"bleu1": 0, "bleu4": 0, "rouge_l": 0}
count = 0
for name in PATIENTS:
if name not in generated_letters or name not in REFERENCES:
continue
ref = REFERENCES[name]
hyp = generated_letters[name]
bl = bleu_score(ref, hyp)
rl = rouge_l_f1(ref, hyp)
scores = {"bleu1": bl["bleu1"], "bleu4": bl["bleu4"], "rouge_l": rl}
base = BASELINE.get(name, {"bleu1": 0, "bleu4": 0, "rouge_l": 0})
for metric in ["bleu1", "bleu4", "rouge_l"]:
delta = scores[metric] - base[metric]
sign = "+" if delta >= 0 else ""
label = {"bleu1": "BLEU-1", "bleu4": "BLEU-4", "rouge_l": "ROUGE-L"}[metric]
print(f"{name:<20} {label:<10} {base[metric]:<10.4f} {scores[metric]:<10.4f} {sign}{delta:.4f}")
lora_totals[metric] += scores[metric]
base_totals[metric] += base[metric]
count += 1
print()
if count > 0:
print("-"*60)
print(f"{'AVERAGE':<20} {'Metric':<10} {'Base':<10} {'LoRA':<10} {'Delta':<10}")
print("-"*60)
for metric in ["bleu1", "bleu4", "rouge_l"]:
avg_base = base_totals[metric] / count
avg_lora = lora_totals[metric] / count
delta = avg_lora - avg_base
sign = "+" if delta >= 0 else ""
label = {"bleu1": "BLEU-1", "bleu4": "BLEU-4", "rouge_l": "ROUGE-L"}[metric]
print(f"{'AVERAGE':<20} {label:<10} {avg_base:<10.4f} {avg_lora:<10.4f} {sign}{delta:.4f}")
# Save generated letters
for name, letter in generated_letters.items():
Path(f"/tmp/lora_{name}.txt").write_text(letter)
print(f"Saved: /tmp/lora_{name}.txt")
print("\nEVALUATION COMPLETE.")
# Cleanup
del model
gc.collect()
torch.cuda.empty_cache()
print("Memory freed.")