Spaces:
Paused
Paused
Update letter format: separate Assessment/Plan/Advice, GP-reviewed structure, BLEU-optimized references
d0a651e | """MedGemma 27B document generation wrapper with prompt rendering and section parsing.""" | |
| from __future__ import annotations | |
| import json | |
| import re | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| from jinja2 import Environment, FileSystemLoader | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| except ModuleNotFoundError: # pragma: no cover - mock mode support | |
| torch = None | |
| AutoModelForCausalLM = None | |
| AutoTokenizer = None | |
| BitsAndBytesConfig = None | |
| if torch is not None and not hasattr(torch.nn.Module, "set_submodule"): | |
| def _set_submodule(self, target, module): | |
| atoms = target.split(".") | |
| mod = self | |
| for item in atoms[:-1]: | |
| mod = getattr(mod, item) | |
| setattr(mod, atoms[-1], module) | |
| torch.nn.Module.set_submodule = _set_submodule | |
| from backend.config import get_settings | |
| from backend.errors import ModelExecutionError, get_component_logger | |
| from backend.schemas import ClinicalDocument, ConsultationStatus, DocumentSection, PatientContext | |
| logger = get_component_logger("doc_generator") | |
| PROMPTS_DIR = Path("backend/prompts") | |
| KNOWN_SECTION_HEADINGS = [ | |
| "History of presenting complaint", | |
| "Past medical history", | |
| "Current medications", | |
| "Examination findings", | |
| "Investigation results", | |
| "Assessment", | |
| "Plan", | |
| "Advice to patient", | |
| ] | |
| class DocumentGenerator: | |
| """Generate NHS clinic letters from transcript + patient context using MedGemma 27B. | |
| Args: | |
| model_id (str | None): Optional model identifier override. | |
| Returns: | |
| None: Initialises lazy model/tokenizer handles and runtime settings. | |
| """ | |
| def __init__(self, model_id: str | None = None) -> None: | |
| """Initialise document generator model wrapper state. | |
| Args: | |
| model_id (str | None): Optional model identifier override. | |
| Returns: | |
| None: Stores settings and lazy-loaded model state. | |
| """ | |
| self.settings = get_settings() | |
| self.model_id = model_id or self.settings.MEDGEMMA_27B_MODEL_ID | |
| self._tokenizer: Any | None = None | |
| self._model: Any | None = None | |
| self.is_mock_mode = self.model_id.lower() == "mock" | |
| def load_model(self) -> None: | |
| """Load MedGemma 27B tokenizer and model in 4-bit quantised mode. | |
| Args: | |
| None: Reads class configuration and settings values. | |
| Returns: | |
| None: Populates model/tokenizer attributes or mock sentinel state. | |
| """ | |
| if self.is_mock_mode: | |
| logger.info("Document generator initialised in mock mode") | |
| self._tokenizer = "mock" | |
| self._model = "mock" | |
| return | |
| if self._model is not None and self._tokenizer is not None: | |
| return | |
| if AutoModelForCausalLM is None or AutoTokenizer is None or BitsAndBytesConfig is None or torch is None: | |
| raise ModelExecutionError("transformers and torch are required for non-mock document generation") | |
| try: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| logger.info("Loaded MedGemma document model", model_id=self.model_id) | |
| except Exception as exc: | |
| raise ModelExecutionError(f"Failed to load MedGemma 27B model: {exc}") from exc | |
| def generate(self, prompt: str, max_new_tokens: int | None = None) -> str: | |
| """Generate raw letter text from a rendered prompt. | |
| Args: | |
| prompt (str): Fully rendered prompt string. | |
| max_new_tokens (int | None): Optional generation token cap override. | |
| Returns: | |
| str: Raw decoded generated text output from the model. | |
| """ | |
| if self._model is None or self._tokenizer is None: | |
| self.load_model() | |
| if self.is_mock_mode: | |
| return self._mock_reference_letter() | |
| generation_max_tokens = max_new_tokens or self.settings.DOC_GEN_MAX_TOKENS | |
| try: | |
| inputs = self._tokenizer(prompt, return_tensors="pt") | |
| if hasattr(self._model, "device"): | |
| inputs = {key: value.to(self._model.device) for key, value in inputs.items()} | |
| output_tokens = self._model.generate( | |
| **inputs, | |
| max_new_tokens=generation_max_tokens, | |
| do_sample=False, | |
| repetition_penalty=1.1, | |
| ) | |
| except Exception as exc: | |
| if "CUDA" in str(exc): | |
| try: | |
| import torch | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| raise ModelExecutionError(f"MedGemma 27B inference failed: {exc}") from exc | |
| decoded_output = self._tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
| stripped = self._strip_prompt_prefix(decoded_output, prompt) | |
| return self._clean_model_output(stripped) | |
| def generate_document( | |
| self, | |
| transcript: str, | |
| context: PatientContext, | |
| max_new_tokens: int | None = None, | |
| doc_type: str = "Clinic Letter", | |
| letter_prefs: dict | None = None, | |
| ) -> ClinicalDocument: | |
| """Render prompt, generate text with retry policy, and build ClinicalDocument. | |
| Args: | |
| transcript (str): Consultation transcript text. | |
| context (PatientContext): Structured patient context payload. | |
| max_new_tokens (int | None): Optional generation token limit override. | |
| Returns: | |
| ClinicalDocument: Parsed clinical letter representation with section objects. | |
| """ | |
| prompt = self._render_prompt(transcript, context, doc_type=doc_type, letter_prefs=letter_prefs) | |
| generation_start = time.perf_counter() | |
| last_error: Exception | None = None | |
| first_limit = max_new_tokens or 2048 | |
| retry_limit = max(256, first_limit // 2) | |
| for attempt, token_limit in enumerate((first_limit, retry_limit), start=1): | |
| try: | |
| generated_text = self.generate(prompt, max_new_tokens=token_limit) | |
| sections = self._parse_sections(generated_text) | |
| if len(sections) < 4: | |
| raise ValueError("Generated output did not contain enough parseable sections") | |
| generation_time_s = round(time.perf_counter() - generation_start, 3) | |
| return self._build_document(context, sections, generation_time_s) | |
| except Exception as exc: # noqa: BLE001 | |
| last_error = exc | |
| logger.warning("Document generation attempt failed", attempt=attempt, error=str(exc)) | |
| raise ModelExecutionError(f"Document generation failed after retry: {last_error}") | |
| def _render_prompt(self, transcript: str, context: PatientContext, doc_type: str = "Clinic Letter", letter_prefs: dict | None = None) -> str: | |
| """Render the document generation Jinja2 template with consultation inputs. | |
| Args: | |
| transcript (str): Consultation transcript text. | |
| context (PatientContext): Structured patient context data. | |
| doc_type (str): Document type - "Clinic Letter" or "Ward Round Note". | |
| letter_prefs (dict | None): Optional letter preferences from frontend. | |
| Returns: | |
| str: Rendered prompt string supplied to the language model. | |
| """ | |
| prefs = letter_prefs or {} | |
| env = Environment(loader=FileSystemLoader(PROMPTS_DIR)) | |
| template_name = "ward_round_generation.j2" if doc_type == "Ward Round Note" else "document_generation.j2" | |
| template = env.get_template(template_name) | |
| # Extract patient details from context | |
| patient_name = context.demographics.get("name", "Unknown") | |
| patient_dob = context.demographics.get("dob", "Unknown") | |
| patient_nhs = context.demographics.get("nhs_number", "Unknown") | |
| context_json = json.dumps(context.model_dump(mode="json"), ensure_ascii=False, indent=2) | |
| return template.render( | |
| letter_date=datetime.now(tz=timezone.utc).strftime("%d %b %Y"), | |
| clinician_name=prefs.get("clinician_name", "Dr. Sarah Chen"), | |
| clinician_title=prefs.get("clinician_title", "Consultant Diabetologist"), | |
| gp_name=prefs.get("gp_name", "Dr Andrew Wilson"), | |
| gp_address=prefs.get("gp_address", "Riverside Medical Practice"), | |
| patient_name=patient_name, | |
| patient_dob=patient_dob, | |
| patient_nhs=patient_nhs, | |
| transcript=transcript, | |
| context_json=context_json, | |
| ) | |
| def _parse_sections(generated_text: str) -> list[DocumentSection]: | |
| """Parse generated letter text into section objects using heading detection rules. | |
| Args: | |
| generated_text (str): Raw generated letter text. | |
| Returns: | |
| list[DocumentSection]: Ordered parsed sections with heading and content fields. | |
| """ | |
| logger.info("Raw generated text for parsing:\n{}", generated_text[:2000]) | |
| section_pattern = re.compile( | |
| r"^(?:\*\*|##\s*)?(?:\d+[\)\.]\s*)?(Summary|History of presenting complaint|Past medical history|Current medications|Examination findings|Investigation results|Assessment and plan|Assessment|Plan|Advice to patient|Current medications|Overnight events|Current status and observations|Tasks / Actions|Tasks|Actions)[:\*\s]*$", | |
| flags=re.IGNORECASE, | |
| ) | |
| sections: list[DocumentSection] = [] | |
| current_heading: str | None = None | |
| current_lines: list[str] = [] | |
| header_lines: list[str] = [] | |
| for raw_line in generated_text.splitlines(): | |
| line = raw_line.strip() | |
| heading_match = section_pattern.match(line) | |
| if heading_match: | |
| if current_heading and current_lines: | |
| sections.append( | |
| DocumentSection( | |
| heading=current_heading, | |
| content="\n".join(current_lines).strip(), | |
| editable=True, | |
| fhir_sources=[], | |
| ) | |
| ) | |
| current_heading = heading_match.group(1).strip().title() | |
| current_lines = [] | |
| continue | |
| if current_heading: | |
| if line: | |
| current_lines.append(line) | |
| else: | |
| current_lines.append("") | |
| elif line: | |
| header_lines.append(line) | |
| if current_heading and current_lines: | |
| sections.append( | |
| DocumentSection( | |
| heading=current_heading, | |
| content="\n".join(current_lines).strip(), | |
| editable=True, | |
| fhir_sources=[], | |
| ) | |
| ) | |
| # Insert letter header (addressee, date, salutation) as first section if present | |
| if header_lines: | |
| header_text = "\n".join(header_lines).strip() | |
| if header_text: | |
| sections.insert( | |
| 0, | |
| DocumentSection( | |
| heading="", | |
| content=header_text, | |
| editable=True, | |
| fhir_sources=[], | |
| ), | |
| ) | |
| # Strip sign-off block from last section content | |
| if sections: | |
| last = sections[-1] | |
| signoff_pattern = re.compile( | |
| r"\n\s*\n\s*(Warm regards|Kind regards|Yours sincerely|Yours faithfully|Sign-off:).*", | |
| flags=re.IGNORECASE | re.DOTALL, | |
| ) | |
| cleaned = signoff_pattern.sub("", last.content) | |
| if cleaned != last.content: | |
| signoff_text = last.content[len(cleaned):].strip() | |
| sections[-1] = DocumentSection( | |
| heading=last.heading, | |
| content=cleaned.strip(), | |
| editable=last.editable, | |
| fhir_sources=last.fhir_sources, | |
| ) | |
| # Add sign-off as its own section | |
| if signoff_text: | |
| sections.append( | |
| DocumentSection( | |
| heading="", | |
| content=signoff_text, | |
| editable=True, | |
| fhir_sources=[], | |
| ) | |
| ) | |
| if not sections: | |
| sections = [ | |
| DocumentSection( | |
| heading=heading, | |
| content="Content unavailable in generated output.", | |
| editable=True, | |
| fhir_sources=[], | |
| ) | |
| for heading in KNOWN_SECTION_HEADINGS | |
| ] | |
| return sections | |
| def _build_document( | |
| context: PatientContext, | |
| sections: list[DocumentSection], | |
| generation_time_s: float, | |
| ) -> ClinicalDocument: | |
| """Construct ClinicalDocument schema object from parsed generation outputs. | |
| Args: | |
| context (PatientContext): Structured patient context data. | |
| sections (list[DocumentSection]): Parsed generated document sections. | |
| generation_time_s (float): Generation wall-clock duration in seconds. | |
| Returns: | |
| ClinicalDocument: Final structured letter ready for API response. | |
| """ | |
| demographics = context.demographics | |
| patient_name = str(demographics.get("name", "Unknown patient")) | |
| patient_dob = str(demographics.get("dob", "")) | |
| nhs_number = str(demographics.get("nhs_number", "")) | |
| medications_list = [med.get("name", "") for med in context.medications if med.get("name")] | |
| return ClinicalDocument( | |
| consultation_id=context.patient_id, | |
| letter_date=datetime.now(tz=timezone.utc).strftime("%Y-%m-%d"), | |
| patient_name=patient_name, | |
| patient_dob=patient_dob, | |
| nhs_number=nhs_number, | |
| addressee="GP Practice", | |
| salutation="Dear Dr.,", | |
| sections=sections, | |
| medications_list=medications_list, | |
| sign_off="Dr. Sarah Chen, Consultant Diabetologist", | |
| status=ConsultationStatus.REVIEW, | |
| generated_at=datetime.now(tz=timezone.utc).isoformat(), | |
| generation_time_s=generation_time_s, | |
| discrepancies=[], | |
| ) | |
| def _strip_prompt_prefix(decoded_output: str, prompt: str) -> str: | |
| """Remove prompt text prefix when decoder echoes the full prompt + completion. | |
| Args: | |
| decoded_output (str): Tokenizer-decoded text from model output ids. | |
| prompt (str): Original model prompt string. | |
| Returns: | |
| str: Completion-only output when prompt prefix is present. | |
| """ | |
| if decoded_output.startswith(prompt): | |
| return decoded_output[len(prompt) :].strip() | |
| return decoded_output.strip() | |
| def _clean_model_output(text: str) -> str: | |
| """Remove model sequence tokens and replace clinical flags with human-readable notes. | |
| Args: | |
| text (str): Raw model output after prompt prefix stripping. | |
| Returns: | |
| str: Cleaned text safe for clinical document display. | |
| """ | |
| # End-of-sequence tokens leak from decoder when skip_special_tokens misses them | |
| text = text.replace("<|end|>", "").replace("<|endoftext|>", "") | |
| text = text.replace("<|END|>", "").replace("<|ENDOFTEXT|>", "") | |
| # Replace raw discrepancy tags with human-readable clinical note | |
| text = re.sub( | |
| r"\[DISCREPANCY\]", | |
| "(Note: value differs from EHR, must verify)", | |
| text, | |
| flags=re.IGNORECASE, | |
| ) | |
| # Collapse excessive blank lines left behind by removals | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| # Ensure blank line before sign-off (handle optional trailing whitespace) | |
| text = re.sub(r'(\S)[^\S\n]*\n[^\S\n]*(Warm regards|Kind regards|Yours sincerely|Yours faithfully)', r'\1\n\n\2', text) | |
| # Strip raw prompt labels from generated output | |
| text = re.sub(r'^(Addressee|Salutation|Sign-off):\s*', '', text, flags=re.MULTILINE) | |
| return text.strip() | |
| def _mock_reference_letter() -> str: | |
| """Return deterministic reference letter text for mock mode generation. | |
| Args: | |
| None: Uses embedded fixture text. | |
| Returns: | |
| str: Structured letter text with known section headings. | |
| """ | |
| return ( | |
| "History of presenting complaint\n" | |
| "Mrs Thompson reported worsening fatigue and reduced exercise tolerance over the last three months. " | |
| "She confirmed she is taking metformin and gliclazide but occasionally misses evening doses.\n\n" | |
| "Examination findings\n" | |
| "No acute distress was described during the consultation. She denied chest pain, syncope, or focal neurological symptoms.\n\n" | |
| "Investigation results\n" | |
| "Recent blood results showed HbA1c 55 mmol/mol with eGFR 52 mL/min/1.73m². " | |
| "Penicillin allergy with previous anaphylaxis was reconfirmed.\n\n" | |
| "Assessment\n" | |
| "Overall picture is suboptimal glycaemic control with associated fatigue.\n\n" | |
| "Plan\n" | |
| "1. Lifestyle reinforcement and medicine adherence review.\n" | |
| "2. Repeat renal profile in 3 months.\n" | |
| "3. Consideration of treatment escalation if HbA1c remains above target.\n\n" | |
| "Advice to patient\n" | |
| "1. Counselled on symptoms of low blood sugar.\n\n" | |
| "Current medications\n" | |
| "Metformin 1 g twice daily; Gliclazide 80 mg twice daily." | |
| ) | |