clarke / backend /models /doc_generator.py
yashvshetty's picture
Update letter format: separate Assessment/Plan/Advice, GP-reviewed structure, BLEU-optimized references
d0a651e
"""MedGemma 27B document generation wrapper with prompt rendering and section parsing."""
from __future__ import annotations
import json
import re
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from jinja2 import Environment, FileSystemLoader
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
except ModuleNotFoundError: # pragma: no cover - mock mode support
torch = None
AutoModelForCausalLM = None
AutoTokenizer = None
BitsAndBytesConfig = None
if torch is not None and not hasattr(torch.nn.Module, "set_submodule"):
def _set_submodule(self, target, module):
atoms = target.split(".")
mod = self
for item in atoms[:-1]:
mod = getattr(mod, item)
setattr(mod, atoms[-1], module)
torch.nn.Module.set_submodule = _set_submodule
from backend.config import get_settings
from backend.errors import ModelExecutionError, get_component_logger
from backend.schemas import ClinicalDocument, ConsultationStatus, DocumentSection, PatientContext
logger = get_component_logger("doc_generator")
PROMPTS_DIR = Path("backend/prompts")
KNOWN_SECTION_HEADINGS = [
"History of presenting complaint",
"Past medical history",
"Current medications",
"Examination findings",
"Investigation results",
"Assessment",
"Plan",
"Advice to patient",
]
class DocumentGenerator:
"""Generate NHS clinic letters from transcript + patient context using MedGemma 27B.
Args:
model_id (str | None): Optional model identifier override.
Returns:
None: Initialises lazy model/tokenizer handles and runtime settings.
"""
def __init__(self, model_id: str | None = None) -> None:
"""Initialise document generator model wrapper state.
Args:
model_id (str | None): Optional model identifier override.
Returns:
None: Stores settings and lazy-loaded model state.
"""
self.settings = get_settings()
self.model_id = model_id or self.settings.MEDGEMMA_27B_MODEL_ID
self._tokenizer: Any | None = None
self._model: Any | None = None
self.is_mock_mode = self.model_id.lower() == "mock"
def load_model(self) -> None:
"""Load MedGemma 27B tokenizer and model in 4-bit quantised mode.
Args:
None: Reads class configuration and settings values.
Returns:
None: Populates model/tokenizer attributes or mock sentinel state.
"""
if self.is_mock_mode:
logger.info("Document generator initialised in mock mode")
self._tokenizer = "mock"
self._model = "mock"
return
if self._model is not None and self._tokenizer is not None:
return
if AutoModelForCausalLM is None or AutoTokenizer is None or BitsAndBytesConfig is None or torch is None:
raise ModelExecutionError("transformers and torch are required for non-mock document generation")
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self._model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
logger.info("Loaded MedGemma document model", model_id=self.model_id)
except Exception as exc:
raise ModelExecutionError(f"Failed to load MedGemma 27B model: {exc}") from exc
def generate(self, prompt: str, max_new_tokens: int | None = None) -> str:
"""Generate raw letter text from a rendered prompt.
Args:
prompt (str): Fully rendered prompt string.
max_new_tokens (int | None): Optional generation token cap override.
Returns:
str: Raw decoded generated text output from the model.
"""
if self._model is None or self._tokenizer is None:
self.load_model()
if self.is_mock_mode:
return self._mock_reference_letter()
generation_max_tokens = max_new_tokens or self.settings.DOC_GEN_MAX_TOKENS
try:
inputs = self._tokenizer(prompt, return_tensors="pt")
if hasattr(self._model, "device"):
inputs = {key: value.to(self._model.device) for key, value in inputs.items()}
output_tokens = self._model.generate(
**inputs,
max_new_tokens=generation_max_tokens,
do_sample=False,
repetition_penalty=1.1,
)
except Exception as exc:
if "CUDA" in str(exc):
try:
import torch
torch.cuda.empty_cache()
except Exception:
pass
raise ModelExecutionError(f"MedGemma 27B inference failed: {exc}") from exc
decoded_output = self._tokenizer.decode(output_tokens[0], skip_special_tokens=True)
stripped = self._strip_prompt_prefix(decoded_output, prompt)
return self._clean_model_output(stripped)
def generate_document(
self,
transcript: str,
context: PatientContext,
max_new_tokens: int | None = None,
doc_type: str = "Clinic Letter",
letter_prefs: dict | None = None,
) -> ClinicalDocument:
"""Render prompt, generate text with retry policy, and build ClinicalDocument.
Args:
transcript (str): Consultation transcript text.
context (PatientContext): Structured patient context payload.
max_new_tokens (int | None): Optional generation token limit override.
Returns:
ClinicalDocument: Parsed clinical letter representation with section objects.
"""
prompt = self._render_prompt(transcript, context, doc_type=doc_type, letter_prefs=letter_prefs)
generation_start = time.perf_counter()
last_error: Exception | None = None
first_limit = max_new_tokens or 2048
retry_limit = max(256, first_limit // 2)
for attempt, token_limit in enumerate((first_limit, retry_limit), start=1):
try:
generated_text = self.generate(prompt, max_new_tokens=token_limit)
sections = self._parse_sections(generated_text)
if len(sections) < 4:
raise ValueError("Generated output did not contain enough parseable sections")
generation_time_s = round(time.perf_counter() - generation_start, 3)
return self._build_document(context, sections, generation_time_s)
except Exception as exc: # noqa: BLE001
last_error = exc
logger.warning("Document generation attempt failed", attempt=attempt, error=str(exc))
raise ModelExecutionError(f"Document generation failed after retry: {last_error}")
def _render_prompt(self, transcript: str, context: PatientContext, doc_type: str = "Clinic Letter", letter_prefs: dict | None = None) -> str:
"""Render the document generation Jinja2 template with consultation inputs.
Args:
transcript (str): Consultation transcript text.
context (PatientContext): Structured patient context data.
doc_type (str): Document type - "Clinic Letter" or "Ward Round Note".
letter_prefs (dict | None): Optional letter preferences from frontend.
Returns:
str: Rendered prompt string supplied to the language model.
"""
prefs = letter_prefs or {}
env = Environment(loader=FileSystemLoader(PROMPTS_DIR))
template_name = "ward_round_generation.j2" if doc_type == "Ward Round Note" else "document_generation.j2"
template = env.get_template(template_name)
# Extract patient details from context
patient_name = context.demographics.get("name", "Unknown")
patient_dob = context.demographics.get("dob", "Unknown")
patient_nhs = context.demographics.get("nhs_number", "Unknown")
context_json = json.dumps(context.model_dump(mode="json"), ensure_ascii=False, indent=2)
return template.render(
letter_date=datetime.now(tz=timezone.utc).strftime("%d %b %Y"),
clinician_name=prefs.get("clinician_name", "Dr. Sarah Chen"),
clinician_title=prefs.get("clinician_title", "Consultant Diabetologist"),
gp_name=prefs.get("gp_name", "Dr Andrew Wilson"),
gp_address=prefs.get("gp_address", "Riverside Medical Practice"),
patient_name=patient_name,
patient_dob=patient_dob,
patient_nhs=patient_nhs,
transcript=transcript,
context_json=context_json,
)
@staticmethod
def _parse_sections(generated_text: str) -> list[DocumentSection]:
"""Parse generated letter text into section objects using heading detection rules.
Args:
generated_text (str): Raw generated letter text.
Returns:
list[DocumentSection]: Ordered parsed sections with heading and content fields.
"""
logger.info("Raw generated text for parsing:\n{}", generated_text[:2000])
section_pattern = re.compile(
r"^(?:\*\*|##\s*)?(?:\d+[\)\.]\s*)?(Summary|History of presenting complaint|Past medical history|Current medications|Examination findings|Investigation results|Assessment and plan|Assessment|Plan|Advice to patient|Current medications|Overnight events|Current status and observations|Tasks / Actions|Tasks|Actions)[:\*\s]*$",
flags=re.IGNORECASE,
)
sections: list[DocumentSection] = []
current_heading: str | None = None
current_lines: list[str] = []
header_lines: list[str] = []
for raw_line in generated_text.splitlines():
line = raw_line.strip()
heading_match = section_pattern.match(line)
if heading_match:
if current_heading and current_lines:
sections.append(
DocumentSection(
heading=current_heading,
content="\n".join(current_lines).strip(),
editable=True,
fhir_sources=[],
)
)
current_heading = heading_match.group(1).strip().title()
current_lines = []
continue
if current_heading:
if line:
current_lines.append(line)
else:
current_lines.append("")
elif line:
header_lines.append(line)
if current_heading and current_lines:
sections.append(
DocumentSection(
heading=current_heading,
content="\n".join(current_lines).strip(),
editable=True,
fhir_sources=[],
)
)
# Insert letter header (addressee, date, salutation) as first section if present
if header_lines:
header_text = "\n".join(header_lines).strip()
if header_text:
sections.insert(
0,
DocumentSection(
heading="",
content=header_text,
editable=True,
fhir_sources=[],
),
)
# Strip sign-off block from last section content
if sections:
last = sections[-1]
signoff_pattern = re.compile(
r"\n\s*\n\s*(Warm regards|Kind regards|Yours sincerely|Yours faithfully|Sign-off:).*",
flags=re.IGNORECASE | re.DOTALL,
)
cleaned = signoff_pattern.sub("", last.content)
if cleaned != last.content:
signoff_text = last.content[len(cleaned):].strip()
sections[-1] = DocumentSection(
heading=last.heading,
content=cleaned.strip(),
editable=last.editable,
fhir_sources=last.fhir_sources,
)
# Add sign-off as its own section
if signoff_text:
sections.append(
DocumentSection(
heading="",
content=signoff_text,
editable=True,
fhir_sources=[],
)
)
if not sections:
sections = [
DocumentSection(
heading=heading,
content="Content unavailable in generated output.",
editable=True,
fhir_sources=[],
)
for heading in KNOWN_SECTION_HEADINGS
]
return sections
@staticmethod
def _build_document(
context: PatientContext,
sections: list[DocumentSection],
generation_time_s: float,
) -> ClinicalDocument:
"""Construct ClinicalDocument schema object from parsed generation outputs.
Args:
context (PatientContext): Structured patient context data.
sections (list[DocumentSection]): Parsed generated document sections.
generation_time_s (float): Generation wall-clock duration in seconds.
Returns:
ClinicalDocument: Final structured letter ready for API response.
"""
demographics = context.demographics
patient_name = str(demographics.get("name", "Unknown patient"))
patient_dob = str(demographics.get("dob", ""))
nhs_number = str(demographics.get("nhs_number", ""))
medications_list = [med.get("name", "") for med in context.medications if med.get("name")]
return ClinicalDocument(
consultation_id=context.patient_id,
letter_date=datetime.now(tz=timezone.utc).strftime("%Y-%m-%d"),
patient_name=patient_name,
patient_dob=patient_dob,
nhs_number=nhs_number,
addressee="GP Practice",
salutation="Dear Dr.,",
sections=sections,
medications_list=medications_list,
sign_off="Dr. Sarah Chen, Consultant Diabetologist",
status=ConsultationStatus.REVIEW,
generated_at=datetime.now(tz=timezone.utc).isoformat(),
generation_time_s=generation_time_s,
discrepancies=[],
)
@staticmethod
def _strip_prompt_prefix(decoded_output: str, prompt: str) -> str:
"""Remove prompt text prefix when decoder echoes the full prompt + completion.
Args:
decoded_output (str): Tokenizer-decoded text from model output ids.
prompt (str): Original model prompt string.
Returns:
str: Completion-only output when prompt prefix is present.
"""
if decoded_output.startswith(prompt):
return decoded_output[len(prompt) :].strip()
return decoded_output.strip()
@staticmethod
def _clean_model_output(text: str) -> str:
"""Remove model sequence tokens and replace clinical flags with human-readable notes.
Args:
text (str): Raw model output after prompt prefix stripping.
Returns:
str: Cleaned text safe for clinical document display.
"""
# End-of-sequence tokens leak from decoder when skip_special_tokens misses them
text = text.replace("<|end|>", "").replace("<|endoftext|>", "")
text = text.replace("<|END|>", "").replace("<|ENDOFTEXT|>", "")
# Replace raw discrepancy tags with human-readable clinical note
text = re.sub(
r"\[DISCREPANCY\]",
"(Note: value differs from EHR, must verify)",
text,
flags=re.IGNORECASE,
)
# Collapse excessive blank lines left behind by removals
text = re.sub(r"\n{3,}", "\n\n", text)
# Ensure blank line before sign-off (handle optional trailing whitespace)
text = re.sub(r'(\S)[^\S\n]*\n[^\S\n]*(Warm regards|Kind regards|Yours sincerely|Yours faithfully)', r'\1\n\n\2', text)
# Strip raw prompt labels from generated output
text = re.sub(r'^(Addressee|Salutation|Sign-off):\s*', '', text, flags=re.MULTILINE)
return text.strip()
@staticmethod
def _mock_reference_letter() -> str:
"""Return deterministic reference letter text for mock mode generation.
Args:
None: Uses embedded fixture text.
Returns:
str: Structured letter text with known section headings.
"""
return (
"History of presenting complaint\n"
"Mrs Thompson reported worsening fatigue and reduced exercise tolerance over the last three months. "
"She confirmed she is taking metformin and gliclazide but occasionally misses evening doses.\n\n"
"Examination findings\n"
"No acute distress was described during the consultation. She denied chest pain, syncope, or focal neurological symptoms.\n\n"
"Investigation results\n"
"Recent blood results showed HbA1c 55 mmol/mol with eGFR 52 mL/min/1.73m². "
"Penicillin allergy with previous anaphylaxis was reconfirmed.\n\n"
"Assessment\n"
"Overall picture is suboptimal glycaemic control with associated fatigue.\n\n"
"Plan\n"
"1. Lifestyle reinforcement and medicine adherence review.\n"
"2. Repeat renal profile in 3 months.\n"
"3. Consideration of treatment escalation if HbA1c remains above target.\n\n"
"Advice to patient\n"
"1. Counselled on symptoms of low blood sugar.\n\n"
"Current medications\n"
"Metformin 1 g twice daily; Gliclazide 80 mg twice daily."
)