| | |
| | """ |
| | Track C — Core refinement loop. |
| | |
| | On each iteration: |
| | 1. Feed the current differential to a self-critique prompt |
| | 2. Ask the model to identify weaknesses, missing diagnoses, and re-rank |
| | 3. Compare the new differential to the previous one |
| | 4. If converged (similarity > threshold) or max iters reached, stop |
| | 5. Otherwise loop |
| | |
| | Each LLM call is recorded in the CostLedger for cost/benefit charting. |
| | """ |
| | from __future__ import annotations |
| |
|
| | import logging |
| | import time |
| | import uuid |
| | from typing import List, Optional, Tuple |
| |
|
| | from app.models.schemas import ( |
| | ClinicalReasoningResult, |
| | DiagnosisCandidate, |
| | PatientProfile, |
| | ) |
| | from app.services.medgemma import MedGemmaService |
| | from tracks.iterative.config import IterativeConfig |
| | from tracks.shared.cost_tracker import ( |
| | CostLedger, |
| | estimate_cost, |
| | estimate_tokens, |
| | LLMCallRecord, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| |
|
| | CRITIQUE_SYSTEM = """You are a senior physician reviewing a colleague's differential diagnosis. |
| | Your goal is to find weaknesses, missing diagnoses, and ranking errors. |
| | Be constructive but rigorous. You may add, remove, reorder, or refine diagnoses. |
| | Always justify changes with clinical evidence from the patient data.""" |
| |
|
| | CRITIQUE_PROMPT = """ |
| | PATIENT PROFILE: |
| | {patient_summary} |
| | |
| | CURRENT DIFFERENTIAL (iteration {iteration}): |
| | {current_differential} |
| | |
| | INSTRUCTIONS: |
| | 1. For each diagnosis, state whether you AGREE, DISAGREE, or want to MODIFY it. |
| | 2. Identify any MISSING diagnoses that should be on the list. |
| | 3. Propose a REVISED differential diagnosis, ranked by likelihood. |
| | 4. Explain your chain-of-thought reasoning for every change. |
| | 5. Rate each diagnosis as "low", "moderate", or "high" likelihood. |
| | |
| | Return ONLY the revised differential as a JSON object matching |
| | ClinicalReasoningResult schema (differential_diagnosis, risk_assessment, |
| | recommended_workup, reasoning_chain). |
| | """ |
| |
|
| |
|
| | class IterativeRefiner: |
| | """ |
| | Runs the iterative self-critique loop for a single patient case. |
| | |
| | Usage: |
| | refiner = IterativeRefiner(config, ledger) |
| | final, history = await refiner.refine(profile, initial_reasoning) |
| | """ |
| |
|
| | def __init__(self, config: IterativeConfig, ledger: CostLedger): |
| | self.config = config |
| | self.ledger = ledger |
| | self.medgemma = MedGemmaService() |
| |
|
| | async def refine( |
| | self, |
| | profile: PatientProfile, |
| | initial_reasoning: ClinicalReasoningResult, |
| | ) -> Tuple[ClinicalReasoningResult, List[ClinicalReasoningResult]]: |
| | """ |
| | Iteratively refine the differential diagnosis. |
| | |
| | Args: |
| | profile: Structured patient data |
| | initial_reasoning: Track A's first-pass reasoning result |
| | |
| | Returns: |
| | (final_reasoning, iteration_history) — history[0] is the initial, |
| | history[-1] is the final. |
| | """ |
| | history: List[ClinicalReasoningResult] = [initial_reasoning] |
| | current = initial_reasoning |
| |
|
| | patient_summary = self._format_profile(profile) |
| |
|
| | for iteration in range(1, self.config.max_iterations + 1): |
| | logger.info( |
| | f" [Iteration {iteration}/{self.config.max_iterations}] " |
| | f"Current top dx: {self._top_dx(current)}" |
| | ) |
| |
|
| | |
| | diff_text = self._format_differential(current) |
| | prompt = CRITIQUE_PROMPT.format( |
| | patient_summary=patient_summary, |
| | iteration=iteration, |
| | current_differential=diff_text, |
| | ) |
| |
|
| | |
| | input_tokens = estimate_tokens(CRITIQUE_SYSTEM + prompt) |
| | t0 = time.monotonic() |
| |
|
| | try: |
| | revised = await self.medgemma.generate_structured( |
| | prompt=prompt, |
| | response_model=ClinicalReasoningResult, |
| | system_prompt=CRITIQUE_SYSTEM, |
| | temperature=self.config.critique_temperature, |
| | max_tokens=self.config.max_tokens_critique, |
| | ) |
| | except (ValueError, Exception) as e: |
| | logger.warning( |
| | f" [Iteration {iteration}] Structured parse failed, " |
| | f"stopping refinement early: {e}" |
| | ) |
| | |
| | return current, history |
| |
|
| | elapsed_ms = int((time.monotonic() - t0) * 1000) |
| | output_tokens = estimate_tokens(str(revised.model_dump_json())) |
| |
|
| | |
| | self.ledger.calls.append(LLMCallRecord( |
| | call_id=str(uuid.uuid4())[:8], |
| | track_id=self.ledger.track_id, |
| | step_name="iterative_critique", |
| | iteration=iteration, |
| | input_tokens=input_tokens, |
| | output_tokens=output_tokens, |
| | total_tokens=input_tokens + output_tokens, |
| | latency_ms=elapsed_ms, |
| | temperature=self.config.critique_temperature, |
| | max_tokens_requested=self.config.max_tokens_critique, |
| | estimated_cost_usd=estimate_cost(input_tokens, output_tokens), |
| | timestamp=time.time(), |
| | )) |
| |
|
| | history.append(revised) |
| |
|
| | |
| | if self._has_converged(current, revised): |
| | logger.info( |
| | f" Converged at iteration {iteration} " |
| | f"(threshold {self.config.convergence_threshold})" |
| | ) |
| | return revised, history |
| |
|
| | current = revised |
| |
|
| | logger.info(f" Reached max iterations ({self.config.max_iterations})") |
| | return current, history |
| |
|
| | def _has_converged( |
| | self, |
| | prev: ClinicalReasoningResult, |
| | curr: ClinicalReasoningResult, |
| | ) -> bool: |
| | """ |
| | Check if the differential is stable between two iterations. |
| | |
| | Uses simple top-N diagnosis name overlap as a proxy for convergence. |
| | If the top diagnoses haven't changed, the model is repeating itself. |
| | """ |
| | prev_names = {dx.diagnosis.lower().strip() for dx in prev.differential_diagnosis[:5]} |
| | curr_names = {dx.diagnosis.lower().strip() for dx in curr.differential_diagnosis[:5]} |
| |
|
| | if not prev_names or not curr_names: |
| | return False |
| |
|
| | overlap = len(prev_names & curr_names) |
| | union = len(prev_names | curr_names) |
| | jaccard = overlap / union if union > 0 else 0.0 |
| |
|
| | return jaccard >= (1 - self.config.convergence_threshold) |
| |
|
| | @staticmethod |
| | def _top_dx(reasoning: ClinicalReasoningResult) -> str: |
| | if reasoning.differential_diagnosis: |
| | return reasoning.differential_diagnosis[0].diagnosis |
| | return "(empty)" |
| |
|
| | @staticmethod |
| | def _format_profile(profile: PatientProfile) -> str: |
| | parts = [ |
| | f"Age: {profile.age or 'Unknown'}, Gender: {profile.gender.value}", |
| | f"Chief Complaint: {profile.chief_complaint}", |
| | f"HPI: {profile.history_of_present_illness or 'N/A'}", |
| | ] |
| | if profile.past_medical_history: |
| | parts.append(f"PMH: {', '.join(profile.past_medical_history)}") |
| | if profile.current_medications: |
| | meds = "; ".join(f"{m.name} {m.dose or ''}" for m in profile.current_medications) |
| | parts.append(f"Medications: {meds}") |
| | return "\n".join(parts) |
| |
|
| | @staticmethod |
| | def _format_differential(reasoning: ClinicalReasoningResult) -> str: |
| | lines = [] |
| | for i, dx in enumerate(reasoning.differential_diagnosis, 1): |
| | lines.append( |
| | f"{i}. {dx.diagnosis} (likelihood: {dx.likelihood.value}) — {dx.reasoning}" |
| | ) |
| | if reasoning.risk_assessment: |
| | lines.append(f"\nRisk Assessment: {reasoning.risk_assessment}") |
| | return "\n".join(lines) if lines else "(no differential generated)" |
| |
|