cds-agent / VALIDATION_PIPELINE_PLAN.md
bshepp
Implement validation pipeline fixes (P1-P7) and experimental track system
28f1212

VALIDATION_PIPELINE_PLAN.md β€” Validation Pipeline Fix Plan

Purpose: Step-by-step implementation plan for fixing the validation/scoring pipeline so accuracy metrics actually reflect the system's capabilities.

Root cause: The pipeline forces every MedQA question through differential diagnosis generation, but only 7/50 sampled questions are diagnostic. The other 43 are treatment, mechanism, lab-finding, ethics, etc. β€” producing near-zero accuracy on questions the pipeline was never designed to answer.

Expected impact: Fixes P5+P3+P6 alone should raise measured MedQA accuracy from ~36% to 60-70%+. Full implementation of all 7 fixes gives honest, stratified metrics and unlocks multi-mode pipeline expansion.

Implementation order: Bottom-up through the data flow. Each step locks down its interface before the next layer builds on it. No rewrites needed.


Step 1: P5 β€” Fix fuzzy_match() for Short Answers

File: src/backend/validation/base.py
Functions: fuzzy_match(), normalize_text()
Depends on: Nothing
Depended on by: P4 (type-aware scoring), P6 (MCQ selection comparison)

Problem

fuzzy_match() uses min(len(c_tokens), len(t_tokens)) as the denominator for token overlap. For a 1-word target like "Clopidogrel", min(1, 200) = 1, so a single token match gives 100% overlap. But for a 3-word target like "Cross-linking of DNA", stop-word removal and normalization can reduce the target to 2 tokens, and if the candidate doesn't contain those specific tokens, it fails β€” even if the concept is present in different phrasing.

The substring check (normalize_text(target) in normalize_text(candidate)) works for exact matches but fails for any morphological variation: "clopidogrel 75mg" won't substring-match "Clopidogrel" because the candidate is longer.

Wait β€” actually the current code does normalize_text(target) in normalize_text(candidate), which WOULD match "clopidogrel" inside "clopidogrel 75mg daily". The real failure case is when the answer uses different phrasing than the pipeline output, e.g.:

  • Target: "Reassurance and continuous monitoring"
  • Pipeline says: "reassure the patient and monitor continuously"
  • Neither substring contains the other, and token overlap may be low

Changes

# In base.py β€” replace fuzzy_match() entirely

