bshepp
Implement validation pipeline fixes (P1-P7) and experimental track system
28f1212
# [Track C: Iterative Refinement]
"""
Track C — Core refinement loop.
On each iteration:
1. Feed the current differential to a self-critique prompt
2. Ask the model to identify weaknesses, missing diagnoses, and re-rank
3. Compare the new differential to the previous one
4. If converged (similarity > threshold) or max iters reached, stop
5. Otherwise loop
Each LLM call is recorded in the CostLedger for cost/benefit charting.
"""
from __future__ import annotations
import logging
import time
import uuid
from typing import List, Optional, Tuple
from app.models.schemas import (
ClinicalReasoningResult,
DiagnosisCandidate,
PatientProfile,
)
from app.services.medgemma import MedGemmaService
from tracks.iterative.config import IterativeConfig
from tracks.shared.cost_tracker import (
CostLedger,
estimate_cost,
estimate_tokens,
LLMCallRecord,
)
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
# Self-critique prompt
# ──────────────────────────────────────────────
CRITIQUE_SYSTEM = """You are a senior physician reviewing a colleague's differential diagnosis.
Your goal is to find weaknesses, missing diagnoses, and ranking errors.
Be constructive but rigorous. You may add, remove, reorder, or refine diagnoses.
Always justify changes with clinical evidence from the patient data."""
CRITIQUE_PROMPT = """
PATIENT PROFILE:
{patient_summary}
CURRENT DIFFERENTIAL (iteration {iteration}):
{current_differential}
INSTRUCTIONS:
1. For each diagnosis, state whether you AGREE, DISAGREE, or want to MODIFY it.
2. Identify any MISSING diagnoses that should be on the list.
3. Propose a REVISED differential diagnosis, ranked by likelihood.
4. Explain your chain-of-thought reasoning for every change.
5. Rate each diagnosis as "low", "moderate", or "high" likelihood.
Return ONLY the revised differential as a JSON object matching
ClinicalReasoningResult schema (differential_diagnosis, risk_assessment,
recommended_workup, reasoning_chain).
"""
class IterativeRefiner:
"""
Runs the iterative self-critique loop for a single patient case.
Usage:
refiner = IterativeRefiner(config, ledger)
final, history = await refiner.refine(profile, initial_reasoning)
"""
def __init__(self, config: IterativeConfig, ledger: CostLedger):
self.config = config
self.ledger = ledger
self.medgemma = MedGemmaService()
async def refine(
self,
profile: PatientProfile,
initial_reasoning: ClinicalReasoningResult,
) -> Tuple[ClinicalReasoningResult, List[ClinicalReasoningResult]]:
"""
Iteratively refine the differential diagnosis.
Args:
profile: Structured patient data
initial_reasoning: Track A's first-pass reasoning result
Returns:
(final_reasoning, iteration_history) — history[0] is the initial,
history[-1] is the final.
"""
history: List[ClinicalReasoningResult] = [initial_reasoning]
current = initial_reasoning
patient_summary = self._format_profile(profile)
for iteration in range(1, self.config.max_iterations + 1):
logger.info(
f" [Iteration {iteration}/{self.config.max_iterations}] "
f"Current top dx: {self._top_dx(current)}"
)
# Build the critique prompt
diff_text = self._format_differential(current)
prompt = CRITIQUE_PROMPT.format(
patient_summary=patient_summary,
iteration=iteration,
current_differential=diff_text,
)
# Call MedGemma for the self-critique
input_tokens = estimate_tokens(CRITIQUE_SYSTEM + prompt)
t0 = time.monotonic()
try:
revised = await self.medgemma.generate_structured(
prompt=prompt,
response_model=ClinicalReasoningResult,
system_prompt=CRITIQUE_SYSTEM,
temperature=self.config.critique_temperature,
max_tokens=self.config.max_tokens_critique,
)
except (ValueError, Exception) as e:
logger.warning(
f" [Iteration {iteration}] Structured parse failed, "
f"stopping refinement early: {e}"
)
# Return best result so far instead of crashing
return current, history
elapsed_ms = int((time.monotonic() - t0) * 1000)
output_tokens = estimate_tokens(str(revised.model_dump_json()))
# Record cost
self.ledger.calls.append(LLMCallRecord(
call_id=str(uuid.uuid4())[:8],
track_id=self.ledger.track_id,
step_name="iterative_critique",
iteration=iteration,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
latency_ms=elapsed_ms,
temperature=self.config.critique_temperature,
max_tokens_requested=self.config.max_tokens_critique,
estimated_cost_usd=estimate_cost(input_tokens, output_tokens),
timestamp=time.time(),
))
history.append(revised)
# Check convergence
if self._has_converged(current, revised):
logger.info(
f" Converged at iteration {iteration} "
f"(threshold {self.config.convergence_threshold})"
)
return revised, history
current = revised
logger.info(f" Reached max iterations ({self.config.max_iterations})")
return current, history
def _has_converged(
self,
prev: ClinicalReasoningResult,
curr: ClinicalReasoningResult,
) -> bool:
"""
Check if the differential is stable between two iterations.
Uses simple top-N diagnosis name overlap as a proxy for convergence.
If the top diagnoses haven't changed, the model is repeating itself.
"""
prev_names = {dx.diagnosis.lower().strip() for dx in prev.differential_diagnosis[:5]}
curr_names = {dx.diagnosis.lower().strip() for dx in curr.differential_diagnosis[:5]}
if not prev_names or not curr_names:
return False
overlap = len(prev_names & curr_names)
union = len(prev_names | curr_names)
jaccard = overlap / union if union > 0 else 0.0
return jaccard >= (1 - self.config.convergence_threshold)
@staticmethod
def _top_dx(reasoning: ClinicalReasoningResult) -> str:
if reasoning.differential_diagnosis:
return reasoning.differential_diagnosis[0].diagnosis
return "(empty)"
@staticmethod
def _format_profile(profile: PatientProfile) -> str:
parts = [
f"Age: {profile.age or 'Unknown'}, Gender: {profile.gender.value}",
f"Chief Complaint: {profile.chief_complaint}",
f"HPI: {profile.history_of_present_illness or 'N/A'}",
]
if profile.past_medical_history:
parts.append(f"PMH: {', '.join(profile.past_medical_history)}")
if profile.current_medications:
meds = "; ".join(f"{m.name} {m.dose or ''}" for m in profile.current_medications)
parts.append(f"Medications: {meds}")
return "\n".join(parts)
@staticmethod
def _format_differential(reasoning: ClinicalReasoningResult) -> str:
lines = []
for i, dx in enumerate(reasoning.differential_diagnosis, 1):
lines.append(
f"{i}. {dx.diagnosis} (likelihood: {dx.likelihood.value}) — {dx.reasoning}"
)
if reasoning.risk_assessment:
lines.append(f"\nRisk Assessment: {reasoning.risk_assessment}")
return "\n".join(lines) if lines else "(no differential generated)"