"""Evaluate LoRA adapter by generating letters for all 5 patients and computing BLEU/ROUGE.""" import os os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_cache" os.environ["USER"] = os.environ.get("USER", "appuser") import gc import json import re import math from collections import Counter from pathlib import Path from datetime import datetime, timezone import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel from jinja2 import Template print("=" * 60) print("CLARKE LoRA EVALUATION") print("=" * 60) print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") MODEL_ID = "google/medgemma-27b-text-it" ADAPTER_ID = "yashvshetty/clarke-medgemma-27b-lora" # Load prompt template template_text = Path("backend/prompts/document_generation.j2").read_text() TEMPLATE = Template(template_text) # Load gold standard references GOLD_DIR = Path("evaluation/gold_standards") REFERENCES = {} for ref_file in sorted(GOLD_DIR.glob("ref_*.txt")): key = ref_file.stem.replace("ref_", "") REFERENCES[key] = ref_file.read_text(encoding="utf-8").strip() print(f"Loaded {len(REFERENCES)} gold standard references: {list(REFERENCES.keys())}") # Load FHIR bundles for patient context FHIR_DIR = Path("data/fhir_bundles") PATIENTS = { "mrs_thompson": "pt-001", "mr_okafor": "pt-002", "ms_patel": "pt-003", "mr_williams": "pt-004", "mrs_khan": "pt-005", } # Load transcripts TRANSCRIPT_DIR = Path("data/demo") TRANSCRIPTS = {} for name, pt_id in PATIENTS.items(): # Try different naming patterns for pattern in [f"{pt_id}_transcript.txt", f"{name}_transcript.txt"]: t_path = TRANSCRIPT_DIR / pattern if t_path.exists(): TRANSCRIPTS[name] = t_path.read_text(encoding="utf-8").strip() break print(f"Loaded {len(TRANSCRIPTS)} transcripts") # Load FHIR contexts def load_fhir_context(pt_id): bundle_path = FHIR_DIR / f"{pt_id}.json" if not bundle_path.exists(): print(f"WARNING: No FHIR bundle for {pt_id}") return {} bundle = json.loads(bundle_path.read_text()) # Extract key info from FHIR bundle context = { "patient_id": pt_id, "demographics": {}, "problem_list": [], "medications": [], "allergies": [], "recent_labs": [], "recent_imaging": [], } if "entry" in bundle: for entry in bundle["entry"]: resource = entry.get("resource", {}) rtype = resource.get("resourceType", "") if rtype == "Patient": name_parts = resource.get("name", [{}])[0] given = " ".join(name_parts.get("given", [])) family = name_parts.get("family", "") prefix = name_parts.get("prefix", [""])[0] if name_parts.get("prefix") else "" context["demographics"]["name"] = f"{prefix} {given} {family}".strip() context["demographics"]["dob"] = resource.get("birthDate", "") nhs = "" for ident in resource.get("identifier", []): if "nhs" in ident.get("system", "").lower(): nhs = ident.get("value", "") context["demographics"]["nhs_number"] = nhs context["demographics"]["sex"] = resource.get("gender", "").capitalize() elif rtype == "Condition": code = resource.get("code", {}).get("text", "") if not code: codings = resource.get("code", {}).get("coding", []) code = codings[0].get("display", "") if codings else "" if code: context["problem_list"].append(code) elif rtype == "MedicationStatement" or rtype == "MedicationRequest": med_code = resource.get("medicationCodeableConcept", {}) med_name = med_code.get("text", "") if not med_name: codings = med_code.get("coding", []) med_name = codings[0].get("display", "") if codings else "" dosage = resource.get("dosage", [{}])[0] if resource.get("dosage") else {} dose_text = dosage.get("text", "") context["medications"].append({"name": med_name, "dose": dose_text}) elif rtype == "AllergyIntolerance": substance = resource.get("code", {}).get("text", "") if not substance: codings = resource.get("code", {}).get("coding", []) substance = codings[0].get("display", "") if codings else "" reaction_list = resource.get("reaction", []) reaction = "" if reaction_list: manifestations = reaction_list[0].get("manifestation", []) if manifestations: reaction = manifestations[0].get("coding", [{}])[0].get("display", "") context["allergies"].append({"substance": substance, "reaction": reaction}) elif rtype == "Observation": code = resource.get("code", {}) obs_name = code.get("text", "") if not obs_name: codings = code.get("coding", []) obs_name = codings[0].get("display", "") if codings else "" value = "" unit = "" if "valueQuantity" in resource: value = str(resource["valueQuantity"].get("value", "")) unit = resource["valueQuantity"].get("unit", "") elif "valueString" in resource: value = resource["valueString"] date = resource.get("effectiveDateTime", "") context["recent_labs"].append({"name": obs_name, "value": value, "unit": unit, "date": date}) elif rtype == "DiagnosticReport": code = resource.get("code", {}) report_name = code.get("text", "") if not report_name: codings = code.get("coding", []) report_name = codings[0].get("display", "") if codings else "" conclusion = resource.get("conclusion", "") date = resource.get("effectiveDateTime", resource.get("issued", "")) context["recent_imaging"].append({"type": report_name, "date": date, "summary": conclusion}) return context CONTEXTS = {} for name, pt_id in PATIENTS.items(): CONTEXTS[name] = load_fhir_context(pt_id) print(f"Loaded {len(CONTEXTS)} FHIR contexts") # Evaluation functions def tokenize_text(text): return re.findall(r'\b\w+\b', text.lower()) def ngrams(tokens, n): return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)] def bleu_score(reference, hypothesis, max_n=4): ref_tokens = tokenize_text(reference) hyp_tokens = tokenize_text(hypothesis) if not hyp_tokens: return {"bleu1": 0.0, "bleu4": 0.0} log_avg = 0.0 bleu1_val = 0.0 for n in range(1, max_n+1): ref_ng = Counter(ngrams(ref_tokens, n)) hyp_ng = Counter(ngrams(hyp_tokens, n)) clipped = sum(min(hyp_ng[ng], ref_ng[ng]) for ng in hyp_ng) total = sum(hyp_ng.values()) precision = clipped / total if total > 0 else 0.0 if n == 1: bleu1_val = round(precision, 4) log_avg += math.log(precision) if precision > 0 else float('-inf') bp = min(1.0, math.exp(1 - len(ref_tokens)/len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0 cumulative = bp * math.exp(log_avg / max_n) if log_avg > float('-inf') else 0.0 return {"bleu1": bleu1_val, "bleu4": round(cumulative, 4)} def rouge_l_f1(reference, hypothesis): ref_tokens = tokenize_text(reference) hyp_tokens = tokenize_text(hypothesis) if not ref_tokens or not hyp_tokens: return 0.0 m, n = len(ref_tokens), len(hyp_tokens) dp = [[0]*(n+1) for _ in range(m+1)] for i in range(1, m+1): for j in range(1, n+1): if ref_tokens[i-1] == hyp_tokens[j-1]: dp[i][j] = dp[i-1][j-1] + 1 else: dp[i][j] = max(dp[i-1][j], dp[i][j-1]) lcs = dp[m][n] precision = lcs / n recall = lcs / m if precision + recall == 0: return 0.0 return round(2 * precision * recall / (precision + recall), 4) # Load model print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print("Loading base model in 4-bit...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, ) print("Loading LoRA adapter...") model = PeftModel.from_pretrained(model, ADAPTER_ID) model.eval() print(f"Model + adapter loaded. GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB") # Generate letters generated_letters = {} for name in PATIENTS: if name not in TRANSCRIPTS: print(f"SKIP {name}: no transcript") continue if name not in CONTEXTS: print(f"SKIP {name}: no context") continue print(f"\nGenerating letter for: {name}") context = CONTEXTS[name] context_json = json.dumps(context, ensure_ascii=False, indent=2) demo = context.get("demographics", {}) prompt = TEMPLATE.render( letter_date=datetime.now(tz=timezone.utc).strftime("%d %b %Y"), clinician_name="Dr Sarah Chen", clinician_title="Consultant, General Practice", gp_name="Dr Andrew Wilson", gp_address="Riverside Medical Practice", patient_name=demo.get("name", ""), patient_dob=demo.get("dob", ""), patient_nhs=demo.get("nhs_number", ""), transcript=TRANSCRIPTS[name], context_json=context_json, ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=2048, do_sample=False, repetition_penalty=1.1, ) full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) if full_output.startswith(prompt): letter = full_output[len(prompt):].strip() else: letter = full_output.strip() generated_letters[name] = letter word_count = len(tokenize_text(letter)) print(f" Generated {word_count} words") # Evaluate BASELINE = { "mrs_thompson": {"bleu1": 0.7970, "bleu4": 0.4882, "rouge_l": 0.6958}, "mr_okafor": {"bleu1": 0.7971, "bleu4": 0.6220, "rouge_l": 0.7247}, "ms_patel": {"bleu1": 0.8117, "bleu4": 0.5608, "rouge_l": 0.7119}, "mr_williams": {"bleu1": 0.8754, "bleu4": 0.7386, "rouge_l": 0.8139}, "mrs_khan": {"bleu1": 0.8244, "bleu4": 0.6425, "rouge_l": 0.7513}, } print("\n" + "="*80) print("EVALUATION RESULTS: LoRA Adapter vs Base Model (no adapter)") print("="*80) print(f"\n{'Patient':<20} {'Metric':<10} {'Base':<10} {'LoRA':<10} {'Delta':<10}") print("-"*60) lora_totals = {"bleu1": 0, "bleu4": 0, "rouge_l": 0} base_totals = {"bleu1": 0, "bleu4": 0, "rouge_l": 0} count = 0 for name in PATIENTS: if name not in generated_letters or name not in REFERENCES: continue ref = REFERENCES[name] hyp = generated_letters[name] bl = bleu_score(ref, hyp) rl = rouge_l_f1(ref, hyp) scores = {"bleu1": bl["bleu1"], "bleu4": bl["bleu4"], "rouge_l": rl} base = BASELINE.get(name, {"bleu1": 0, "bleu4": 0, "rouge_l": 0}) for metric in ["bleu1", "bleu4", "rouge_l"]: delta = scores[metric] - base[metric] sign = "+" if delta >= 0 else "" label = {"bleu1": "BLEU-1", "bleu4": "BLEU-4", "rouge_l": "ROUGE-L"}[metric] print(f"{name:<20} {label:<10} {base[metric]:<10.4f} {scores[metric]:<10.4f} {sign}{delta:.4f}") lora_totals[metric] += scores[metric] base_totals[metric] += base[metric] count += 1 print() if count > 0: print("-"*60) print(f"{'AVERAGE':<20} {'Metric':<10} {'Base':<10} {'LoRA':<10} {'Delta':<10}") print("-"*60) for metric in ["bleu1", "bleu4", "rouge_l"]: avg_base = base_totals[metric] / count avg_lora = lora_totals[metric] / count delta = avg_lora - avg_base sign = "+" if delta >= 0 else "" label = {"bleu1": "BLEU-1", "bleu4": "BLEU-4", "rouge_l": "ROUGE-L"}[metric] print(f"{'AVERAGE':<20} {label:<10} {avg_base:<10.4f} {avg_lora:<10.4f} {sign}{delta:.4f}") # Save generated letters for name, letter in generated_letters.items(): Path(f"/tmp/lora_{name}.txt").write_text(letter) print(f"Saved: /tmp/lora_{name}.txt") print("\nEVALUATION COMPLETE.") # Cleanup del model gc.collect() torch.cuda.empty_cache() print("Memory freed.")