def normalize_text(text: str) -> str:
    """Lowercase, strip punctuation, normalize whitespace."""
    text = text.lower().strip()
    text = re.sub(r'[^\w\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()


# Medical stopwords that don't carry diagnostic meaning
_MEDICAL_STOPWORDS = frozenset({
    "the", "a", "an", "of", "in", "to", "and", "or", "is", "are", "was",
    "were", "be", "been", "with", "for", "on", "at", "by", "from", "this",
    "that", "these", "those", "it", "its", "has", "have", "had", "do",
    "does", "did", "will", "would", "could", "should", "may", "might",
    "most", "likely", "following", "which", "what", "patient", "patients",
})


def _content_tokens(text: str) -> set[str]:
    """Extract meaningful content tokens, removing medical stopwords."""
    tokens = set(normalize_text(text).split())
    return tokens - _MEDICAL_STOPWORDS


def fuzzy_match(candidate: str, target: str, threshold: float = 0.6) -> bool:
    """
    Check if candidate text is a fuzzy match for target.

    Strategy (checked in order, first match wins):
      1. Normalized substring containment (either direction)
      2. All content tokens of target appear in candidate (recall=1.0)
      3. Token overlap ratio >= threshold (using content tokens)

    Args:
        candidate: Text from the pipeline output (may be long)
        target: Ground truth text (usually short)
        threshold: Minimum token overlap ratio (0.0-1.0)
    """
    c_norm = normalize_text(candidate)
    t_norm = normalize_text(target)

    if not t_norm:
        return False

    # 1. Substring containment (either direction)
    if t_norm in c_norm or c_norm in t_norm:
        return True

    # 2. All content tokens of target present in candidate
    #    This catches "clopidogrel" in a 500-word report
    t_content = _content_tokens(target)
    c_content = _content_tokens(candidate)

    if t_content and t_content.issubset(c_content):
        return True

    # 3. Token overlap ratio
    if not t_content or not c_content:
        return False

    overlap = len(t_content & c_content)
    # Use target token count as denominator β€” "what fraction of
    # the target's meaning is present in the candidate?"
    recall = overlap / len(t_content)

    return recall >= threshold

Key interface change

  • Signature stays the same: fuzzy_match(candidate, target, threshold) -> bool
  • Behavior change: More permissive matching for short targets (all-token-subset check), slightly different threshold semantics (recall-based instead of min-denominator-based). This is strictly better β€” no downstream code breaks.

Tests to write

# test_fuzzy_match.py
def test_short_target_substring():
    assert fuzzy_match("Start clopidogrel 75mg daily", "Clopidogrel") == True

def test_short_target_all_tokens():
    assert fuzzy_match("The diagnosis is cholesterol embolization syndrome", "Cholesterol embolization") == True

def test_multi_word_phrasing_variation():
    # "Reassurance and continuous monitoring" vs report text
    assert fuzzy_match(
        "reassure the patient and provide continuous cardiac monitoring",
        "Reassurance and continuous monitoring"
    ) == True  # content tokens: {reassurance, continuous, monitoring} β€” "reassurance" != "reassure" though

def test_no_false_positive():
    assert fuzzy_match("Acute myocardial infarction", "Pulmonary embolism") == False

def test_empty_target():
    assert fuzzy_match("some text", "") == False

Note: The "reassurance" vs "reassure" case will still fail without stemming. Add stemming as a future enhancement (e.g., via nltk.stem.PorterStemmer or a simple suffix-stripping function). For now, the all-token-subset check is the biggest improvement.

Validation

Run existing test suite β€” no existing tests should break because matching is strictly more permissive. Verify on a few known failure cases from the 50-case run results.


Step 2: P3 β€” Preserve the Question Stem

File: src/backend/validation/harness_medqa.py
Functions: _extract_vignette(), fetch_medqa()
Depends on: Nothing (independent of P5, but listed second for logical flow)
Depended on by: P1 (classifier needs the stem), P6 (MCQ step needs the stem + options)

Problem

_extract_vignette() strips the question stem ("Which of the following is the most likely diagnosis?") from the MedQA question. This means:

  1. The pipeline doesn't know what's being asked β€” it always defaults to "generate a differential"
  2. The question classifier (P1) can't classify without the stem
  3. The MCQ step (P6) can't present the original question

Changes

2a. Refactor _extract_vignette() β†’ _split_question()

# In harness_medqa.py β€” replace _extract_vignette()

def _split_question(question: str) -> tuple[str, str]:
    """
    Split a USMLE question into (clinical_vignette, question_stem).

    The vignette is the clinical narrative. The stem is the actual question
    being asked ("Which of the following is the most likely diagnosis?").

    Returns:
        (vignette, stem) β€” stem may be empty if no recognizable stem found.
        In that case, vignette contains the full question text.
    """
    stems = [
        r"which of the following",
        r"what is the most likely",
        r"what is the best next step",
        r"what is the most appropriate",
        r"what is the diagnosis",
        r"the most likely diagnosis is",
        r"this patient most likely has",
        r"what would be the next step",
        r"what is the next best step",
        r"what is the underlying",
        r"what is the mechanism",
        r"what is the pathophysiology",
    ]

    text = question.strip()
    for stem_pattern in stems:
        pattern = re.compile(
            rf'(\.?\s*)([A-Z][^.]*{stem_pattern}[^.]*[\?\.]?\s*)$',
            re.IGNORECASE,
        )
        match = pattern.search(text)
        if match:
            vignette = text[:match.start()].strip()
            stem_text = match.group(2).strip()
            if len(vignette) > 50:  # Sanity check
                return vignette, stem_text

    # Fallback: no recognizable stem β€” return full text as vignette
    return text, ""

2b. Update fetch_medqa() to store stem + vignette separately

# In fetch_medqa(), replace the case-building loop body:

        vignette, question_stem = _split_question(question)

        cases.append(ValidationCase(
            case_id=f"medqa_{i:04d}",
            source_dataset="medqa",
            input_text=vignette,           # Pipeline still gets the vignette
            ground_truth={
                "correct_answer": answer_text,
                "answer_idx": answer_idx,
                "options": options,
                "full_question": question,
            },
            metadata={
                "question_stem": question_stem,      # NEW
                "clinical_vignette": vignette,       # NEW (same as input_text, explicit)
                "full_question_with_stem": question,  # NEW (redundant with ground_truth but cleaner access)
            },
        ))

Key interface change

  • ValidationCase.metadata now has 3 new keys: question_stem, clinical_vignette, full_question_with_stem
  • input_text is still just the vignette (pipeline input unchanged)
  • _extract_vignette() is renamed to _split_question() returning a tuple
  • Old callers of _extract_vignette(): only fetch_medqa() β€” update in place

Backward compatibility

  • input_text stays the same β†’ pipeline behavior unchanged
  • ground_truth keeps all existing keys β†’ scoring unchanged
  • New data is in metadata only β†’ nothing breaks

Step 3: P1 β€” Question-Type Classifier

New file: src/backend/validation/question_classifier.py
Depends on: P3 (needs metadata["question_stem"])
Depended on by: P4 (type-aware scoring), P6 (routing), P7 (stratified reporting)

Design

Two-tier classifier:

  1. Heuristic classifier (fast, no LLM call, used by default) β€” regex on question stem
  2. LLM classifier (optional, for ambiguous cases) β€” ask MedGemma to classify

Start with heuristic only. It correctly classified our 50-case sample already (7 diagnostic, 6 treatment, 1 mechanism, 2 lab, 34 other β€” matching manual review).

Question type enum

# In question_classifier.py

from enum import Enum

class QuestionType(str, Enum):
    DIAGNOSTIC = "diagnostic"           # "most likely diagnosis/cause/explanation"
    TREATMENT = "treatment"             # "most appropriate next step/management/treatment"
    MECHANISM = "mechanism"             # "mechanism of action", "pathophysiology"
    LAB_FINDING = "lab_finding"         # "expected finding", "characteristic on agar"
    PHARMACOLOGY = "pharmacology"       # "drug that targets...", "receptor..."
    EPIDEMIOLOGY = "epidemiology"       # "risk factor", "prevalence", "incidence"
    ETHICS = "ethics"                   # "most appropriate action" (ethical dilemmas)
    ANATOMY = "anatomy"                 # "structure most likely damaged"
    OTHER = "other"                     # Doesn't fit above categories

Heuristic classifier

import re
from typing import Optional
from validation.base import ValidationCase


# Pattern β†’ QuestionType mapping (checked in order, first match wins)
_STEM_PATTERNS: list[tuple[str, QuestionType]] = [
    # Diagnostic
    (r"most likely diagnosis", QuestionType.DIAGNOSTIC),
    (r"most likely cause", QuestionType.DIAGNOSTIC),
    (r"most likely explanation", QuestionType.DIAGNOSTIC),
    (r"what is the diagnosis", QuestionType.DIAGNOSTIC),
    (r"diagnosis is", QuestionType.DIAGNOSTIC),
    (r"most likely condition", QuestionType.DIAGNOSTIC),
    (r"most likely has", QuestionType.DIAGNOSTIC),
    (r"most likely suffer", QuestionType.DIAGNOSTIC),

    # Treatment / Management
    (r"most appropriate (next step|management|treatment|intervention|therapy|pharmacotherapy)", QuestionType.TREATMENT),
    (r"best (next step|initial step|management|treatment)", QuestionType.TREATMENT),
    (r"most appropriate action", QuestionType.TREATMENT),  # Can be ethics β€” see below
    (r"recommended (treatment|management|therapy)", QuestionType.TREATMENT),

    # Mechanism
    (r"mechanism of action", QuestionType.MECHANISM),
    (r"pathophysiology", QuestionType.MECHANISM),
    (r"mediator.*(responsible|involved)", QuestionType.MECHANISM),
    (r"(inhibit|block|activate).*receptor", QuestionType.MECHANISM),
    (r"cross-link", QuestionType.MECHANISM),

    # Lab / Findings
    (r"most likely finding", QuestionType.LAB_FINDING),
    (r"expected (finding|result|value)", QuestionType.LAB_FINDING),
    (r"characteristic (finding|feature|appearance)", QuestionType.LAB_FINDING),
    (r"(agar|culture|stain|gram|biopsy).*show", QuestionType.LAB_FINDING),
    (r"(laboratory|lab).*(result|finding|value)", QuestionType.LAB_FINDING),

    # Pharmacology
    (r"drug.*(target|mechanism|receptor|inhibit)", QuestionType.PHARMACOLOGY),
    (r"(target|act on|bind).*(receptor|enzyme|channel)", QuestionType.PHARMACOLOGY),

    # Epidemiology
    (r"(risk factor|prevalence|incidence|odds ratio|relative risk)", QuestionType.EPIDEMIOLOGY),
    (r"most (common|frequent).*(cause|risk|complication)", QuestionType.EPIDEMIOLOGY),

    # Anatomy
    (r"(structure|nerve|artery|vein|muscle|ligament).*(damaged|injured|affected|involved)", QuestionType.ANATOMY),

    # Ethics (refine: "most appropriate action" in context of disclosure, consent, etc.)
    (r"(tell|inform|disclose|report|consent|refuse|autonomy|confidentiality)", QuestionType.ETHICS),
]


def classify_question(case: ValidationCase) -> QuestionType:
    """
    Classify a MedQA question by type using heuristics on the question stem.

    Looks at metadata["question_stem"] first, falls back to ground_truth["full_question"].

    Returns:
        QuestionType enum value
    """
    stem = case.metadata.get("question_stem", "")
    full_q = case.ground_truth.get("full_question", case.input_text)

    # Classify on stem first (more specific), then full question
    for text in [stem, full_q]:
        text_lower = text.lower()
        for pattern, qtype in _STEM_PATTERNS:
            if re.search(pattern, text_lower):
                return qtype

    return QuestionType.OTHER


def classify_question_from_text(question_text: str) -> QuestionType:
    """
    Classify a raw question string (no ValidationCase needed).
    Useful for ad-hoc classification.
    """
    text_lower = question_text.lower()
    for pattern, qtype in _STEM_PATTERNS:
        if re.search(pattern, text_lower):
            return qtype
    return QuestionType.OTHER


# Convenience: which types are "pipeline-appropriate"?
DIAGNOSTIC_TYPES = {QuestionType.DIAGNOSTIC}
PIPELINE_APPROPRIATE_TYPES = {
    QuestionType.DIAGNOSTIC,
    QuestionType.TREATMENT,
    QuestionType.LAB_FINDING,
}

Integration point

In fetch_medqa(), after building each case, classify it:

from validation.question_classifier import classify_question

# After creating the ValidationCase:
case.metadata["question_type"] = classify_question(case).value

Tests

def test_diagnostic_classification():
    case = make_case(question="...What is the most likely diagnosis?")
    assert classify_question(case) == QuestionType.DIAGNOSTIC

def test_treatment_classification():
    case = make_case(question="...What is the most appropriate next step in management?")
    assert classify_question(case) == QuestionType.TREATMENT

def test_mechanism_classification():
    case = make_case(question="...mechanism of action...")
    assert classify_question(case) == QuestionType.MECHANISM

def test_ethics_override():
    # "most appropriate action" + disclosure keywords β†’ ethics, not treatment
    case = make_case(question="...Tell the attending that he cannot fail to disclose this mistake. What is the most appropriate action?")
    assert classify_question(case) == QuestionType.ETHICS

Note on ethics override: The pattern order matters. "most appropriate action" will match TREATMENT first. To handle ethics, we need the ethics patterns to check for disclosure/consent keywords in the answer or full question context. The current design checks patterns in order β€” put ethics keyword patterns before the generic "most appropriate action" treatment pattern, OR do a two-pass: first check for ethics keywords, then fall through to treatment.

Decision: Use a two-pass approach. If the question contains ethics keywords AND a treatment-like stem, classify as ETHICS. Otherwise classify as TREATMENT. Implement this in classify_question() with a special-case check.


Step 4: P4 β€” Question-Type-Aware Scoring

File: src/backend/validation/base.py (new function) + src/backend/validation/harness_medqa.py (refactor scoring block)
Depends on: P5 (correct fuzzy_match), P1 (question_type in metadata)
Depended on by: P7 (stratified reporting)

Problem

diagnosis_in_differential() always searches the same fields in the same order regardless of question type. Treatment answers get looked up in the differential (wrong place), and mechanism answers get looked up everywhere (unlikely to match).

Design: score_case() dispatcher

# In base.py β€” new function alongside diagnosis_in_differential()

def score_case(
    target_answer: str,
    report: CDSReport,
    question_type: str = "diagnostic",
    reasoning_result: Optional[ClinicalReasoningResult] = None,
) -> dict[str, float]:
    """
    Score a case based on its question type.

    Returns a dict of metric_name β†’ score (0.0 or 1.0).
    Always includes: "matched", "match_location", "match_rank"
    Plus type-specific metrics.
    """
    qt = question_type.lower()

    if qt == "diagnostic":
        return _score_diagnostic(target_answer, report)
    elif qt == "treatment":
        return _score_treatment(target_answer, report)
    elif qt == "mechanism":
        return _score_mechanism(target_answer, report, reasoning_result)
    elif qt == "lab_finding":
        return _score_lab_finding(target_answer, report, reasoning_result)
    else:
        return _score_generic(target_answer, report, reasoning_result)

Per-type scorers

def _score_diagnostic(target: str, report: CDSReport) -> dict:
    """Score a diagnostic question β€” primary field is differential_diagnosis."""
    found_top1, r1, l1 = diagnosis_in_differential(target, report, top_n=1)
    found_top3, r3, l3 = diagnosis_in_differential(target, report, top_n=3)
    found_any, ra, la = diagnosis_in_differential(target, report)

    return {
        "top1_accuracy": 1.0 if found_top1 else 0.0,
        "top3_accuracy": 1.0 if found_top3 else 0.0,
        "mentioned_accuracy": 1.0 if found_any else 0.0,
        "differential_accuracy": 1.0 if (found_any and la == "differential") else 0.0,
        "match_location": la,
        "match_rank": ra,
    }


def _score_treatment(target: str, report: CDSReport) -> dict:
    """Score a treatment question β€” primary fields are next_steps + recommendations."""
    # Check suggested_next_steps first (most specific)
    for i, action in enumerate(report.suggested_next_steps):
        if fuzzy_match(action.action, target):
            return {
                "top1_accuracy": 1.0 if i == 0 else 0.0,
                "top3_accuracy": 1.0 if i < 3 else 0.0,
                "mentioned_accuracy": 1.0,
                "match_location": "next_steps",
                "match_rank": i,
            }

    # Check guideline_recommendations
    for i, rec in enumerate(report.guideline_recommendations):
        if fuzzy_match(rec, target):
            return {
                "top1_accuracy": 0.0,  # Not in primary slot
                "top3_accuracy": 0.0,
                "mentioned_accuracy": 1.0,
                "match_location": "recommendations",
                "match_rank": i,
            }

    # Check differential reasoning text (treatment may appear in reasoning)
    for dx in report.differential_diagnosis:
        if fuzzy_match(dx.reasoning, target, threshold=0.3):
            return {
                "top1_accuracy": 0.0,
                "top3_accuracy": 0.0,
                "mentioned_accuracy": 1.0,
                "match_location": "reasoning_text",
                "match_rank": -1,
            }

    # Fulltext fallback
    full_text = _build_fulltext(report)
    if fuzzy_match(full_text, target, threshold=0.3):
        return {
            "top1_accuracy": 0.0,
            "top3_accuracy": 0.0,
            "mentioned_accuracy": 1.0,
            "match_location": "fulltext",
            "match_rank": -1,
        }

    return _not_found()


def _score_mechanism(
    target: str, report: CDSReport,
    reasoning_result: Optional[ClinicalReasoningResult] = None,
) -> dict:
    """Score a mechanism question β€” primary field is reasoning_chain."""
    # Check reasoning chain from clinical reasoning step
    if reasoning_result and reasoning_result.reasoning_chain:
        if fuzzy_match(reasoning_result.reasoning_chain, target, threshold=0.3):
            return {
                "top1_accuracy": 0.0,
                "top3_accuracy": 0.0,
                "mentioned_accuracy": 1.0,
                "match_location": "reasoning_chain",
                "match_rank": -1,
            }

    # Check differential reasoning text
    for dx in report.differential_diagnosis:
        if fuzzy_match(dx.reasoning, target, threshold=0.3):
            return {
                "top1_accuracy": 0.0,
                "top3_accuracy": 0.0,
                "mentioned_accuracy": 1.0,
                "match_location": "differential_reasoning",
                "match_rank": -1,
            }

    # Fulltext fallback
    full_text = _build_fulltext(report)
    if fuzzy_match(full_text, target, threshold=0.3):
        return {
            "top1_accuracy": 0.0,
            "top3_accuracy": 0.0,
            "mentioned_accuracy": 1.0,
            "match_location": "fulltext",
            "match_rank": -1,
        }

    return _not_found()


def _score_lab_finding(
    target: str, report: CDSReport,
    reasoning_result: Optional[ClinicalReasoningResult] = None,
) -> dict:
    """Score a lab/finding question β€” primary field is recommended_workup."""
    # Check recommended workup
    if reasoning_result:
        for i, action in enumerate(reasoning_result.recommended_workup):
            if fuzzy_match(action.action, target, threshold=0.4):
                return {
                    "top1_accuracy": 1.0 if i == 0 else 0.0,
                    "top3_accuracy": 1.0 if i < 3 else 0.0,
                    "mentioned_accuracy": 1.0,
                    "match_location": "recommended_workup",
                    "match_rank": i,
                }

    # Check next steps in final report
    for i, action in enumerate(report.suggested_next_steps):
        if fuzzy_match(action.action, target, threshold=0.4):
            return {
                "top1_accuracy": 0.0,
                "top3_accuracy": 0.0,
                "mentioned_accuracy": 1.0,
                "match_location": "next_steps",
                "match_rank": i,
            }

    # Fulltext fallback
    full_text = _build_fulltext(report)
    if fuzzy_match(full_text, target, threshold=0.3):
        return {
            "top1_accuracy": 0.0,
            "top3_accuracy": 0.0,
            "mentioned_accuracy": 1.0,
            "match_location": "fulltext",
            "match_rank": -1,
        }

    return _not_found()


def _score_generic(
    target: str, report: CDSReport,
    reasoning_result: Optional[ClinicalReasoningResult] = None,
) -> dict:
    """Score any question type β€” searches all fields broadly."""
    # Try all specific scorers, return first hit
    for scorer in [_score_diagnostic, _score_treatment]:
        result = scorer(target, report)
        if result.get("mentioned_accuracy", 0.0) > 0.0:
            return result

    if reasoning_result:
        result = _score_mechanism(target, report, reasoning_result)
        if result.get("mentioned_accuracy", 0.0) > 0.0:
            return result

    return _not_found()


def _build_fulltext(report: CDSReport) -> str:
    """Concatenate all report fields into a single searchable string."""
    return " ".join([
        report.patient_summary or "",
        " ".join(report.guideline_recommendations),
        " ".join(a.action for a in report.suggested_next_steps),
        " ".join(dx.diagnosis + " " + dx.reasoning for dx in report.differential_diagnosis),
        " ".join(report.sources_cited),
        " ".join(c.description for c in report.conflicts),
    ])


def _not_found() -> dict:
    return {
        "top1_accuracy": 0.0,
        "top3_accuracy": 0.0,
        "mentioned_accuracy": 0.0,
        "match_location": "not_found",
        "match_rank": -1,
    }

Integration in harness_medqa.py

Replace the scoring block (lines ~242-290) in validate_medqa():

# OLD:
#   found_top1, rank1, loc1 = diagnosis_in_differential(correct_answer, report, top_n=1)
#   ...etc...

# NEW:
question_type = case.metadata.get("question_type", "other")
scores = score_case(
    target_answer=correct_answer,
    report=report,
    question_type=question_type,
    reasoning_result=state.clinical_reasoning if state else None,
)
# Extract individual metrics from the dict
scores["parse_success"] = 1.0

Key interface

  • score_case() returns dict[str, float] β€” always includes top1_accuracy, top3_accuracy, mentioned_accuracy, match_location, match_rank
  • The harness doesn't need to know about question type internals β€” just passes the string through
  • diagnosis_in_differential() is NOT removed β€” it's still used internally by _score_diagnostic() and as a utility

Step 5: P6 β€” MCQ Answer-Selection Step

File: src/backend/validation/harness_medqa.py (new function + integration)
Depends on: P3 (question stem + options stored in metadata/ground_truth)
Depended on by: P7 (reporting), but can be integrated independently

Design

After the pipeline generates its report, present MedGemma with the original question + answer choices + the pipeline's analysis, and ask it to select the best answer choice.

# In harness_medqa.py β€” new function

from app.services.medgemma import MedGemmaService


MCQ_SELECTION_PROMPT = """You are a medical expert taking a USMLE-style exam.

You have already performed a thorough clinical analysis of this case.
Now, based on your analysis, select the single best answer from the choices below.

CLINICAL VIGNETTE:
{vignette}

QUESTION:
{question_stem}

YOUR CLINICAL ANALYSIS:
- Top diagnoses: {top_diagnoses}
- Key reasoning: {reasoning_summary}
- Recommended next steps: {next_steps}
- Guideline recommendations: {recommendations}

ANSWER CHOICES:
{formatted_options}

Based on your clinical analysis above, which answer choice (A, B, C, or D)
is BEST supported? Reply with ONLY the letter (A, B, C, or D) and a one-sentence justification.

Format: X) Justification"""


async def select_mcq_answer(
    case: ValidationCase,
    report: CDSReport,
    state: Optional[AgentState] = None,
) -> tuple[str, str]:
    """
    Use MedGemma to select the best MCQ answer given the pipeline's analysis.

    Args:
        case: The validation case (must have options in ground_truth)
        report: The CDS pipeline output
        state: Full agent state (for reasoning_chain access)

    Returns:
        (selected_letter, justification) β€” e.g. ("B", "Consistent with...")
    """
    options = case.ground_truth.get("options", {})
    if not options:
        return "", "No options available"

    # Format options
    if isinstance(options, dict):
        formatted = "\n".join(f"{k}) {v}" for k, v in sorted(options.items()))
    else:
        formatted = "\n".join(
            f"{chr(65+i)}) {v}" for i, v in enumerate(options)
        )

    # Build context from report
    top_dx = [dx.diagnosis for dx in report.differential_diagnosis[:3]]
    reasoning = ""
    if state and state.clinical_reasoning:
        reasoning = state.clinical_reasoning.reasoning_chain[:500]
    next_steps = [a.action for a in report.suggested_next_steps[:3]]
    recommendations = report.guideline_recommendations[:3]

    vignette = case.metadata.get("clinical_vignette", case.input_text)
    stem = case.metadata.get("question_stem", "")

    prompt = MCQ_SELECTION_PROMPT.format(
        vignette=vignette[:1000],
        question_stem=stem or "Based on the clinical presentation, select the best answer.",
        top_diagnoses=", ".join(top_dx) if top_dx else "None generated",
        reasoning_summary=reasoning[:500] if reasoning else "Not available",
        next_steps=", ".join(next_steps) if next_steps else "None",
        recommendations=", ".join(recommendations) if recommendations else "None",
        formatted_options=formatted,
    )

    service = MedGemmaService()
    raw = await service.generate(
        prompt=prompt,
        system_prompt="You are a medical expert. Select the single best answer.",
        max_tokens=100,
        temperature=0.1,
    )

    # Parse response β€” look for a letter A-D
    selected = ""
    justification = raw.strip()
    for char in raw.strip()[:5]:
        if char.upper() in "ABCD":
            selected = char.upper()
            break

    return selected, justification


def score_mcq_selection(
    selected_letter: str,
    correct_idx: str,
) -> float:
    """Return 1.0 if selected matches correct, else 0.0."""
    return 1.0 if selected_letter.upper() == correct_idx.upper() else 0.0

Integration in validate_medqa()

After the existing scoring block, add:

# MCQ selection (optional additional scoring)
if report and case.ground_truth.get("options"):
    try:
        selected, justification = await select_mcq_answer(case, report, state)
        scores["mcq_accuracy"] = score_mcq_selection(
            selected, case.ground_truth["answer_idx"]
        )
        details["mcq_selected"] = selected
        details["mcq_justification"] = justification
        details["mcq_correct"] = case.ground_truth["answer_idx"]
    except Exception as e:
        logger.warning(f"MCQ selection failed: {e}")
        scores["mcq_accuracy"] = 0.0

Cost consideration

This adds 1 extra MedGemma call per case (~100 tokens output). For 50 cases, that's ~5,000 extra output tokens β€” negligible cost (<$0.10).

Key interface

  • select_mcq_answer() is self-contained β€” can be called or skipped
  • Adds mcq_accuracy to the scores dict
  • Does NOT change any existing score calculations

Step 6: P7 β€” Stratified Reporting

File: src/backend/validation/base.py (modify print_summary, save_results)

  • src/backend/validation/harness_medqa.py (modify aggregation block)
    Depends on: P1 (question types), P4 (per-type scores)
    Depended on by: Nothing (terminal node)

Changes to summary aggregation in validate_medqa()

# In validate_medqa() β€” replace the aggregation block at the end

    # Aggregate β€” overall
    total = len(results)
    successful = sum(1 for r in results if r.success)

    metric_names = [
        "top1_accuracy", "top3_accuracy", "mentioned_accuracy",
        "differential_accuracy", "parse_success", "mcq_accuracy",
    ]
    metrics = {}
    for m in metric_names:
        values = [r.scores.get(m, 0.0) for r in results if m in r.scores]
        metrics[m] = sum(values) / len(values) if values else 0.0

    # Average pipeline time
    times = [r.pipeline_time_ms for r in results if r.success]
    metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0

    # ── Stratified metrics ──
    from validation.question_classifier import QuestionType, PIPELINE_APPROPRIATE_TYPES

    # Group results by question type
    by_type: dict[str, list[ValidationResult]] = {}
    for r in results:
        qt = r.details.get("question_type", "other")
        by_type.setdefault(qt, []).append(r)

    # Per-type metrics
    for qt, type_results in by_type.items():
        n = len(type_results)
        metrics[f"count_{qt}"] = n
        for m in ["top1_accuracy", "top3_accuracy", "mentioned_accuracy", "mcq_accuracy"]:
            values = [r.scores.get(m, 0.0) for r in type_results if m in r.scores]
            if values:
                metrics[f"{m}_{qt}"] = sum(values) / len(values)

    # Pipeline-appropriate subset
    appropriate_results = [
        r for r in results
        if r.details.get("question_type", "other") in {t.value for t in PIPELINE_APPROPRIATE_TYPES}
    ]
    if appropriate_results:
        for m in ["top1_accuracy", "top3_accuracy", "mentioned_accuracy"]:
            values = [r.scores.get(m, 0.0) for r in appropriate_results]
            metrics[f"{m}_pipeline_appropriate"] = sum(values) / len(values) if values else 0.0
        metrics["count_pipeline_appropriate"] = len(appropriate_results)

Changes to print_summary()

# In base.py β€” enhanced print_summary()

def print_summary(summary: ValidationSummary):
    """Pretty-print validation results to console."""
    print(f"\n{'='*60}")
    print(f"  Validation Results: {summary.dataset.upper()}")
    print(f"{'='*60}")
    print(f"  Total cases:      {summary.total_cases}")
    print(f"  Successful:       {summary.successful_cases}")
    print(f"  Failed:           {summary.failed_cases}")
    print(f"  Duration:         {summary.run_duration_sec:.1f}s")

    # Overall metrics (exclude per-type and count metrics)
    print(f"\n  Overall Metrics:")
    for metric, value in sorted(summary.metrics.items()):
        if "_" in metric and any(metric.endswith(f"_{qt}") for qt in
            ["diagnostic", "treatment", "mechanism", "lab_finding",
             "pharmacology", "epidemiology", "ethics", "anatomy", "other",
             "pipeline_appropriate"]):
            continue  # Print these in stratified section
        if metric.startswith("count_"):
            continue
        if "time" in metric and isinstance(value, (int, float)):
            print(f"    {metric:35s} {value:.0f}ms")
        elif isinstance(value, float):
            print(f"    {metric:35s} {value:.1%}")
        else:
            print(f"    {metric:35s} {value}")

    # Stratified metrics
    type_keys = sorted(set(
        k.rsplit("_", 1)[-1] for k in summary.metrics
        if k.startswith("count_") and k != "count_pipeline_appropriate"
    ))
    if type_keys:
        print(f"\n  By Question Type:")
        print(f"    {'Type':15s} {'Count':>6s} {'Top-1':>7s} {'Top-3':>7s} {'Mentioned':>10s} {'MCQ':>7s}")
        print(f"    {'-'*15} {'-'*6} {'-'*7} {'-'*7} {'-'*10} {'-'*7}")
        for qt in type_keys:
            count = summary.metrics.get(f"count_{qt}", 0)
            t1 = summary.metrics.get(f"top1_accuracy_{qt}", None)
            t3 = summary.metrics.get(f"top3_accuracy_{qt}", None)
            ma = summary.metrics.get(f"mentioned_accuracy_{qt}", None)
            mcq = summary.metrics.get(f"mcq_accuracy_{qt}", None)
            print(f"    {qt:15s} {int(count):6d} "
                  f"{f'{t1:.0%}':>7s if t1 is not None else '   -   '} "
                  f"{f'{t3:.0%}':>7s if t3 is not None else '   -   '} "
                  f"{f'{ma:.0%}':>10s if ma is not None else '     -     '} "
                  f"{f'{mcq:.0%}':>7s if mcq is not None else '   -   '}")

    # Pipeline-appropriate subset
    pa_count = summary.metrics.get("count_pipeline_appropriate", 0)
    if pa_count > 0:
        print(f"\n  Pipeline-Appropriate Subset ({int(pa_count)} cases):")
        for m in ["top1_accuracy", "top3_accuracy", "mentioned_accuracy"]:
            v = summary.metrics.get(f"{m}_pipeline_appropriate")
            if v is not None:
                print(f"    {m:35s} {v:.1%}")

    print(f"{'='*60}\n")

Key interface

  • ValidationSummary.metrics dict gains new keys with _{question_type} suffixes
  • save_results() doesn't need changes β€” it serializes metrics as-is
  • Console output is richer but backward-compatible (old scripts parsing the JSON still see all the original keys)

Step 7: P2 β€” Multi-Mode Pipeline (Large β€” Future)

Files: src/backend/app/agent/orchestrator.py, src/backend/app/tools/clinical_reasoning.py, src/backend/app/models/schemas.py
Depends on: P1 (question type routing into the pipeline), P3 (question stem passed to pipeline)
Depended on by: Nothing (this is the final architectural evolution)

Overview

This is the biggest change and should be done LAST. It modifies the production pipeline, not just the validation framework.

7a. Add question_context to CaseSubmission

# In schemas.py β€” extend CaseSubmission

class CaseSubmission(BaseModel):
    patient_text: str = Field(..., min_length=10)
    include_drug_check: bool = Field(True)
    include_guidelines: bool = Field(True)
    question_context: Optional[str] = Field(
        None,
        description="The clinical question being asked (e.g., 'What is the most likely diagnosis?'). "
                    "If provided, the pipeline adapts its reasoning mode.",
    )
    question_type: Optional[str] = Field(
        None,
        description="Pre-classified question type: diagnostic, treatment, mechanism, etc.",
    )

7b. Mode-specific system prompts in clinical_reasoning.py

# Replace single SYSTEM_PROMPT with a dict:

SYSTEM_PROMPTS = {
    "diagnostic": """You are an expert clinical reasoning assistant...
    [existing diagnostic prompt β€” mostly unchanged]""",

    "treatment": """You are an expert clinical management assistant...
    Given a structured patient profile and clinical question, recommend the
    most appropriate treatment or next step in management.
    Focus on: evidence-based treatment guidelines, patient-specific factors,
    contraindications, and prioritized management steps.
    Generate a ranked list of management options (not diagnoses)...""",

    "mechanism": """You are an expert in medical pathophysiology...
    Given a clinical scenario, explain the underlying mechanism,
    pathophysiology, or pharmacological principle being tested.
    Focus on: molecular/cellular mechanism, physiological pathways,
    drug mechanisms of action...""",

    "default": """[existing SYSTEM_PROMPT as fallback]""",
}

7c. Extend clinical reasoning output model

# In schemas.py β€” new model for non-diagnostic reasoning

class ClinicalAnalysisResult(BaseModel):
    """Flexible clinical analysis output that adapts to question type."""
    analysis_mode: str = Field("diagnostic", description="What type of analysis was performed")
    differential_diagnosis: List[DiagnosisCandidate] = Field(default_factory=list)
    management_options: List[RecommendedAction] = Field(default_factory=list)
    mechanism_explanation: str = Field("", description="Pathophysiology/mechanism explanation")
    recommended_workup: List[RecommendedAction] = Field(default_factory=list)
    reasoning_chain: str = Field("")
    risk_assessment: Optional[str] = None
    direct_answer: Optional[str] = Field(
        None,
        description="Direct answer to the clinical question (when applicable)",
    )

7d. Orchestrator routing

# In orchestrator.py β€” _step_reason() adapts based on question type

async def _step_reason(self):
    question_type = self._case.question_type or "diagnostic"
    result = await self.clinical_reasoning.run(
        self._state.patient_profile,
        mode=question_type,
    )
    ...

Scope warning

This is a multi-file, multi-model refactor. Do it only after Steps 1-6 are working and validated. The validation improvements (Steps 1-6) will already give us honest metrics; Step 7 is about actually improving the pipeline's ability to handle non-diagnostic questions.


Testing Strategy

Unit tests (no LLM calls needed)

Test file What it tests
test_fuzzy_match.py P5: fuzzy_match with short/long targets, edge cases
test_question_classifier.py P1: classification accuracy on known questions
test_split_question.py P3: vignette/stem separation on real MedQA samples
test_score_case.py P4: type-aware scoring with mock CDSReport objects

Integration tests (need LLM endpoint)

Test What it tests Cost
3-case smoke test with MCQ P6: MCQ selection works ~$0.50
10-case run with stratified reporting P7: reporting output is correct ~$2.00
50-case full run with all fixes All: end-to-end accuracy comparison ~$5.00

Comparison protocol

Run 50-case MedQA (seed=42) twice:

  1. Before: Current code (baseline: 36% top-1, 38% mentioned)
  2. After: All fixes applied

Compare:

  • Overall accuracy (should be similar or slightly higher)
  • Diagnostic-only accuracy (should be similar β€” same pipeline, better matching)
  • MCQ accuracy (expected 60-70%+ β€” this is the big win)
  • Pipeline-appropriate accuracy (expected higher than overall)
  • Stratified breakdown by question type

File Change Summary

File Changes Step
validation/base.py Rewrite fuzzy_match(), add _content_tokens(), _MEDICAL_STOPWORDS. Add score_case() and per-type scorers. Modify print_summary(). P5, P4, P7
validation/harness_medqa.py Replace _extract_vignette() with _split_question(). Update fetch_medqa() metadata. Refactor scoring block to use score_case(). Add select_mcq_answer(). Update aggregation. P3, P4, P6, P7
validation/question_classifier.py NEW FILE. QuestionType enum, classify_question(), _STEM_PATTERNS. P1
app/models/schemas.py Add question_context, question_type to CaseSubmission. Add ClinicalAnalysisResult. P2 (Step 7 only)
app/tools/clinical_reasoning.py Add mode-specific system prompts. Accept mode param. P2 (Step 7 only)
app/agent/orchestrator.py Route reasoning step based on question type. P2 (Step 7 only)

Steps 1-6 touch only validation code. The production pipeline is unchanged until Step 7.