TrialPath / docs /tdd-guide-backend-service.md
yakilee's picture
chore: initialize project skeleton with pyproject.toml
1abff4e

TrialPath Backend Service TDD-Ready Implementation Guide

Architecture Decisions:

  • Parlant runs as independent service (parlant-server --gemini), FE communicates via REST API
  • MedGemma: HF Inference Endpoint (MedGemmaCloudExtractor) β€” no local GPU required for PoC
  • Doctor packet export: JSON + Markdown β€” no PDF generation module needed
  • Cost budget: $0.50/session enforced by GeminiCostTracker

1. Architecture Overview

1.1 Component Relationship Diagram

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                   UI & Orchestrator (Streamlit/FastAPI)       β”‚
β”‚                   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”            β”‚
β”‚                   β”‚   Parlant Engine (Embedded)  β”‚            β”‚
β”‚                   β”‚   Agent: patient_trial_copilotβ”‚           β”‚
β”‚                   β”‚   Journey: 5-state workflow   β”‚           β”‚
β”‚                   β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜            β”‚
β”‚                        β”‚    β”‚    β”‚                            β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                         β”‚    β”‚    β”‚
              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚    └──────────┐
              β–Ό               β–Ό               β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  MedGemma 4B    β”‚ β”‚  Gemini 3 Pro  β”‚ β”‚  ClinicalTrials MCP  β”‚
β”‚  (HF Endpoint)  β”‚ β”‚  (Google AI)   β”‚ β”‚  Server (Existing)   β”‚
β”‚                 β”‚ β”‚                β”‚ β”‚                      β”‚
β”‚ - PDF/Image     β”‚ β”‚ - SearchAnchorsβ”‚ β”‚ - search_studies     β”‚
β”‚   extraction    β”‚ β”‚   generation   β”‚ β”‚ - get_study          β”‚
β”‚ - Criterion     β”‚ β”‚ - Reranking    β”‚ β”‚ - find_eligible      β”‚
β”‚   evaluation    β”‚ β”‚ - Gap analysis β”‚ β”‚ - analyze_trends     β”‚
β”‚                 β”‚ β”‚ - Structured   β”‚ β”‚ - compare_studies    β”‚
β”‚ Output:         β”‚ β”‚   output (JSON)β”‚ β”‚                      β”‚
β”‚ PatientProfile  β”‚ β”‚                β”‚ β”‚ Output:              β”‚
β”‚ + evidence spansβ”‚ β”‚ Output:        β”‚ β”‚ TrialCandidate[]     β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ SearchAnchors  β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                    β”‚ EligibilityLedgerβ”‚
                    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 Module Structure

trialpath/
β”œβ”€β”€ __init__.py
β”œβ”€β”€ config.py                  # Environment & API key configuration
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ patient_profile.py     # PatientProfile v1 Pydantic model
β”‚   β”œβ”€β”€ search_anchors.py      # SearchAnchors v1 Pydantic model
β”‚   β”œβ”€β”€ trial_candidate.py     # TrialCandidate v1 Pydantic model
β”‚   β”œβ”€β”€ eligibility_ledger.py  # EligibilityLedger v1 Pydantic model
β”‚   └── search_log.py          # SearchLog v1 β€” iterative query refinement tracking
β”œβ”€β”€ services/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ medgemma_extractor.py  # MedGemma HF endpoint integration
β”‚   β”œβ”€β”€ gemini_planner.py      # Gemini structured output + function calling
β”‚   └── mcp_client.py          # ClinicalTrials MCP tool wrapper
β”œβ”€β”€ agent/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ setup.py               # Parlant agent + journey definition
β”‚   β”œβ”€β”€ tools.py               # Parlant tool registrations
β”‚   └── guidelines.py          # Guideline definitions per state
β”œβ”€β”€ orchestrator.py            # Main app entry point
└── tests/
    β”œβ”€β”€ __init__.py
    β”œβ”€β”€ test_models.py
    β”œβ”€β”€ test_medgemma.py
    β”œβ”€β”€ test_gemini.py
    β”œβ”€β”€ test_mcp.py
    β”œβ”€β”€ test_parlant_agent.py
    β”œβ”€β”€ test_journey.py
    └── test_e2e.py

1.3 Dependency Map

parlant[gemini]          # Parlant framework + Gemini NLP provider
google-genai             # Google GenAI Python SDK (Gemini 3 Pro)
transformers>=4.50.0     # HuggingFace Transformers (MedGemma)
accelerate               # Model loading for MedGemma
torch                    # PyTorch backend
Pillow                   # Image processing
pydantic>=2.0            # Data contracts / validation
httpx                    # HTTP client for MCP server
fastapi                  # API server (optional)
streamlit                # UI layer
pytest                   # Testing
pytest-asyncio           # Async test support

2. Parlant Workflow Guide

2.1 Agent Definition

Parlant uses a Python SDK to define agents. The core entry point is p.Server, which manages the Parlant engine lifecycle.

Agent Creation Code:

import parlant.sdk as p
from parlant.sdk import NLPServices
import asyncio

async def setup_agent():
    """Create and configure the patient_trial_copilot agent."""
    async with p.Server(
        nlp_service=NLPServices.gemini,    # Use Gemini as NLP backend
        session_store="local",              # Persist sessions to disk
    ) as server:
        agent = await server.create_agent(
            id="patient-trial-copilot",
            name="Patient Trial Copilot",
            description=(
                "An AI copilot that helps NSCLC patients understand which "
                "clinical trials they may qualify for. Transforms rejection "
                "into actionable next steps via gap analysis."
            ),
            max_engine_iterations=10,
        )
        return server, agent

Key Agent Attributes:

  • id -- custom string identifier (auto-generated if omitted)
  • name -- display name
  • description -- agent purpose (influences LLM behavior)
  • max_engine_iterations -- processing limit per request
  • composition_mode -- response style: FLUID (default, natural), CANNED_FLUID, CANNED_COMPOSITED, CANNED_STRICT

Agent Management API:

# List all agents
agents = await server.list_agents()

# Find a specific agent
agent = await server.find_agent(id="patient-trial-copilot")

# REST API equivalent
# POST /agents  -- create
# GET  /agents  -- list
# PATCH /agents/{id}  -- update
# DELETE /agents/{id} -- delete

2.2 Tool Registration

Parlant supports four types of tool services: sdk (Python plugins), openapi, local, and mcp. For TrialPath, we primarily use the sdk type.

Tool Definition Pattern:

from parlant.sdk import tool, ToolContext, ToolResult
from trialpath.models.patient_profile import PatientProfile
from trialpath.models.search_anchors import SearchAnchors
import json

@tool
async def extract_patient_profile(
    context: ToolContext,
    document_urls: list[str],
    metadata: str,
) -> ToolResult:
    """
    Calls MedGemma to extract a PatientProfile from uploaded documents.

    Args:
        context: Parlant tool context with session info.
        document_urls: List of URLs/paths to uploaded PDFs/images.
        metadata: JSON string with basic patient metadata (age, sex, location).

    Returns:
        ToolResult containing the extracted PatientProfile as JSON.
    """
    from trialpath.services.medgemma_extractor import MedGemmaExtractor

    extractor = MedGemmaExtractor()
    meta = json.loads(metadata)
    profile = await extractor.extract(document_urls, meta)

    return ToolResult(
        data=profile.model_dump(),
        metadata={"source": "medgemma-4b-it", "doc_count": len(document_urls)},
    )


@tool
async def search_clinical_trials(
    context: ToolContext,
    search_anchors_json: str,
) -> ToolResult:
    """
    Searches ClinicalTrials.gov via MCP server using SearchAnchors.

    Args:
        context: Parlant tool context.
        search_anchors_json: JSON string of SearchAnchors v1.

    Returns:
        ToolResult with list of TrialCandidate objects.
    """
    from trialpath.services.mcp_client import ClinicalTrialsMCPClient

    anchors = SearchAnchors.model_validate_json(search_anchors_json)
    client = ClinicalTrialsMCPClient()
    candidates = await client.search(anchors)

    return ToolResult(
        data=[c.model_dump() for c in candidates],
        metadata={"total_results": len(candidates)},
    )


@tool
async def refine_search_query(
    context: ToolContext,
    search_anchors_json: str,
    result_count: int,
    search_log_json: str,
) -> ToolResult:
    """
    Refines search parameters when too many results are returned (>50).
    Adds more specific filters such as phase restrictions or additional keywords
    to narrow the result set.

    Args:
        context: Parlant tool context.
        search_anchors_json: JSON string of current SearchAnchors v1.
        result_count: Number of results from the previous search.
        search_log_json: JSON string of SearchLog tracking refinement history.

    Returns:
        ToolResult with updated SearchAnchors JSON (more restrictive filters).
    """
    from trialpath.services.gemini_planner import GeminiPlanner
    from trialpath.models.search_log import SearchLog, RefinementAction

    anchors = SearchAnchors.model_validate_json(search_anchors_json)
    search_log = SearchLog.model_validate_json(search_log_json)
    planner = GeminiPlanner()

    refined_anchors = await planner.refine_search(anchors, result_count, search_log)

    search_log.add_step(
        query_params=refined_anchors.model_dump(),
        result_count=result_count,
        action=RefinementAction.REFINE,
        reason=f"Too many results ({result_count} > 50), adding filters",
    )

    return ToolResult(
        data={
            "search_anchors": refined_anchors.model_dump(),
            "search_log": search_log.model_dump(),
        },
        metadata={"refinement_round": search_log.total_refinement_rounds},
    )


@tool
async def relax_search_query(
    context: ToolContext,
    search_anchors_json: str,
    result_count: int,
    search_log_json: str,
) -> ToolResult:
    """
    Relaxes search parameters when zero results are returned.
    Removes restrictive filters such as location or phase constraints
    to broaden the result set.

    Args:
        context: Parlant tool context.
        search_anchors_json: JSON string of current SearchAnchors v1.
        result_count: Number of results from the previous search (expected 0).
        search_log_json: JSON string of SearchLog tracking refinement history.

    Returns:
        ToolResult with updated SearchAnchors JSON (fewer filters applied).
    """
    from trialpath.services.gemini_planner import GeminiPlanner
    from trialpath.models.search_log import SearchLog, RefinementAction

    anchors = SearchAnchors.model_validate_json(search_anchors_json)
    search_log = SearchLog.model_validate_json(search_log_json)
    planner = GeminiPlanner()

    relaxed_anchors = await planner.relax_search(anchors, result_count, search_log)

    search_log.add_step(
        query_params=relaxed_anchors.model_dump(),
        result_count=result_count,
        action=RefinementAction.RELAX,
        reason=f"No results ({result_count} == 0), removing filters",
    )

    return ToolResult(
        data={
            "search_anchors": relaxed_anchors.model_dump(),
            "search_log": search_log.model_dump(),
        },
        metadata={"refinement_round": search_log.total_refinement_rounds},
    )


@tool
async def generate_search_anchors(
    context: ToolContext,
    patient_profile_json: str,
) -> ToolResult:
    """
    Uses Gemini to generate SearchAnchors from a PatientProfile.

    Args:
        context: Parlant tool context.
        patient_profile_json: JSON string of PatientProfile v1.

    Returns:
        ToolResult with SearchAnchors JSON.
    """
    from trialpath.services.gemini_planner import GeminiPlanner

    profile = PatientProfile.model_validate_json(patient_profile_json)
    planner = GeminiPlanner()
    anchors = await planner.generate_search_anchors(profile)

    return ToolResult(data=anchors.model_dump())


@tool
async def evaluate_trial_eligibility(
    context: ToolContext,
    patient_profile_json: str,
    trial_candidate_json: str,
) -> ToolResult:
    """
    Evaluates criterion-by-criterion eligibility for a single trial using
    a dual-model approach:
      - MedGemma evaluates MEDICAL criteria: imaging interpretation, lab value
        extraction with units, biomarker matching, treatment history verification.
      - Gemini handles STRUCTURAL tasks: criterion text slicing, aggregation
        into overall assessment, gap identification, and non-medical criteria.

    Args:
        context: Parlant tool context.
        patient_profile_json: JSON string of PatientProfile v1.
        trial_candidate_json: JSON string of TrialCandidate v1.

    Returns:
        ToolResult with EligibilityLedger JSON.
    """
    from trialpath.services.gemini_planner import GeminiPlanner
    from trialpath.services.medgemma_extractor import MedGemmaExtractor

    MEDICAL_CRITERIA_TYPES = ["lab_value", "imaging", "biomarker", "treatment_history"]

    profile = PatientProfile.model_validate_json(patient_profile_json)
    trial = TrialCandidate.model_validate_json(trial_candidate_json)

    planner = GeminiPlanner()
    extractor = MedGemmaExtractor()

    # Step 1: Gemini slices the eligibility text into individual criteria
    criteria_list = await planner.slice_criteria(trial)

    # Step 2: Route each criterion to the appropriate model
    assessments = []
    for criterion in criteria_list:
        if criterion.get("type") in MEDICAL_CRITERIA_TYPES:
            # MedGemma evaluates medical criteria (labs, imaging, biomarkers)
            assessment = await extractor.evaluate_medical_criterion(
                criterion_text=criterion["text"],
                patient_profile=profile,
                evidence_docs=profile.source_docs,
            )
        else:
            # Gemini evaluates structural/non-medical criteria
            assessment = await planner.evaluate_structural_criterion(
                criterion_text=criterion["text"],
                patient_profile=profile,
            )
        assessments.append(assessment)

    # Step 3: Gemini aggregates into overall assessment and identifies gaps
    ledger = await planner.aggregate_assessments(profile, trial, assessments)

    return ToolResult(data=ledger.model_dump())


@tool
async def analyze_gaps(
    context: ToolContext,
    patient_profile_json: str,
    ledgers_json: str,
) -> ToolResult:
    """
    Synthesizes the minimal actionable set of missing data across all ledgers.

    Args:
        context: Parlant tool context.
        patient_profile_json: PatientProfile JSON.
        ledgers_json: JSON array of EligibilityLedger objects.

    Returns:
        ToolResult with gap analysis summary.
    """
    from trialpath.services.gemini_planner import GeminiPlanner

    profile = PatientProfile.model_validate_json(patient_profile_json)
    ledgers = [
        EligibilityLedger.model_validate(l)
        for l in json.loads(ledgers_json)
    ]

    planner = GeminiPlanner()
    gaps = await planner.analyze_gaps(profile, ledgers)

    return ToolResult(data=gaps)

Registering Tools with Plugin Server:

from parlant.core.services.tools.plugins import PluginServer

async def register_tools(service_registry):
    """Register all TrialPath tools with Parlant."""
    tools = [
        extract_patient_profile,
        search_clinical_trials,
        refine_search_query,
        relax_search_query,
        generate_search_anchors,
        evaluate_trial_eligibility,
        analyze_gaps,
    ]

    async with PluginServer(tools) as server:
        await service_registry.update_tool_service(
            name="trialpath_tools",
            kind="sdk",
            url=server.url,
        )

2.3 Guideline Configuration

Guidelines are "if-then" rules controlling agent behavior. They are matched dynamically to conversation context.

State-Specific Guidelines:

async def configure_guidelines(agent):
    """Define behavioral guidelines for each journey state."""

    # --- INGEST State Guidelines ---
    await agent.create_guideline(
        condition="User has uploaded documents or provided patient information",
        action=(
            "Call extract_patient_profile with the uploaded documents and metadata. "
            "Then summarize what was understood and what is missing. "
            "Explicitly list any items in the 'unknowns' field."
        ),
        tools=[extract_patient_profile],
    )

    await agent.create_guideline(
        condition="PatientProfile has been extracted but is missing critical fields (diagnosis, stage, or ECOG)",
        action=(
            "Ask the patient to upload additional documents such as a pathology report, "
            "discharge summary, or clinic letter that might contain the missing information."
        ),
    )

    # --- PRESCREEN State Guidelines ---
    await agent.create_guideline(
        condition="PatientProfile has sufficient data for prescreening (diagnosis + stage + ECOG present)",
        action=(
            "Call generate_search_anchors to create search parameters, then "
            "call search_clinical_trials to find matching trials. "
            "Summarize how many trials were found."
        ),
        tools=[generate_search_anchors, search_clinical_trials],
    )

    await agent.create_guideline(
        condition="Search returned more than 50 results",
        action=(
            "Call refine_search_query to add filters like phase or more specific keywords. "
            "Log each refinement step."
        ),
        tools=[refine_search_query],
    )

    await agent.create_guideline(
        condition="Search returned 0 results",
        action=(
            "Call relax_search_query to broaden criteria such as removing location or "
            "phase filters. Log the relaxation step."
        ),
        tools=[relax_search_query],
    )

    # --- VALIDATE_TRIALS State Guidelines ---
    await agent.create_guideline(
        condition="Clinical trials have been found and need eligibility validation",
        action=(
            "For each trial in the shortlist, call evaluate_trial_eligibility "
            "to generate a criterion-level eligibility ledger. "
            "Present results with traffic-light labels (green/yellow/red)."
        ),
        tools=[evaluate_trial_eligibility],
    )

    # --- GAP_FOLLOWUP State Guidelines ---
    await agent.create_guideline(
        condition="Eligibility ledgers show unknown criteria or gaps that could be resolved with additional data",
        action=(
            "Call analyze_gaps to identify the minimal set of missing data. "
            "Explain to the patient in simple language what additional tests or "
            "documents would improve their trial matching."
        ),
        tools=[analyze_gaps],
    )

    # --- SUMMARY State Guidelines ---
    await agent.create_guideline(
        condition="All validation is complete or user declines to provide additional data",
        action=(
            "Generate a patient-friendly summary with 3-5 bullet points about matches. "
            "Also generate a clinician packet with criterion-level evidence pointers. "
            "Offer to export results as JSON or Markdown."
        ),
    )

    # --- Global Guidelines ---
    await agent.create_guideline(
        condition="User asks about a specific trial by NCT ID",
        action=(
            "Look up the trial details and explain its key eligibility criteria "
            "in patient-friendly language."
        ),
    )

    await agent.create_guideline(
        condition="User seems confused or asks for help",
        action=(
            "Explain the current step in the matching process and what "
            "the patient can do next. Keep language simple and empathetic."
        ),
    )

    # --- Global: Medical Disclaimer (PRD Section 9) ---
    await agent.create_guideline(
        condition="Always, in every response to the patient",
        action=(
            "Include a reminder that this tool provides information only and is not "
            "medical advice. Do not recommend specific treatments. Always suggest "
            "discussing results with their healthcare provider. Use language like: "
            "'This information is for educational purposes only and should not replace "
            "professional medical advice.'"
        ),
    )

Guideline Priority and Coherence:

# If two guidelines conflict, set explicit priority
specific_guideline = await agent.create_guideline(
    condition="User asks about EGFR mutation specifically",
    action="Explain EGFR testing and its importance for NSCLC trial matching",
)
general_guideline = await agent.create_guideline(
    condition="User asks about biomarkers",
    action="Explain biomarkers generally",
)

# Specific overrides general
await specific_guideline.prioritize_over(general_guideline)

2.4 Journey (State Machine) Definition

Parlant Journeys implement multi-state workflows with explicit states, transitions, and conditions.

Complete 5-State Journey for TrialPath:

import parlant.sdk as p
from parlant.sdk import tool, ToolContext, ToolResult

async def create_clinical_trial_journey(agent):
    """
    Create the main clinical trial matching journey with 5 states:
    INGEST -> PRESCREEN -> VALIDATE_TRIALS -> GAP_FOLLOWUP -> SUMMARY
    """

    journey = await agent.create_journey(
        title="Clinical Trial Matching Workflow",
        conditions=[
            "Patient wants to find matching clinical trials",
            "User has uploaded medical documents for trial matching",
            "User wants to check clinical trial eligibility",
        ],
        description=(
            "Guides an NSCLC patient through the clinical trial matching process: "
            "document ingestion, prescreening, detailed validation, gap analysis, "
            "and final summary with actionable next steps."
        ),
    )

    # ── State 1: INGEST ──
    # initial_state is automatically created; configure transition to tool state
    ingest_tool_transition = await journey.initial_state.transition_to(
        tool_state=extract_patient_profile,
        tool_instruction=(
            "Extract patient profile from uploaded documents. "
            "Include all available clinical data and explicitly note unknowns."
        ),
        description="Ingest patient documents and extract structured profile",
    )

    # After tool execution, present results to patient
    ingest_chat_transition = await ingest_tool_transition.target.transition_to(
        chat_state=(
            "Present the extracted PatientProfile to the patient. "
            "Confirm what was understood and list missing information. "
            "Ask if the patient has additional documents to upload."
        ),
        description="Confirm extracted profile with patient",
    )

    # ── Fork: Check if enough data for prescreen ──
    ingest_fork = await ingest_chat_transition.target.fork()

    # Branch A: Sufficient data -> PRESCREEN
    prescreen_anchor_transition = await ingest_fork.transition_to(
        condition="Patient profile has diagnosis, stage, and ECOG performance status",
        tool_state=generate_search_anchors,
        tool_instruction="Generate search parameters from the patient profile for trial search",
        description="Generate search anchors for trial discovery",
    )

    # Branch B: Insufficient data -> Ask for more docs (loop back)
    await ingest_fork.transition_to(
        condition="Patient profile is missing critical fields like diagnosis, stage, or ECOG",
        existing_state=journey.initial_state,
        description="Request additional documents from patient",
    )

    # ── State 2: PRESCREEN (with iterative refinement loop) ──
    # Max 5 refinement rounds to prevent infinite loops
    search_transition = await prescreen_anchor_transition.target.transition_to(
        tool_state=search_clinical_trials,
        tool_instruction="Search ClinicalTrials.gov using the generated search anchors",
        description="Search for matching clinical trials",
    )

    prescreen_chat_transition = await search_transition.target.transition_to(
        chat_state=(
            "Present the search results to the patient. "
            "Summarize how many trials were found and their key characteristics. "
            "Explain that detailed eligibility checking will follow."
        ),
        description="Present trial search results",
    )

    # ── Fork: Check search result count for iterative refinement ──
    prescreen_fork = await prescreen_chat_transition.target.fork()

    # Branch A: result_count > 50 β†’ refine search (add filters), loop back
    refine_transition = await prescreen_fork.transition_to(
        condition="Search returned more than 50 results and refinement rounds have not been exhausted",
        tool_state=refine_search_query,
        tool_instruction=(
            "Too many results. Call refine_search_query to add more specific filters "
            "(phase, additional keywords) and then re-run search_clinical_trials."
        ),
        description="Refine search to narrow results (>50 found)",
    )

    # After refining, loop back to search_clinical_trials
    await refine_transition.target.transition_to(
        tool_state=search_clinical_trials,
        tool_instruction="Re-search ClinicalTrials.gov with refined search anchors",
        description="Re-search with refined parameters",
    )

    # Branch B: result_count == 0 β†’ relax search (remove filters), loop back or go to GAP
    relax_transition = await prescreen_fork.transition_to(
        condition="Search returned 0 results and refinement rounds have not been exhausted",
        tool_state=relax_search_query,
        tool_instruction=(
            "No results found. Call relax_search_query to remove restrictive filters "
            "(location, phase) and then re-run search_clinical_trials."
        ),
        description="Relax search to broaden results (0 found)",
    )

    # After relaxing, loop back to search_clinical_trials
    await relax_transition.target.transition_to(
        tool_state=search_clinical_trials,
        tool_instruction="Re-search ClinicalTrials.gov with relaxed search anchors",
        description="Re-search with relaxed parameters",
    )

    # Branch C: result_count 1-50 (right size) β†’ proceed to VALIDATE_TRIALS
    validate_transition = await prescreen_fork.transition_to(
        condition="Search returned between 1 and 50 results (right-sized for validation)",
        tool_state=evaluate_trial_eligibility,
        tool_instruction=(
            "Evaluate eligibility for each trial in the shortlist. "
            "Generate criterion-level assessment with evidence pointers."
        ),
        description="Begin criterion-level trial validation",
    )

    # Branch D: No trials and max iterations reached β†’ GAP_FOLLOWUP
    no_trials_gap = await prescreen_fork.transition_to(
        condition="No clinical trials were found and refinement rounds are exhausted (max 5 reached)",
        tool_state=analyze_gaps,
        tool_instruction=(
            "Analyze why no trials were found after exhausting refinement attempts. "
            "Suggest criteria relaxation (wider geography, more phases) or additional data."
        ),
        description="Analyze gaps when no trials found after max refinement rounds",
    )

    # ── State 3: VALIDATE_TRIALS ──
    validation_chat = await validate_transition.target.transition_to(
        chat_state=(
            "Present eligibility results with traffic-light indicators. "
            "For each trial show: green (likely eligible), yellow (uncertain), "
            "or red (likely ineligible). Explain key criteria decisions."
        ),
        description="Present eligibility validation results",
    )

    # ── Fork: Post-validation decision ──
    validation_fork = await validation_chat.target.fork()

    # Branch A: Some eligible trials -> SUMMARY
    summary_state_ref = await validation_fork.transition_to(
        condition="At least one trial shows green or yellow overall eligibility",
        chat_state=(
            "Generate a patient-friendly summary with top matching trials. "
            "Also generate a clinician packet with full criterion-level detail. "
            "Offer export options (JSON/Markdown)."
        ),
        description="Generate final summary for eligible trials",
    )

    # Branch B: All ineligible but gaps exist -> GAP_FOLLOWUP
    gap_from_validation = await validation_fork.transition_to(
        condition="All trials show red eligibility but there are unknown criteria that could change results",
        tool_state=analyze_gaps,
        tool_instruction=(
            "Identify the minimal set of additional tests or documents that "
            "could change eligibility outcomes."
        ),
        description="Analyze gaps for ineligible trials",
    )

    # ── State 4: GAP_FOLLOWUP ──
    gap_chat = await gap_from_validation.target.transition_to(
        chat_state=(
            "Explain to the patient what additional data could improve matching. "
            "Use simple language: 'If your doctor orders X test, you might qualify for Y trial.' "
            "Ask if the patient can provide additional documents."
        ),
        description="Present gap analysis and request additional data",
    )

    # Also connect the no-trials gap path
    no_trials_gap_chat = await no_trials_gap.target.transition_to(
        chat_state=(
            "Explain that no matching trials were found with current data. "
            "Suggest what could be done: broader search, additional data, or geographic flexibility."
        ),
        description="Explain no trials found and suggest alternatives",
    )

    # ── Fork: After gap analysis ──
    gap_fork = await gap_chat.target.fork()

    # Branch A: User provides new docs -> loop back to INGEST
    await gap_fork.transition_to(
        condition="User uploads new documents or provides additional information",
        existing_state=journey.initial_state,
        description="Re-ingest with new documents",
    )

    # Branch B: User declines -> SUMMARY
    await gap_fork.transition_to(
        condition="User declines to provide additional data or no further data is available",
        existing_state=summary_state_ref.target,
        description="Proceed to summary without additional data",
    )

    # Same fork for no-trials path
    no_trials_fork = await no_trials_gap_chat.target.fork()

    await no_trials_fork.transition_to(
        condition="User provides new documents or wants to adjust search criteria",
        existing_state=journey.initial_state,
        description="Re-ingest and re-search",
    )

    await no_trials_fork.transition_to(
        condition="User accepts current results or wants to end",
        existing_state=summary_state_ref.target,
        description="Proceed to final summary",
    )

    # ── State 5: SUMMARY (End) ──
    await summary_state_ref.target.transition_to(state=p.END_JOURNEY)

    return journey

2.5 Session Management

async def run_patient_session(server, agent):
    """Create and manage a patient matching session."""

    # Create a session for the patient
    session = await server.create_session(
        agent_id=agent.id,
        title="NSCLC Trial Matching - Patient Anna",
    )

    # Send a message as the patient
    event = await session.create_event(
        kind="message",
        source="customer",
        message="I have stage IV NSCLC and want to find clinical trials. Here are my documents.",
    )

    # Poll for agent response (long-polling with timeout)
    response_events = await session.list_events(
        min_offset=event.offset,
        wait_for_data=60,  # Wait up to 60 seconds
    )

    for evt in response_events:
        if evt.kind == "message" and evt.source == "ai_agent":
            print(f"Agent: {evt.message}")
        elif evt.kind == "tool":
            print(f"Tool called: {evt.data}")
        elif evt.kind == "status":
            print(f"Status: {evt.data}")

2.6 Context Variables

from datetime import timedelta

@tool
async def get_session_state(context: ToolContext) -> ToolResult:
    """Provides current journey state info to the agent."""
    return ToolResult(data={
        "current_state": context.metadata.get("journey_state", "INGEST"),
        "profile_complete": context.metadata.get("profile_complete", False),
        "trials_found": context.metadata.get("trials_found", 0),
    })

# Register as auto-updating context variable
await agent.create_variable(
    name="session-state",
    tool=get_session_state,
    update_interval=timedelta(minutes=0),  # Update on every interaction
)

3. Gemini API Integration Guide

3.1 SDK Setup

from google import genai
from google.genai import types
import os

# Initialize client
client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

# Model selection
MODEL = "gemini-3-pro"  # Or "gemini-3-flash" for lower cost

3.2 Structured Output with Pydantic

Gemini supports schema-constrained JSON output using Pydantic models.

SearchAnchors Generation:

from pydantic import BaseModel, Field
from typing import Optional
from google import genai

class GeographyFilter(BaseModel):
    country: str = Field(description="ISO country code")
    max_distance_km: Optional[int] = Field(default=None)

class TrialFilters(BaseModel):
    recruitment_status: list[str] = Field(
        description="e.g. ['Recruiting', 'Not yet recruiting']"
    )
    phase: list[str] = Field(description="e.g. ['Phase 2', 'Phase 3']")

class SearchAnchors(BaseModel):
    condition: str
    subtype: Optional[str] = None
    biomarkers: list[str] = Field(default_factory=list)
    stage: Optional[str] = None
    geography: Optional[GeographyFilter] = None
    age: Optional[int] = None
    performance_status_max: Optional[int] = None
    trial_filters: TrialFilters
    relaxation_order: list[str] = Field(
        description="Order in which to relax search criteria if too few results"
    )


async def generate_search_anchors(patient_profile: dict) -> SearchAnchors:
    """Use Gemini structured output to generate SearchAnchors from PatientProfile."""

    client = genai.Client()

    prompt = f"""
    Given the following patient profile, generate search parameters
    for finding relevant NSCLC clinical trials on ClinicalTrials.gov.

    Patient Profile:
    {json.dumps(patient_profile, indent=2)}

    Generate SearchAnchors that:
    1. Focus on the patient's specific cancer type, stage, and biomarkers
    2. Include appropriate geographic filters
    3. Consider the patient's age and performance status
    4. Set a relaxation_order for broadening search if too few results
    """

    response = client.models.generate_content(
        model="gemini-3-pro",
        contents=prompt,
        config={
            "response_mime_type": "application/json",
            "response_json_schema": SearchAnchors.model_json_schema(),
        },
    )

    return SearchAnchors.model_validate_json(response.text)

EligibilityLedger Generation:

from enum import Enum

class CriterionDecision(str, Enum):
    MET = "met"
    NOT_MET = "not_met"
    UNKNOWN = "unknown"

class EvidencePointer(BaseModel):
    doc_id: str
    page: Optional[int] = None
    span_id: Optional[str] = None

class TrialEvidencePointer(BaseModel):
    field: str
    offset_start: int
    offset_end: int

class CriterionAssessment(BaseModel):
    criterion_id: str
    type: str  # "inclusion" or "exclusion"
    text: str
    decision: CriterionDecision
    reasoning: str
    patient_evidence: list[EvidencePointer]
    trial_evidence: list[TrialEvidencePointer]

class GapItem(BaseModel):
    description: str
    recommended_action: str
    clinical_importance: str  # "high", "medium", "low"

class OverallAssessment(str, Enum):
    LIKELY_ELIGIBLE = "likely_eligible"
    LIKELY_INELIGIBLE = "likely_ineligible"
    UNCERTAIN = "uncertain"

class EligibilityLedger(BaseModel):
    patient_id: str
    nct_id: str
    overall_assessment: OverallAssessment
    criteria: list[CriterionAssessment]
    gaps: list[GapItem]


async def evaluate_single_trial(
    patient_profile: dict,
    trial_candidate: dict,
) -> EligibilityLedger:
    """Use Gemini to evaluate eligibility for a single trial."""

    client = genai.Client()

    prompt = f"""
    Evaluate this patient's eligibility for the clinical trial below.

    For each inclusion/exclusion criterion:
    1. Assign a criterion_id (inc_1, inc_2, ... or exc_1, exc_2, ...)
    2. Determine if the criterion is met, not_met, or unknown based on patient data
    3. Provide reasoning and evidence pointers

    Patient Profile:
    {json.dumps(patient_profile, indent=2)}

    Trial:
    NCT ID: {trial_candidate['nct_id']}
    Title: {trial_candidate['title']}

    Inclusion Criteria:
    {trial_candidate['eligibility_text']['inclusion']}

    Exclusion Criteria:
    {trial_candidate['eligibility_text']['exclusion']}

    Also identify gaps: criteria that are 'unknown' where additional patient
    data could change the assessment. For each gap, specify what test or
    document would help.
    """

    response = client.models.generate_content(
        model="gemini-3-pro",
        contents=prompt,
        config={
            "response_mime_type": "application/json",
            "response_json_schema": EligibilityLedger.model_json_schema(),
        },
    )

    return EligibilityLedger.model_validate_json(response.text)

3.3 Function Calling

Gemini supports declaring functions as tools that the LLM can invoke.

Automatic Function Calling (Python-only):

from google import genai
from google.genai import types


def search_trials_on_clinicaltrials_gov(
    condition: str,
    phase: str = "",
    status: str = "Recruiting",
    country: str = "",
    max_results: int = 20,
) -> dict:
    """Search ClinicalTrials.gov for matching trials.

    Args:
        condition: Medical condition to search for (e.g. 'NSCLC')
        phase: Trial phase filter (e.g. 'Phase 3')
        status: Recruitment status filter
        country: Country filter
        max_results: Maximum number of results

    Returns:
        Dictionary with list of matching trials.
    """
    # Implementation calls MCP server
    from trialpath.services.mcp_client import ClinicalTrialsMCPClient
    client = ClinicalTrialsMCPClient()
    return client.search_sync(condition, phase, status, country, max_results)


# Pass Python function directly -- SDK handles declaration automatically
config = types.GenerateContentConfig(
    tools=[search_trials_on_clinicaltrials_gov]
)

response = client.models.generate_content(
    model="gemini-3-pro",
    contents="Find Phase 3 NSCLC trials recruiting in the United States",
    config=config,
)
# SDK automatically executes the function and returns final response
print(response.text)

Manual Function Declaration:

search_function = {
    "name": "search_clinical_trials",
    "description": "Search ClinicalTrials.gov for clinical trials matching patient criteria",
    "parameters": {
        "type": "object",
        "properties": {
            "condition": {
                "type": "string",
                "description": "Primary medical condition (e.g., 'Non-Small Cell Lung Cancer')",
            },
            "biomarkers": {
                "type": "array",
                "items": {"type": "string"},
                "description": "List of biomarker results (e.g., ['EGFR exon 19 deletion'])",
            },
            "phase": {
                "type": "array",
                "items": {"type": "string"},
                "description": "Trial phases to include (e.g., ['Phase 2', 'Phase 3'])",
            },
            "status": {
                "type": "string",
                "description": "Recruitment status filter",
                "enum": ["Recruiting", "Not yet recruiting", "Active, not recruiting"],
            },
        },
        "required": ["condition"],
    },
}

tools = types.Tool(function_declarations=[search_function])
config = types.GenerateContentConfig(tools=[tools])

response = client.models.generate_content(
    model="gemini-3-pro",
    contents="Find trials for a patient with EGFR+ NSCLC",
    config=config,
)

# Handle the function call response
if response.candidates[0].content.parts[0].function_call:
    fc = response.candidates[0].content.parts[0].function_call
    print(f"Function: {fc.name}, Args: {fc.args}")
    # Execute the function and send result back

3.4 Token Management and Cost Control

async def count_tokens_and_estimate_cost(prompt: str, model: str = "gemini-3-pro"):
    """Estimate tokens and cost before making an API call."""
    client = genai.Client()

    token_count = client.models.count_tokens(
        model=model,
        contents=prompt,
    )

    # Approximate pricing (check current rates)
    input_cost_per_1m = 1.25   # USD per 1M input tokens (Gemini 3 Pro)
    output_cost_per_1m = 5.00  # USD per 1M output tokens

    estimated_input_cost = (token_count.total_tokens / 1_000_000) * input_cost_per_1m

    return {
        "input_tokens": token_count.total_tokens,
        "estimated_input_cost_usd": estimated_input_cost,
    }


class GeminiCostTracker:
    """Track cumulative cost across a patient session."""

    def __init__(self, budget_usd: float = 0.50):
        self.budget = budget_usd
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_cost = 0.0

    def record(self, input_tokens: int, output_tokens: int):
        self.total_input_tokens += input_tokens
        self.total_output_tokens += output_tokens
        cost = (input_tokens / 1e6) * 1.25 + (output_tokens / 1e6) * 5.0
        self.total_cost += cost

    @property
    def remaining_budget(self) -> float:
        return self.budget - self.total_cost

    @property
    def over_budget(self) -> bool:
        return self.total_cost > self.budget

4. MedGemma Integration Guide

4.1 HuggingFace Endpoint Setup

MedGemma 4B instruction-tuned model (google/medgemma-4b-it) is a multimodal model that accepts text and medical images.

Pipeline Approach (Recommended for PoC):

from transformers import pipeline
from PIL import Image
import torch
import json

class MedGemmaExtractor:
    """Extract patient profiles from medical documents using MedGemma."""

    def __init__(self, model_id: str = "google/medgemma-4b-it"):
        self.pipe = pipeline(
            "image-text-to-text",
            model=model_id,
            torch_dtype=torch.bfloat16,
            device="cuda",
        )

    async def extract(
        self,
        document_urls: list[str],
        metadata: dict,
    ) -> dict:
        """
        Extract PatientProfile from documents.

        Args:
            document_urls: Paths to PDF pages rendered as images.
            metadata: Basic patient metadata (age, sex).

        Returns:
            PatientProfile dictionary.
        """
        # Load images from document pages
        images = [Image.open(url) for url in document_urls]

        # Build multimodal message
        content = [
            {"type": "text", "text": self._build_extraction_prompt(metadata)},
        ]
        for img in images:
            content.append({"type": "image", "image": img})

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self._system_prompt()}],
            },
            {
                "role": "user",
                "content": content,
            },
        ]

        output = self.pipe(text=messages, max_new_tokens=2048)
        raw_text = output[0]["generated_text"][-1]["content"]

        return self._parse_profile(raw_text, metadata)

    def _system_prompt(self) -> str:
        return (
            "You are an expert medical data extractor specializing in oncology. "
            "Extract structured patient information from medical documents. "
            "Always cite the source document and location for each extracted fact. "
            "If information is unclear or missing, explicitly note it as unknown."
        )

    def _build_extraction_prompt(self, metadata: dict) -> str:
        return f"""
        Extract a structured patient profile from the following medical documents.

        Known metadata: age={metadata.get('age', 'unknown')}, sex={metadata.get('sex', 'unknown')}

        Extract the following fields in JSON format:
        - diagnosis (primary_condition, histology, stage, diagnosis_date)
        - performance_status (scale, value, evidence)
        - biomarkers (name, result, date, evidence for each)
        - key_labs (name, value, unit, date, evidence for each)
        - treatments (drug_name, start_date, end_date, line, evidence)
        - comorbidities (name, grade, evidence)
        - imaging_summary (modality, date, finding, interpretation, certainty, evidence)
        - unknowns (field, reason, importance for each missing critical field)

        For each evidence reference, include: doc_id (filename), page number, span_id.

        Return ONLY valid JSON matching the PatientProfile schema.
        """

    def _parse_profile(self, raw_text: str, metadata: dict) -> dict:
        """Parse MedGemma output into PatientProfile structure."""
        try:
            # Try direct JSON parsing
            profile = json.loads(raw_text)
        except json.JSONDecodeError:
            # Extract JSON from markdown code blocks
            import re
            json_match = re.search(r'```(?:json)?\s*(.*?)\s*```', raw_text, re.DOTALL)
            if json_match:
                profile = json.loads(json_match.group(1))
            else:
                raise ValueError(f"Could not parse MedGemma output as JSON: {raw_text[:200]}")

        # Merge with metadata
        if "demographics" not in profile:
            profile["demographics"] = {}
        profile["demographics"].update(metadata)

        return profile

Direct Transformers API (For finer control):

from transformers import AutoProcessor, AutoModelForImageTextToText
import torch

class MedGemmaDirectExtractor:
    """Direct transformers API for MedGemma with more control."""

    def __init__(self, model_id: str = "google/medgemma-4b-it"):
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        self.processor = AutoProcessor.from_pretrained(model_id)

    def evaluate_criterion(
        self,
        patient_evidence_image: Image.Image | None,
        patient_text: str,
        criterion_text: str,
    ) -> dict:
        """
        Evaluate a single trial criterion against patient evidence.
        Used in VALIDATE_TRIALS state for criterion-level assessment.
        """
        content = [
            {
                "type": "text",
                "text": (
                    f"Evaluate whether this patient meets the following trial criterion.\n\n"
                    f"Patient Information:\n{patient_text}\n\n"
                    f"Trial Criterion:\n{criterion_text}\n\n"
                    f"Respond with JSON: "
                    f'{{"decision": "met|not_met|unknown", "reasoning": "...", '
                    f'"confidence": 0.0-1.0}}'
                ),
            },
        ]

        if patient_evidence_image:
            content.append({"type": "image", "image": patient_evidence_image})

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are an expert oncologist evaluating clinical trial eligibility criteria."}],
            },
            {"role": "user", "content": content},
        ]

        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            generation = self.model.generate(
                **inputs, max_new_tokens=512, do_sample=False,
            )
            generation = generation[0][input_len:]

        decoded = self.processor.decode(generation, skip_special_tokens=True)
        return json.loads(decoded)

4.2 HuggingFace Inference Endpoint (Cloud Deployment)

For production, deploy MedGemma as a HuggingFace Inference Endpoint:

from huggingface_hub import InferenceClient
import base64

class MedGemmaCloudExtractor:
    """Use HuggingFace Inference Endpoint for MedGemma in production."""

    def __init__(self, endpoint_url: str, hf_token: str):
        self.client = InferenceClient(
            model=endpoint_url,
            token=hf_token,
        )

    async def extract_from_image(
        self, image_bytes: bytes, prompt: str,
    ) -> str:
        """Send image + prompt to HF Inference Endpoint."""
        image_b64 = base64.b64encode(image_bytes).decode("utf-8")

        response = self.client.chat_completion(
            messages=[
                {
                    "role": "system",
                    "content": "You are an expert medical data extractor.",
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{image_b64}"
                            },
                        },
                    ],
                },
            ],
            max_tokens=2048,
        )

        return response.choices[0].message.content

4.3 Output to PatientProfile Mapping

MedGemma Output Field         β†’  PatientProfile Field
─────────────────────────────────────────────────────
diagnosis info                β†’  diagnosis.primary_condition, .histology, .stage
ECOG/KPS score                β†’  performance_status.scale, .value
EGFR/ALK/PD-L1 results       β†’  biomarkers[].name, .result, .date
CBC/CMP lab values            β†’  key_labs[].name, .value, .unit
medication/treatment history  β†’  treatments[].drug_name, .start_date, .line
comorbid conditions           β†’  comorbidities[].name, .grade
radiology findings            β†’  imaging_summary[].modality, .finding
missing information           β†’  unknowns[].field, .reason, .importance
source references             β†’  *.evidence[].doc_id, .page, .span_id

4.4 Criterion-Level Medical Evaluation

MedGemma is used not only for initial profile extraction but also for evaluating medical criteria during the VALIDATE_TRIALS state. This ensures that clinical judgments about labs, imaging, biomarkers, and treatment history are made by a medically-specialized model rather than a general-purpose LLM.

Criteria Type Classification:

# Criteria types that require MedGemma evaluation
MEDICAL_CRITERIA_TYPES = ["lab_value", "imaging", "biomarker", "treatment_history"]

# All other criteria types are handled by Gemini (structural evaluation)
# Examples: age requirements, geographic restrictions, consent requirements,
# performance status thresholds, administrative criteria

Routing Logic:

def route_criterion_to_model(criterion_type: str) -> str:
    """Determine which model should evaluate a criterion.

    Args:
        criterion_type: The classified type of the criterion.

    Returns:
        "medgemma" for medical criteria, "gemini" for structural criteria.
    """
    if criterion_type in MEDICAL_CRITERIA_TYPES:
        return "medgemma"
    return "gemini"

MedGemma Criterion Evaluation Method:

class MedGemmaExtractor:
    # ... existing __init__, extract, etc. ...

    async def evaluate_medical_criterion(
        self,
        criterion_text: str,
        patient_profile: PatientProfile,
        evidence_docs: list[SourceDocument],
    ) -> dict:
        """
        Evaluate a single medical criterion against patient evidence using MedGemma.

        This method is specifically designed for criterion-level evaluation
        (as opposed to full profile extraction). It uses a different prompt
        template optimized for binary/ternary decision-making on clinical criteria.

        Args:
            criterion_text: The raw eligibility criterion text from the trial
                (e.g., "ANC >= 1.5 x 10^9/L within 14 days of enrollment").
            patient_profile: The full PatientProfile for context.
            evidence_docs: Source documents that may contain relevant evidence.

        Returns:
            Dictionary with keys: decision, reasoning, confidence, evidence_pointers.
        """
        # Build patient context from relevant profile fields
        patient_context = self._build_criterion_context(criterion_text, patient_profile)

        prompt = f"""
        You are evaluating whether a patient meets a specific clinical trial criterion.

        CRITERION:
        {criterion_text}

        PATIENT DATA:
        {patient_context}

        INSTRUCTIONS:
        1. Determine if this criterion is MET, NOT_MET, or UNKNOWN based on the patient data.
        2. For lab values: check the value, unit, and whether the test date is within any
           required time window.
        3. For imaging: interpret the finding in the context of the criterion requirement.
        4. For biomarkers: match the biomarker name and result against the criterion.
        5. For treatment history: verify drug names, lines of therapy, and timing.

        Respond with JSON:
        {{
            "decision": "met" | "not_met" | "unknown",
            "reasoning": "Detailed clinical reasoning for the decision",
            "confidence": 0.0 to 1.0,
            "evidence_pointers": [
                {{"doc_id": "...", "page": ..., "span_id": "..."}}
            ],
            "criterion_type": "lab_value" | "imaging" | "biomarker" | "treatment_history"
        }}
        """

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": (
                    "You are an expert oncologist evaluating clinical trial eligibility "
                    "criteria. Focus on precise medical evaluation of labs, imaging, "
                    "biomarkers, and treatment history."
                )}],
            },
            {"role": "user", "content": [{"type": "text", "text": prompt}]},
        ]

        output = self.pipe(text=messages, max_new_tokens=1024)
        raw_text = output[0]["generated_text"][-1]["content"]

        return json.loads(raw_text)

    def _build_criterion_context(
        self, criterion_text: str, profile: PatientProfile,
    ) -> str:
        """Build relevant patient context for a specific criterion."""
        context_parts = []

        # Include labs if criterion mentions lab-related terms
        lab_keywords = ["anc", "creatinine", "hemoglobin", "platelet", "wbc", "alt", "ast"]
        if any(kw in criterion_text.lower() for kw in lab_keywords):
            for lab in profile.key_labs:
                context_parts.append(
                    f"Lab: {lab.name} = {lab.value} {lab.unit} (date: {lab.date})"
                )

        # Include biomarkers if criterion mentions biomarker terms
        biomarker_keywords = ["egfr", "alk", "ros1", "pd-l1", "braf", "kras", "her2", "mutation"]
        if any(kw in criterion_text.lower() for kw in biomarker_keywords):
            for bm in profile.biomarkers:
                context_parts.append(
                    f"Biomarker: {bm.name} = {bm.result} (date: {bm.date})"
                )

        # Include imaging if criterion mentions imaging terms
        imaging_keywords = ["mri", "ct", "pet", "scan", "imaging", "brain metast"]
        if any(kw in criterion_text.lower() for kw in imaging_keywords):
            for img in profile.imaging_summary:
                context_parts.append(
                    f"Imaging: {img.modality} on {img.date} β€” {img.finding}"
                )

        # Include treatments if criterion mentions treatment terms
        treatment_keywords = ["prior", "therapy", "treatment", "line", "chemotherapy", "tki"]
        if any(kw in criterion_text.lower() for kw in treatment_keywords):
            for tx in profile.treatments:
                context_parts.append(
                    f"Treatment: {tx.drug_name}, line {tx.line}, "
                    f"{tx.start_date} to {tx.end_date}"
                )

        # Always include diagnosis context
        if profile.diagnosis:
            context_parts.append(
                f"Diagnosis: {profile.diagnosis.primary_condition}, "
                f"stage {profile.diagnosis.stage}, "
                f"histology {profile.diagnosis.histology}"
            )

        return "\n".join(context_parts) if context_parts else "No relevant patient data found."

5. ClinicalTrials MCP Integration Guide

5.1 Available Tools

The cyanheads/clinicaltrialsgov-mcp-server provides 5 tools:

Tool Description Key Use Case
clinicaltrials_search_studies Search by query, filters, geography PRESCREEN state
clinicaltrials_get_study Fetch full study by NCT ID(s) VALIDATE_TRIALS state
clinicaltrials_find_eligible_studies Match patient demographics to trials PRESCREEN state (alternative)
clinicaltrials_analyze_trends Statistical analysis of trial data Optional analytics
clinicaltrials_compare_studies Side-by-side comparison of 2-5 studies Patient comparison view

5.2 SearchAnchors to MCP Query Mapping

import httpx

class ClinicalTrialsMCPClient:
    """Client for ClinicalTrials MCP Server."""

    def __init__(self, mcp_url: str = "http://localhost:3000"):
        self.mcp_url = mcp_url

    async def search(self, anchors: SearchAnchors) -> list[dict]:
        """Convert SearchAnchors to MCP search_studies call."""

        # Build query string from anchors
        query_parts = [anchors.condition]
        if anchors.subtype:
            query_parts.append(anchors.subtype)
        if anchors.biomarkers:
            query_parts.extend(anchors.biomarkers)

        query = " ".join(query_parts)

        # Build filter expression for ClinicalTrials.gov API v2
        filters = []
        if anchors.trial_filters.recruitment_status:
            status_filter = " OR ".join(
                f"AREA[OverallStatus]{s}"
                for s in anchors.trial_filters.recruitment_status
            )
            filters.append(f"({status_filter})")

        if anchors.trial_filters.phase:
            phase_filter = " OR ".join(
                f"AREA[Phase]{p}" for p in anchors.trial_filters.phase
            )
            filters.append(f"({phase_filter})")

        if anchors.age is not None:
            filters.append(f"AREA[MinimumAge]RANGE[MIN, {anchors.age}]")
            filters.append(f"AREA[MaximumAge]RANGE[{anchors.age}, MAX]")

        filter_str = " AND ".join(filters) if filters else None

        # Call MCP tool
        params = {
            "query": query,
            "pageSize": 50,
            "sort": "LastUpdateDate:desc",
        }
        if filter_str:
            params["filter"] = filter_str
        if anchors.geography:
            params["country"] = anchors.geography.country

        result = await self._call_tool("clinicaltrials_search_studies", params)
        return result.get("studies", [])

    async def get_study(self, nct_id: str) -> dict:
        """Fetch full study details by NCT ID."""
        result = await self._call_tool("clinicaltrials_get_study", {
            "nctIds": [nct_id],
            "summaryOnly": False,
        })
        studies = result.get("studies", [])
        return studies[0] if studies else {}

    async def find_eligible(
        self,
        age: int,
        sex: str,
        conditions: list[str],
        country: str,
        max_results: int = 20,
    ) -> dict:
        """Use find_eligible_studies for demographic-based matching."""
        return await self._call_tool("clinicaltrials_find_eligible_studies", {
            "age": age,
            "sex": sex,
            "conditions": conditions,
            "location": {"country": country},
            "recruitingOnly": True,
            "maxResults": max_results,
        })

    async def compare_studies(self, nct_ids: list[str]) -> dict:
        """Compare 2-5 studies side by side."""
        return await self._call_tool("clinicaltrials_compare_studies", {
            "nctIds": nct_ids,
            "compareFields": "all",
        })

    async def _call_tool(self, tool_name: str, params: dict) -> dict:
        """Call an MCP tool via JSON-RPC."""
        async with httpx.AsyncClient(timeout=30.0) as client:
            response = await client.post(
                f"{self.mcp_url}/mcp/v1/tools/call",
                json={
                    "jsonrpc": "2.0",
                    "method": "tools/call",
                    "params": {
                        "name": tool_name,
                        "arguments": params,
                    },
                    "id": 1,
                },
            )
            response.raise_for_status()
            data = response.json()

            if "error" in data:
                raise MCPError(
                    code=data["error"].get("code", -1),
                    message=data["error"].get("message", "Unknown MCP error"),
                )

            return data.get("result", {})


class MCPError(Exception):
    def __init__(self, code: int, message: str):
        self.code = code
        self.message = message
        super().__init__(f"MCP Error {code}: {message}")

5.3 MCP Server Search Parameters

clinicaltrials_search_studies Parameters:

Parameter Type Description
query string General search (conditions, interventions, sponsors)
filter string Advanced filter using ClinicalTrials.gov syntax
pageSize int (1-200) Results per page (default: 10)
pageToken string Pagination token
sort string Sort order (e.g., "LastUpdateDate:desc")
fields string[] Specific fields to return
country string Geographic filter
state string State/province filter
city string City filter

clinicaltrials_find_eligible_studies Parameters:

Parameter Type Description
age int (0-120) Patient age
sex enum 'All', 'Female', 'Male'
conditions string[] Medical conditions (min 1)
location.country string Country
location.state string State (optional)
location.city string City (optional)
healthyVolunteer bool Default false
maxResults int (1-50) Max results (default: 10)
recruitingOnly bool Default true

5.4 ClinicalTrials.gov API v2 Mapping

The MCP server wraps https://clinicaltrials.gov/api/v2:

SearchAnchors Field            β†’ API Parameter
────────────────────────────────────────────────
condition + subtype + biomarkers β†’ query.term
trial_filters.recruitment_status β†’ filter.advanced (AREA[OverallStatus])
trial_filters.phase             β†’ filter.advanced (AREA[Phase])
age                             β†’ filter.advanced (AREA[MinimumAge/MaximumAge])
geography.country               β†’ filter.advanced (AREA[LocationCountry])

6. Data Contracts (Complete JSON Schemas)

6.1 PatientProfile v1

from pydantic import BaseModel, Field
from typing import Optional
from datetime import date

class EvidencePointer(BaseModel):
    doc_id: str = Field(description="Source document identifier")
    page: Optional[int] = Field(default=None, description="Page number")
    span_id: Optional[str] = Field(default=None, description="Text span identifier")

class SourceDocument(BaseModel):
    doc_id: str
    type: str = Field(description="clinic_letter|pathology|lab|imaging")
    meta: dict = Field(default_factory=dict)

class Demographics(BaseModel):
    age: Optional[int] = None
    sex: Optional[str] = None

class Diagnosis(BaseModel):
    primary_condition: str = Field(description="e.g. 'Non-Small Cell Lung Cancer'")
    histology: Optional[str] = Field(default=None, description="e.g. 'adenocarcinoma'")
    stage: Optional[str] = Field(default=None, description="e.g. 'IVa'")
    diagnosis_date: Optional[date] = None

class PerformanceStatus(BaseModel):
    scale: str = Field(description="'ECOG' or 'KPS'")
    value: int
    evidence: list[EvidencePointer] = Field(default_factory=list)

class Biomarker(BaseModel):
    name: str = Field(description="e.g. 'EGFR', 'ALK', 'PD-L1'")
    result: str = Field(description="e.g. 'Exon 19 deletion', 'Positive 80%'")
    date: Optional[date] = None
    evidence: list[EvidencePointer] = Field(default_factory=list)

class LabResult(BaseModel):
    name: str = Field(description="e.g. 'ANC', 'Creatinine'")
    value: float
    unit: str
    date: Optional[date] = None
    evidence: list[EvidencePointer] = Field(default_factory=list)

class Treatment(BaseModel):
    drug_name: str
    start_date: Optional[date] = None
    end_date: Optional[date] = None
    line: Optional[int] = Field(default=None, description="Line of therapy (1, 2, 3...)")
    evidence: list[EvidencePointer] = Field(default_factory=list)

class Comorbidity(BaseModel):
    name: str
    grade: Optional[str] = None
    evidence: list[EvidencePointer] = Field(default_factory=list)

class ImagingSummary(BaseModel):
    modality: str = Field(description="e.g. 'MRI brain', 'CT chest'")
    date: Optional[date] = None
    finding: str
    interpretation: Optional[str] = None
    certainty: Optional[str] = Field(default=None, description="low|medium|high")
    evidence: list[EvidencePointer] = Field(default_factory=list)

class UnknownField(BaseModel):
    field: str = Field(description="Name of missing field")
    reason: str = Field(description="Why it is unknown")
    importance: str = Field(description="high|medium|low")

class PatientProfile(BaseModel):
    patient_id: str
    source_docs: list[SourceDocument] = Field(default_factory=list)
    demographics: Demographics = Field(default_factory=Demographics)
    diagnosis: Optional[Diagnosis] = None
    performance_status: Optional[PerformanceStatus] = None
    biomarkers: list[Biomarker] = Field(default_factory=list)
    key_labs: list[LabResult] = Field(default_factory=list)
    treatments: list[Treatment] = Field(default_factory=list)
    comorbidities: list[Comorbidity] = Field(default_factory=list)
    imaging_summary: list[ImagingSummary] = Field(default_factory=list)
    unknowns: list[UnknownField] = Field(default_factory=list)

    def has_minimum_prescreen_data(self) -> bool:
        """Check if profile has enough data for prescreening."""
        return (
            self.diagnosis is not None
            and self.diagnosis.stage is not None
            and self.performance_status is not None
        )

6.2 SearchAnchors v1

class GeographyFilter(BaseModel):
    country: str = Field(description="ISO country code or full name")
    max_distance_km: Optional[int] = None

class TrialFilters(BaseModel):
    recruitment_status: list[str] = Field(
        default=["Recruiting", "Not yet recruiting"]
    )
    phase: list[str] = Field(default=["Phase 2", "Phase 3"])

class SearchAnchors(BaseModel):
    condition: str = Field(description="Primary condition for search")
    subtype: Optional[str] = Field(default=None, description="Cancer subtype")
    biomarkers: list[str] = Field(default_factory=list)
    stage: Optional[str] = None
    geography: Optional[GeographyFilter] = None
    age: Optional[int] = None
    performance_status_max: Optional[int] = None
    trial_filters: TrialFilters = Field(default_factory=TrialFilters)
    relaxation_order: list[str] = Field(
        default=["phase", "distance", "biomarker_strictness"],
        description="Order in which to relax criteria if too few results",
    )

6.3 TrialCandidate v1

class TrialLocation(BaseModel):
    country: str
    city: Optional[str] = None

class AgeRange(BaseModel):
    min: Optional[int] = None
    max: Optional[int] = None

class EligibilityText(BaseModel):
    inclusion: str
    exclusion: str

class TrialCandidate(BaseModel):
    nct_id: str = Field(description="NCT identifier e.g. 'NCT01234567'")
    title: str
    conditions: list[str] = Field(default_factory=list)
    phase: Optional[str] = None
    status: Optional[str] = None
    locations: list[TrialLocation] = Field(default_factory=list)
    age_range: Optional[AgeRange] = None
    fingerprint_text: str = Field(
        description="Short text for Gemini reranking"
    )
    eligibility_text: Optional[EligibilityText] = None

6.4 EligibilityLedger v1

from enum import Enum

class CriterionDecision(str, Enum):
    MET = "met"
    NOT_MET = "not_met"
    UNKNOWN = "unknown"

class OverallAssessment(str, Enum):
    LIKELY_ELIGIBLE = "likely_eligible"
    LIKELY_INELIGIBLE = "likely_ineligible"
    UNCERTAIN = "uncertain"

class TrialEvidencePointer(BaseModel):
    field: str = Field(description="e.g. 'eligibility_text.inclusion'")
    offset_start: int
    offset_end: int

class TemporalCheck(BaseModel):
    """Validates whether patient evidence falls within a required time window.

    Example: "ANC >= 1.5 x 10^9/L within 14 days of enrollment" requires
    that the lab result date is no more than 14 days before the evaluation date.
    """
    required_window_days: int | None = Field(None, description="e.g. 14 for 'within 14 days'")
    reference_date: date | None = Field(None, description="Date of the patient evidence")
    evaluation_date: date = Field(default_factory=date.today)
    is_within_window: bool | None = None

    @property
    def days_elapsed(self) -> int | None:
        if self.reference_date:
            return (self.evaluation_date - self.reference_date).days
        return None


class CriterionAssessment(BaseModel):
    criterion_id: str = Field(description="e.g. 'inc_1', 'exc_3'")
    type: str = Field(description="'inclusion' or 'exclusion'")
    text: str = Field(description="Original criterion text from trial")
    decision: CriterionDecision
    patient_evidence: list[EvidencePointer] = Field(default_factory=list)
    trial_evidence: list[TrialEvidencePointer] = Field(default_factory=list)
    temporal_check: TemporalCheck | None = None

class GapItem(BaseModel):
    description: str
    recommended_action: str
    clinical_importance: str = Field(description="high|medium|low")

class EligibilityLedger(BaseModel):
    patient_id: str
    nct_id: str
    overall_assessment: OverallAssessment
    criteria: list[CriterionAssessment] = Field(default_factory=list)
    gaps: list[GapItem] = Field(default_factory=list)

    @property
    def met_count(self) -> int:
        return sum(1 for c in self.criteria if c.decision == CriterionDecision.MET)

    @property
    def not_met_count(self) -> int:
        return sum(1 for c in self.criteria if c.decision == CriterionDecision.NOT_MET)

    @property
    def unknown_count(self) -> int:
        return sum(1 for c in self.criteria if c.decision == CriterionDecision.UNKNOWN)

    @property
    def traffic_light(self) -> str:
        """Return traffic light color for UI display."""
        if self.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE:
            return "green"
        elif self.overall_assessment == OverallAssessment.UNCERTAIN:
            return "yellow"
        return "red"

6.5 SearchLog v1

Tracks every step of the iterative query refinement loop (PRD Section 4: "Every search step is logged and explainable").

from datetime import datetime
from enum import Enum

class RefinementAction(str, Enum):
    INITIAL = "initial"
    REFINE = "refine"        # >50 results β†’ add filters (phase, keywords)
    RELAX = "relax"          # 0 results β†’ remove filters (location, phase)
    SHORTLIST = "shortlist"  # 10-30 results β†’ proceed to deep verification
    ABORT = "abort"          # max iterations reached

class SearchStep(BaseModel):
    step_number: int
    query_params: dict = Field(description="SearchAnchors snapshot used for this query")
    result_count: int
    action_taken: RefinementAction
    action_reason: str = Field(description="Human-readable why this action was chosen")
    timestamp: datetime = Field(default_factory=datetime.utcnow)
    nct_ids_sample: list[str] = Field(
        default_factory=list,
        description="Sample of NCT IDs returned (up to 10 for transparency)",
    )

class SearchLog(BaseModel):
    session_id: str
    patient_id: str
    steps: list[SearchStep] = Field(default_factory=list)
    final_shortlist_nct_ids: list[str] = Field(default_factory=list)
    total_refinement_rounds: int = 0
    max_refinement_rounds: int = Field(default=5, description="Safety cap to prevent infinite loops")

    @property
    def is_refinement_exhausted(self) -> bool:
        return self.total_refinement_rounds >= self.max_refinement_rounds

    def add_step(self, query_params: dict, result_count: int, action: RefinementAction, reason: str, nct_ids_sample: list[str] | None = None):
        step = SearchStep(
            step_number=len(self.steps) + 1,
            query_params=query_params,
            result_count=result_count,
            action_taken=action,
            action_reason=reason,
            nct_ids_sample=nct_ids_sample or [],
        )
        self.steps.append(step)
        if action in (RefinementAction.REFINE, RefinementAction.RELAX):
            self.total_refinement_rounds += 1

    def to_transparency_summary(self) -> list[dict]:
        """Generate human-readable search process for FE display."""
        return [
            {
                "step": s.step_number,
                "query": s.query_params,
                "found": s.result_count,
                "action": s.action_taken.value,
                "reason": s.action_reason,
            }
            for s in self.steps
        ]

7. TDD Test Cases

7.1 Data Model Tests

# tests/test_models.py
import pytest
from datetime import date
from trialpath.models.patient_profile import (
    PatientProfile, Diagnosis, PerformanceStatus, Demographics,
    Biomarker, EvidencePointer, UnknownField,
)
from trialpath.models.search_anchors import SearchAnchors, TrialFilters
from trialpath.models.trial_candidate import TrialCandidate, EligibilityText
from trialpath.models.eligibility_ledger import (
    EligibilityLedger, CriterionAssessment, CriterionDecision,
    OverallAssessment, GapItem,
)


class TestPatientProfile:
    """PatientProfile v1 validation and helper tests."""

    def test_minimal_valid_profile(self):
        """A profile with only patient_id should be valid."""
        profile = PatientProfile(patient_id="P001")
        assert profile.patient_id == "P001"
        assert profile.unknowns == []

    def test_complete_nsclc_profile(self):
        """Full NSCLC patient profile should serialize/deserialize correctly."""
        profile = PatientProfile(
            patient_id="P001",
            demographics=Demographics(age=52, sex="female"),
            diagnosis=Diagnosis(
                primary_condition="Non-Small Cell Lung Cancer",
                histology="adenocarcinoma",
                stage="IVa",
                diagnosis_date=date(2025, 11, 15),
            ),
            performance_status=PerformanceStatus(
                scale="ECOG", value=1,
                evidence=[EvidencePointer(doc_id="clinic_1", page=2, span_id="s_17")],
            ),
            biomarkers=[
                Biomarker(
                    name="EGFR", result="Exon 19 deletion",
                    date=date(2026, 1, 10),
                    evidence=[EvidencePointer(doc_id="path_egfr", page=1, span_id="s_3")],
                ),
            ],
            unknowns=[
                UnknownField(field="PD-L1", reason="Not found in documents", importance="medium"),
            ],
        )

        data = profile.model_dump()
        restored = PatientProfile.model_validate(data)
        assert restored.patient_id == "P001"
        assert restored.diagnosis.stage == "IVa"
        assert len(restored.biomarkers) == 1
        assert restored.biomarkers[0].name == "EGFR"

    def test_has_minimum_prescreen_data_true(self):
        """Profile with diagnosis + stage + ECOG satisfies prescreen requirements."""
        profile = PatientProfile(
            patient_id="P001",
            diagnosis=Diagnosis(
                primary_condition="NSCLC", stage="IV",
            ),
            performance_status=PerformanceStatus(scale="ECOG", value=1),
        )
        assert profile.has_minimum_prescreen_data() is True

    def test_has_minimum_prescreen_data_false_no_stage(self):
        """Profile without stage should fail prescreen check."""
        profile = PatientProfile(
            patient_id="P001",
            diagnosis=Diagnosis(primary_condition="NSCLC"),
            performance_status=PerformanceStatus(scale="ECOG", value=1),
        )
        assert profile.has_minimum_prescreen_data() is False

    def test_has_minimum_prescreen_data_false_no_ecog(self):
        """Profile without performance status should fail prescreen check."""
        profile = PatientProfile(
            patient_id="P001",
            diagnosis=Diagnosis(primary_condition="NSCLC", stage="IV"),
        )
        assert profile.has_minimum_prescreen_data() is False

    def test_json_roundtrip(self):
        """Profile should survive JSON serialization roundtrip."""
        profile = PatientProfile(
            patient_id="P001",
            demographics=Demographics(age=65, sex="male"),
            diagnosis=Diagnosis(
                primary_condition="NSCLC",
                histology="squamous",
                stage="IIIb",
            ),
        )
        json_str = profile.model_dump_json()
        restored = PatientProfile.model_validate_json(json_str)
        assert restored == profile


class TestSearchAnchors:
    """SearchAnchors v1 validation tests."""

    def test_minimal_anchors(self):
        anchors = SearchAnchors(condition="NSCLC")
        assert anchors.condition == "NSCLC"
        assert anchors.trial_filters.recruitment_status == ["Recruiting", "Not yet recruiting"]

    def test_full_anchors(self):
        anchors = SearchAnchors(
            condition="Non-Small Cell Lung Cancer",
            subtype="adenocarcinoma",
            biomarkers=["EGFR exon 19 deletion"],
            stage="IV",
            age=52,
            performance_status_max=1,
            trial_filters=TrialFilters(
                recruitment_status=["Recruiting"],
                phase=["Phase 3"],
            ),
            relaxation_order=["phase", "distance"],
        )
        assert len(anchors.biomarkers) == 1
        assert anchors.trial_filters.phase == ["Phase 3"]


class TestTrialCandidate:
    """TrialCandidate v1 tests."""

    def test_trial_with_eligibility_text(self):
        trial = TrialCandidate(
            nct_id="NCT01234567",
            title="Phase 3 Study of Osimertinib",
            conditions=["NSCLC"],
            phase="Phase 3",
            status="Recruiting",
            fingerprint_text="Osimertinib EGFR+ NSCLC Phase 3",
            eligibility_text=EligibilityText(
                inclusion="Histologically confirmed NSCLC stage IV",
                exclusion="Prior EGFR TKI therapy",
            ),
        )
        assert trial.nct_id == "NCT01234567"
        assert trial.eligibility_text.inclusion.startswith("Histologically")


class TestEligibilityLedger:
    """EligibilityLedger v1 tests."""

    def test_traffic_light_green(self):
        ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
        )
        assert ledger.traffic_light == "green"

    def test_traffic_light_yellow(self):
        ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.UNCERTAIN,
        )
        assert ledger.traffic_light == "yellow"

    def test_traffic_light_red(self):
        ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.LIKELY_INELIGIBLE,
        )
        assert ledger.traffic_light == "red"

    def test_criterion_counts(self):
        ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.UNCERTAIN,
            criteria=[
                CriterionAssessment(
                    criterion_id="inc_1", type="inclusion",
                    text="Stage IV NSCLC", decision=CriterionDecision.MET,
                ),
                CriterionAssessment(
                    criterion_id="inc_2", type="inclusion",
                    text="ECOG 0-1", decision=CriterionDecision.MET,
                ),
                CriterionAssessment(
                    criterion_id="exc_1", type="exclusion",
                    text="No prior immunotherapy", decision=CriterionDecision.NOT_MET,
                ),
                CriterionAssessment(
                    criterion_id="inc_3", type="inclusion",
                    text="EGFR mutation", decision=CriterionDecision.UNKNOWN,
                ),
            ],
            gaps=[
                GapItem(
                    description="EGFR mutation status unknown",
                    recommended_action="Order EGFR mutation test",
                    clinical_importance="high",
                ),
            ],
        )
        assert ledger.met_count == 2
        assert ledger.not_met_count == 1
        assert ledger.unknown_count == 1
        assert len(ledger.gaps) == 1


class TestTemporalCheck:
    """TemporalCheck validation for time-windowed criteria (e.g., 'within 14 days')."""

    def test_within_window(self):
        """Evidence 7 days old should be within a 14-day window."""
        from trialpath.models.eligibility_ledger import TemporalCheck
        check = TemporalCheck(
            required_window_days=14,
            reference_date=date(2026, 1, 20),
            evaluation_date=date(2026, 1, 27),
            is_within_window=True,
        )
        assert check.days_elapsed == 7
        assert check.is_within_window is True

    def test_outside_window(self):
        """Evidence 21 days old should be outside a 14-day window."""
        from trialpath.models.eligibility_ledger import TemporalCheck
        check = TemporalCheck(
            required_window_days=14,
            reference_date=date(2026, 1, 1),
            evaluation_date=date(2026, 1, 22),
            is_within_window=False,
        )
        assert check.days_elapsed == 21
        assert check.is_within_window is False

    def test_no_reference_date(self):
        """Missing reference date should yield None for days_elapsed."""
        from trialpath.models.eligibility_ledger import TemporalCheck
        check = TemporalCheck(
            required_window_days=14,
            reference_date=None,
        )
        assert check.days_elapsed is None
        assert check.is_within_window is None

    def test_criterion_with_temporal_check(self):
        """CriterionAssessment should accept an optional temporal_check."""
        from trialpath.models.eligibility_ledger import TemporalCheck
        assessment = CriterionAssessment(
            criterion_id="inc_5",
            type="inclusion",
            text="ANC >= 1.5 x 10^9/L within 14 days of enrollment",
            decision=CriterionDecision.MET,
            temporal_check=TemporalCheck(
                required_window_days=14,
                reference_date=date(2026, 1, 20),
                evaluation_date=date(2026, 1, 27),
                is_within_window=True,
            ),
        )
        assert assessment.temporal_check is not None
        assert assessment.temporal_check.days_elapsed == 7
        assert assessment.temporal_check.is_within_window is True


class TestSearchLog:
    """SearchLog v1 β€” iterative query refinement tracking tests."""

    def test_add_step_increments_count(self):
        """Adding a refinement step should increment total_refinement_rounds."""
        from trialpath.models.search_log import SearchLog, RefinementAction
        log = SearchLog(session_id="S001", patient_id="P001")
        assert log.total_refinement_rounds == 0

        log.add_step(
            query_params={"condition": "NSCLC"},
            result_count=75,
            action=RefinementAction.REFINE,
            reason="Too many results, adding phase filter",
        )
        assert log.total_refinement_rounds == 1
        assert len(log.steps) == 1

    def test_refinement_exhausted_at_max(self):
        """After 5 refinement rounds, is_refinement_exhausted should be True."""
        from trialpath.models.search_log import SearchLog, RefinementAction
        log = SearchLog(session_id="S001", patient_id="P001")

        for i in range(5):
            log.add_step(
                query_params={"condition": "NSCLC", "round": i},
                result_count=0,
                action=RefinementAction.RELAX,
                reason=f"Relaxation round {i + 1}",
            )

        assert log.total_refinement_rounds == 5
        assert log.is_refinement_exhausted is True

    def test_transparency_summary_format(self):
        """to_transparency_summary should return list of dicts with expected keys."""
        from trialpath.models.search_log import SearchLog, RefinementAction
        log = SearchLog(session_id="S001", patient_id="P001")

        log.add_step(
            query_params={"condition": "NSCLC"},
            result_count=100,
            action=RefinementAction.REFINE,
            reason="Too many results",
        )
        log.add_step(
            query_params={"condition": "NSCLC", "phase": "Phase 3"},
            result_count=25,
            action=RefinementAction.SHORTLIST,
            reason="Right-sized result set",
        )

        summary = log.to_transparency_summary()
        assert len(summary) == 2
        assert summary[0]["step"] == 1
        assert summary[0]["found"] == 100
        assert summary[0]["action"] == "refine"
        assert summary[1]["step"] == 2
        assert summary[1]["found"] == 25
        assert summary[1]["action"] == "shortlist"

    def test_initial_search_no_refinement_count(self):
        """An INITIAL action should not increment the refinement counter."""
        from trialpath.models.search_log import SearchLog, RefinementAction
        log = SearchLog(session_id="S001", patient_id="P001")

        log.add_step(
            query_params={"condition": "NSCLC"},
            result_count=30,
            action=RefinementAction.INITIAL,
            reason="First search",
        )

        assert log.total_refinement_rounds == 0
        assert len(log.steps) == 1

7.2 Gemini Structured Output Tests

# tests/test_gemini.py
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trialpath.services.gemini_planner import GeminiPlanner
from trialpath.models.search_anchors import SearchAnchors
from trialpath.models.eligibility_ledger import (
    EligibilityLedger, OverallAssessment,
)


class TestGeminiSearchAnchorsGeneration:
    """Test Gemini structured output for SearchAnchors generation."""

    @pytest.fixture
    def sample_profile(self):
        return {
            "patient_id": "P001",
            "demographics": {"age": 52, "sex": "female"},
            "diagnosis": {
                "primary_condition": "Non-Small Cell Lung Cancer",
                "histology": "adenocarcinoma",
                "stage": "IVa",
            },
            "biomarkers": [
                {"name": "EGFR", "result": "Exon 19 deletion"},
            ],
            "performance_status": {"scale": "ECOG", "value": 1},
        }

    @pytest.mark.asyncio
    async def test_search_anchors_has_correct_condition(self, sample_profile):
        """Generated SearchAnchors should reference NSCLC."""
        with patch("google.genai.Client") as MockClient:
            mock_response = MagicMock()
            mock_response.text = SearchAnchors(
                condition="Non-Small Cell Lung Cancer",
                subtype="adenocarcinoma",
                biomarkers=["EGFR exon 19 deletion"],
                stage="IV",
                age=52,
                performance_status_max=1,
            ).model_dump_json()

            MockClient.return_value.models.generate_content = MagicMock(
                return_value=mock_response
            )

            planner = GeminiPlanner()
            anchors = await planner.generate_search_anchors(sample_profile)

            assert "lung" in anchors.condition.lower() or "nsclc" in anchors.condition.lower()
            assert anchors.age == 52

    @pytest.mark.asyncio
    async def test_search_anchors_includes_biomarkers(self, sample_profile):
        """SearchAnchors should include patient biomarkers."""
        with patch("google.genai.Client") as MockClient:
            mock_response = MagicMock()
            mock_response.text = SearchAnchors(
                condition="NSCLC",
                biomarkers=["EGFR exon 19 deletion"],
            ).model_dump_json()

            MockClient.return_value.models.generate_content = MagicMock(
                return_value=mock_response
            )

            planner = GeminiPlanner()
            anchors = await planner.generate_search_anchors(sample_profile)

            assert len(anchors.biomarkers) > 0
            assert any("EGFR" in b for b in anchors.biomarkers)

    @pytest.mark.asyncio
    async def test_search_anchors_json_schema_passed(self, sample_profile):
        """Verify that Gemini is called with response_json_schema."""
        with patch("google.genai.Client") as MockClient:
            mock_response = MagicMock()
            mock_response.text = SearchAnchors(condition="NSCLC").model_dump_json()

            mock_generate = MagicMock(return_value=mock_response)
            MockClient.return_value.models.generate_content = mock_generate

            planner = GeminiPlanner()
            await planner.generate_search_anchors(sample_profile)

            call_args = mock_generate.call_args
            config = call_args.kwargs.get("config", call_args[1].get("config", {}))
            assert config.get("response_mime_type") == "application/json"
            assert "response_json_schema" in config


class TestGeminiEligibilityEvaluation:
    """Test Gemini eligibility evaluation output."""

    @pytest.mark.asyncio
    async def test_ledger_has_all_required_fields(self):
        """EligibilityLedger from Gemini should have patient_id, nct_id, assessment."""
        mock_ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.UNCERTAIN,
            criteria=[],
            gaps=[],
        )

        assert mock_ledger.patient_id == "P001"
        assert mock_ledger.nct_id == "NCT01234567"
        assert mock_ledger.overall_assessment in OverallAssessment

    @pytest.mark.asyncio
    async def test_error_handling_invalid_json(self):
        """Should raise error on invalid Gemini JSON response."""
        with patch("google.genai.Client") as MockClient:
            mock_response = MagicMock()
            mock_response.text = "not valid json"

            MockClient.return_value.models.generate_content = MagicMock(
                return_value=mock_response
            )

            planner = GeminiPlanner()
            with pytest.raises(Exception):
                await planner.evaluate_eligibility({}, {}, None)

7.3 MedGemma Extraction Tests

# tests/test_medgemma.py
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trialpath.services.medgemma_extractor import MedGemmaExtractor
from trialpath.models.patient_profile import PatientProfile


class TestMedGemmaExtraction:
    """Test MedGemma extraction pipeline."""

    def test_parse_valid_json_output(self):
        """Should parse well-formed JSON from MedGemma."""
        extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)

        raw_output = '''
        {
            "patient_id": "P001",
            "diagnosis": {
                "primary_condition": "Non-Small Cell Lung Cancer",
                "histology": "adenocarcinoma",
                "stage": "IVa"
            },
            "performance_status": {
                "scale": "ECOG",
                "value": 1,
                "evidence": [{"doc_id": "clinic_1", "page": 2, "span_id": "s_17"}]
            },
            "biomarkers": [],
            "unknowns": [
                {"field": "EGFR", "reason": "No clear mention", "importance": "high"}
            ]
        }
        '''

        result = extractor._parse_profile(raw_output, {"age": 52, "sex": "female"})
        assert result["diagnosis"]["primary_condition"] == "Non-Small Cell Lung Cancer"
        assert result["demographics"]["age"] == 52

    def test_parse_json_in_code_block(self):
        """Should extract JSON from markdown code blocks."""
        extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)

        raw_output = '''Here is the extracted data:
        ```json
        {"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC", "stage": "IV"}}
        ```
        '''

        result = extractor._parse_profile(raw_output, {})
        assert result["diagnosis"]["primary_condition"] == "NSCLC"

    def test_parse_invalid_output_raises(self):
        """Should raise ValueError on unparseable output."""
        extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)

        with pytest.raises(ValueError, match="Could not parse"):
            extractor._parse_profile("This is not JSON at all.", {})

    def test_system_prompt_mentions_oncology(self):
        """System prompt should reference oncology expertise."""
        extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
        prompt = extractor._system_prompt()
        assert "oncology" in prompt.lower()

    def test_extraction_prompt_includes_all_fields(self):
        """Extraction prompt should request all PatientProfile fields."""
        extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
        prompt = extractor._build_extraction_prompt({"age": 52, "sex": "female"})

        required_fields = [
            "diagnosis", "performance_status", "biomarkers",
            "key_labs", "treatments", "comorbidities",
            "imaging_summary", "unknowns",
        ]
        for field in required_fields:
            assert field in prompt

    def test_extraction_prompt_includes_metadata(self):
        """Extraction prompt should include provided metadata."""
        extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
        prompt = extractor._build_extraction_prompt({"age": 65, "sex": "male"})
        assert "65" in prompt
        assert "male" in prompt

7.4 MCP Client Tests

# tests/test_mcp.py
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from trialpath.services.mcp_client import ClinicalTrialsMCPClient, MCPError
from trialpath.models.search_anchors import SearchAnchors, TrialFilters, GeographyFilter


class TestMCPClient:
    """Test ClinicalTrials MCP client."""

    @pytest.fixture
    def client(self):
        return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000")

    @pytest.fixture
    def sample_anchors(self):
        return SearchAnchors(
            condition="Non-Small Cell Lung Cancer",
            subtype="adenocarcinoma",
            biomarkers=["EGFR exon 19 deletion"],
            stage="IV",
            age=52,
            geography=GeographyFilter(country="United States"),
            trial_filters=TrialFilters(
                recruitment_status=["Recruiting"],
                phase=["Phase 3"],
            ),
        )

    @pytest.mark.asyncio
    async def test_search_builds_correct_query(self, client, sample_anchors):
        """Search should combine condition, subtype, and biomarkers into query."""
        with patch("httpx.AsyncClient") as MockHTTP:
            mock_response = MagicMock()
            mock_response.json.return_value = {
                "result": {"studies": []}
            }
            mock_response.raise_for_status = MagicMock()

            mock_client = AsyncMock()
            mock_client.post.return_value = mock_response
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            await client.search(sample_anchors)

            call_args = mock_client.post.call_args
            body = call_args.kwargs.get("json", call_args[1].get("json", {}))
            query = body["params"]["arguments"]["query"]

            assert "Non-Small Cell Lung Cancer" in query
            assert "adenocarcinoma" in query

    @pytest.mark.asyncio
    async def test_search_includes_country_filter(self, client, sample_anchors):
        """Search should pass country as a parameter."""
        with patch("httpx.AsyncClient") as MockHTTP:
            mock_response = MagicMock()
            mock_response.json.return_value = {"result": {"studies": []}}
            mock_response.raise_for_status = MagicMock()

            mock_client = AsyncMock()
            mock_client.post.return_value = mock_response
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            await client.search(sample_anchors)

            call_args = mock_client.post.call_args
            body = call_args.kwargs.get("json", call_args[1].get("json", {}))
            args = body["params"]["arguments"]

            assert args.get("country") == "United States"

    @pytest.mark.asyncio
    async def test_search_includes_recruitment_status_filter(self, client, sample_anchors):
        """Search should include recruitment status in filter expression."""
        with patch("httpx.AsyncClient") as MockHTTP:
            mock_response = MagicMock()
            mock_response.json.return_value = {"result": {"studies": []}}
            mock_response.raise_for_status = MagicMock()

            mock_client = AsyncMock()
            mock_client.post.return_value = mock_response
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            await client.search(sample_anchors)

            call_args = mock_client.post.call_args
            body = call_args.kwargs.get("json", call_args[1].get("json", {}))
            filter_str = body["params"]["arguments"].get("filter", "")

            assert "OverallStatus" in filter_str
            assert "Recruiting" in filter_str

    @pytest.mark.asyncio
    async def test_get_study_by_nct_id(self, client):
        """Should call get_study tool with correct NCT ID."""
        with patch("httpx.AsyncClient") as MockHTTP:
            mock_response = MagicMock()
            mock_response.json.return_value = {
                "result": {
                    "studies": [{"nctId": "NCT01234567", "title": "Test Trial"}]
                }
            }
            mock_response.raise_for_status = MagicMock()

            mock_client = AsyncMock()
            mock_client.post.return_value = mock_response
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            result = await client.get_study("NCT01234567")
            assert result["nctId"] == "NCT01234567"

    @pytest.mark.asyncio
    async def test_mcp_error_handling(self, client):
        """Should raise MCPError on MCP server error response."""
        with patch("httpx.AsyncClient") as MockHTTP:
            mock_response = MagicMock()
            mock_response.json.return_value = {
                "error": {"code": -32600, "message": "Invalid request"}
            }
            mock_response.raise_for_status = MagicMock()

            mock_client = AsyncMock()
            mock_client.post.return_value = mock_response
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            with pytest.raises(MCPError, match="Invalid request"):
                await client.get_study("NCT00000000")

    @pytest.mark.asyncio
    async def test_find_eligible_passes_demographics(self, client):
        """find_eligible should pass patient demographics correctly."""
        with patch("httpx.AsyncClient") as MockHTTP:
            mock_response = MagicMock()
            mock_response.json.return_value = {
                "result": {"eligibleStudies": [], "totalMatches": 0}
            }
            mock_response.raise_for_status = MagicMock()

            mock_client = AsyncMock()
            mock_client.post.return_value = mock_response
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            await client.find_eligible(
                age=52, sex="Female",
                conditions=["NSCLC"],
                country="United States",
            )

            call_args = mock_client.post.call_args
            body = call_args.kwargs.get("json", call_args[1].get("json", {}))
            args = body["params"]["arguments"]

            assert args["age"] == 52
            assert args["sex"] == "Female"
            assert args["conditions"] == ["NSCLC"]

7.5 Parlant Agent Behavior Tests

# tests/test_parlant_agent.py
import pytest
from unittest.mock import AsyncMock, patch, MagicMock


class TestParlantAgentSetup:
    """Test Parlant agent creation and configuration."""

    @pytest.mark.asyncio
    async def test_agent_creation(self):
        """Agent should be created with correct attributes."""
        with patch("parlant.sdk.Server") as MockServer:
            mock_server = AsyncMock()
            mock_agent = MagicMock()
            mock_agent.id = "patient-trial-copilot"
            mock_agent.name = "Patient Trial Copilot"
            mock_server.create_agent = AsyncMock(return_value=mock_agent)
            MockServer.return_value.__aenter__ = AsyncMock(return_value=mock_server)
            MockServer.return_value.__aexit__ = AsyncMock()

            from trialpath.agent.setup import setup_agent
            server, agent = await setup_agent()

            mock_server.create_agent.assert_called_once()
            call_kwargs = mock_server.create_agent.call_args.kwargs
            assert call_kwargs["name"] == "Patient Trial Copilot"
            assert "NSCLC" in call_kwargs["description"]

    @pytest.mark.asyncio
    async def test_guidelines_include_all_states(self):
        """Guidelines should cover INGEST, PRESCREEN, VALIDATE, GAP, SUMMARY."""
        mock_agent = AsyncMock()
        guidelines_created = []

        async def track_guideline(**kwargs):
            guidelines_created.append(kwargs)
            return MagicMock()

        mock_agent.create_guideline = track_guideline

        from trialpath.agent.guidelines import configure_guidelines
        await configure_guidelines(mock_agent)

        all_conditions = " ".join(g["condition"] for g in guidelines_created)

        assert "uploaded" in all_conditions.lower() or "document" in all_conditions.lower()
        assert "prescreen" in all_conditions.lower() or "sufficient" in all_conditions.lower()
        assert "validation" in all_conditions.lower() or "eligibility" in all_conditions.lower()
        assert "gap" in all_conditions.lower() or "missing" in all_conditions.lower() or "unknown" in all_conditions.lower()
        assert "summary" in all_conditions.lower() or "complete" in all_conditions.lower()

    @pytest.mark.asyncio
    async def test_guidelines_include_medical_disclaimer(self):
        """Guidelines must include an always-on medical disclaimer."""
        mock_agent = AsyncMock()
        guidelines_created = []
        async def track_guideline(**kwargs):
            guidelines_created.append(kwargs)
            return MagicMock()
        mock_agent.create_guideline = track_guideline

        from trialpath.agent.guidelines import configure_guidelines
        await configure_guidelines(mock_agent)

        disclaimer_found = any(
            "information only" in g.get("action", "").lower() or
            "not medical advice" in g.get("action", "").lower()
            for g in guidelines_created
        )
        assert disclaimer_found, "Must include medical disclaimer guideline"


class TestParlantJourney:
    """Test Parlant journey state machine."""

    @pytest.mark.asyncio
    async def test_journey_has_five_states(self):
        """Journey should define states for all 5 workflow phases."""
        # This test validates the journey structure definition
        # In integration, we verify via Parlant's mermaid endpoint
        mock_agent = AsyncMock()
        mock_journey = MagicMock()
        mock_journey.initial_state = MagicMock()

        # Create chain of mock transitions
        mock_state = MagicMock()
        mock_transition = MagicMock()
        mock_transition.target = mock_state
        mock_state.transition_to = AsyncMock(return_value=mock_transition)
        mock_state.fork = AsyncMock(return_value=mock_state)
        mock_journey.initial_state.transition_to = AsyncMock(return_value=mock_transition)

        mock_agent.create_journey = AsyncMock(return_value=mock_journey)

        from trialpath.agent.setup import create_clinical_trial_journey
        journey = await create_clinical_trial_journey(mock_agent)

        mock_agent.create_journey.assert_called_once()
        call_kwargs = mock_agent.create_journey.call_args.kwargs
        assert "Clinical Trial" in call_kwargs["title"]

7.6 End-to-End Flow Tests

# tests/test_e2e.py
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trialpath.models.patient_profile import PatientProfile, Diagnosis, PerformanceStatus, Demographics
from trialpath.models.search_anchors import SearchAnchors
from trialpath.models.trial_candidate import TrialCandidate, EligibilityText
from trialpath.models.eligibility_ledger import (
    EligibilityLedger, CriterionAssessment, CriterionDecision,
    OverallAssessment, GapItem,
)


class TestEndToEndFlow:
    """End-to-end workflow tests simulating the full patient journey."""

    @pytest.fixture
    def anna_profile(self):
        """Synthetic patient 'Anna' -- NSCLC Stage IV with EGFR mutation."""
        return PatientProfile(
            patient_id="ANNA_001",
            demographics=Demographics(age=52, sex="female"),
            diagnosis=Diagnosis(
                primary_condition="Non-Small Cell Lung Cancer",
                histology="adenocarcinoma",
                stage="IVa",
            ),
            performance_status=PerformanceStatus(scale="ECOG", value=1),
        )

    def test_ingest_to_prescreen_transition(self, anna_profile):
        """Profile with minimum data should transition from INGEST to PRESCREEN."""
        assert anna_profile.has_minimum_prescreen_data() is True

    def test_incomplete_profile_stays_in_ingest(self):
        """Profile without stage should stay in INGEST."""
        incomplete = PatientProfile(
            patient_id="INC_001",
            diagnosis=Diagnosis(primary_condition="NSCLC"),
        )
        assert incomplete.has_minimum_prescreen_data() is False

    def test_prescreen_to_validate_with_results(self):
        """Finding trials should trigger transition to VALIDATE_TRIALS."""
        candidates = [
            TrialCandidate(
                nct_id="NCT01234567",
                title="Test Trial",
                conditions=["NSCLC"],
                fingerprint_text="Test NSCLC trial",
            ),
        ]
        assert len(candidates) > 0  # Transition condition met

    def test_validate_to_gap_followup_with_unknowns(self):
        """All-red ledgers with unknowns should trigger GAP_FOLLOWUP."""
        ledger = EligibilityLedger(
            patient_id="ANNA_001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.LIKELY_INELIGIBLE,
            criteria=[
                CriterionAssessment(
                    criterion_id="inc_1", type="inclusion",
                    text="EGFR mutation required",
                    decision=CriterionDecision.UNKNOWN,
                ),
            ],
            gaps=[
                GapItem(
                    description="EGFR mutation status not provided",
                    recommended_action="Provide EGFR mutation test results",
                    clinical_importance="high",
                ),
            ],
        )

        # Transition condition: all ineligible but gaps exist
        assert ledger.overall_assessment == OverallAssessment.LIKELY_INELIGIBLE
        assert ledger.unknown_count > 0
        assert len(ledger.gaps) > 0

    def test_validate_to_summary_with_eligible(self):
        """At least one eligible trial should allow SUMMARY transition."""
        ledger = EligibilityLedger(
            patient_id="ANNA_001",
            nct_id="NCT01234567",
            overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
            criteria=[
                CriterionAssessment(
                    criterion_id="inc_1", type="inclusion",
                    text="Stage IV NSCLC", decision=CriterionDecision.MET,
                ),
                CriterionAssessment(
                    criterion_id="inc_2", type="inclusion",
                    text="ECOG 0-1", decision=CriterionDecision.MET,
                ),
            ],
        )

        assert ledger.traffic_light == "green"
        assert ledger.met_count == 2
        assert ledger.not_met_count == 0

    @pytest.mark.asyncio
    async def test_full_pipeline_mock(self, anna_profile):
        """Test full pipeline: profile -> anchors -> search -> validate."""
        # Step 1: Profile is ready
        assert anna_profile.has_minimum_prescreen_data()

        # Step 2: Generate anchors (mocked)
        anchors = SearchAnchors(
            condition="NSCLC",
            subtype="adenocarcinoma",
            stage="IV",
            age=52,
        )
        assert anchors.condition == "NSCLC"

        # Step 3: Search returns candidates (mocked)
        candidates = [
            TrialCandidate(
                nct_id="NCT11111111",
                title="EGFR+ NSCLC Phase 3",
                conditions=["NSCLC"],
                phase="Phase 3",
                status="Recruiting",
                fingerprint_text="EGFR NSCLC Phase 3 osimertinib",
                eligibility_text=EligibilityText(
                    inclusion="Stage IV NSCLC with EGFR mutation",
                    exclusion="Prior TKI therapy",
                ),
            ),
        ]
        assert len(candidates) == 1

        # Step 4: Validate eligibility (mocked)
        ledger = EligibilityLedger(
            patient_id=anna_profile.patient_id,
            nct_id=candidates[0].nct_id,
            overall_assessment=OverallAssessment.UNCERTAIN,
            criteria=[
                CriterionAssessment(
                    criterion_id="inc_1", type="inclusion",
                    text="Stage IV NSCLC", decision=CriterionDecision.MET,
                ),
                CriterionAssessment(
                    criterion_id="inc_2", type="inclusion",
                    text="EGFR mutation", decision=CriterionDecision.UNKNOWN,
                ),
            ],
            gaps=[
                GapItem(
                    description="EGFR mutation status unknown",
                    recommended_action="Submit EGFR test results",
                    clinical_importance="high",
                ),
            ],
        )

        assert ledger.traffic_light == "yellow"
        assert ledger.unknown_count == 1
        assert len(ledger.gaps) == 1


class TestErrorHandling:
    """Test error handling and edge cases."""

    def test_empty_search_results(self):
        """System should handle zero search results gracefully."""
        candidates = []
        # Should trigger GAP_FOLLOWUP with criteria relaxation
        assert len(candidates) == 0

    def test_ledger_with_all_met_criteria(self):
        """All-met criteria should yield likely_eligible assessment."""
        ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT99999999",
            overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
            criteria=[
                CriterionAssessment(
                    criterion_id=f"inc_{i}", type="inclusion",
                    text=f"Criterion {i}", decision=CriterionDecision.MET,
                )
                for i in range(5)
            ],
        )
        assert ledger.met_count == 5
        assert ledger.not_met_count == 0
        assert ledger.traffic_light == "green"

    def test_ledger_with_single_exclusion_not_met(self):
        """A single unmet exclusion criterion should still be flagged."""
        ledger = EligibilityLedger(
            patient_id="P001",
            nct_id="NCT99999999",
            overall_assessment=OverallAssessment.LIKELY_INELIGIBLE,
            criteria=[
                CriterionAssessment(
                    criterion_id="exc_1", type="exclusion",
                    text="No prior immunotherapy",
                    decision=CriterionDecision.NOT_MET,
                ),
            ],
        )
        assert ledger.not_met_count == 1
        assert ledger.traffic_light == "red"

    @pytest.mark.asyncio
    async def test_mcp_timeout_handling(self):
        """MCP client should handle network timeouts."""
        client = ClinicalTrialsMCPClient(mcp_url="http://localhost:3000")

        with patch("httpx.AsyncClient") as MockHTTP:
            mock_client = AsyncMock()
            mock_client.post.side_effect = httpx.TimeoutException("Connection timed out")
            mock_client.__aenter__ = AsyncMock(return_value=mock_client)
            mock_client.__aexit__ = AsyncMock()
            MockHTTP.return_value = mock_client

            import httpx
            with pytest.raises(httpx.TimeoutException):
                await client.get_study("NCT01234567")

    def test_cost_tracker_budget_enforcement(self):
        """Cost tracker should detect over-budget sessions."""
        from trialpath.services.gemini_planner import GeminiCostTracker

        tracker = GeminiCostTracker(budget_usd=0.50)
        assert not tracker.over_budget

        # Simulate heavy usage
        tracker.record(input_tokens=200_000, output_tokens=50_000)
        # Cost: 200k/1M * 1.25 + 50k/1M * 5.0 = 0.25 + 0.25 = 0.50
        assert tracker.total_cost == pytest.approx(0.50, abs=0.01)

        tracker.record(input_tokens=10_000, output_tokens=1_000)
        assert tracker.over_budget

8. Appendix: API Reference

8.1 Parlant Core API

Endpoint Method Description
/agents POST Create agent
/agents GET List agents
/agents/{id} GET Get agent
/agents/{id} PATCH Update agent
/agents/{id} DELETE Delete agent
/sessions POST Create session
/sessions/{id}/events POST Send event (message)
/sessions/{id}/events GET List events (long-polling)
/guidelines POST Create guideline
/guidelines/{id} PATCH Update guideline
/guidelines GET List guidelines

Parlant Python SDK Core:

parlant.sdk.Server(nlp_service, session_store)
  .create_agent(name, description, id?, max_engine_iterations?)
  .list_agents()
  .find_agent(id)

parlant.sdk.Agent
  .create_guideline(condition, action, tools?)
  .create_journey(title, conditions, description)
  .create_variable(name, tool, update_interval)

parlant.sdk.Journey
  .initial_state
  .title

parlant.sdk.JourneyState
  .transition_to(chat_state? | tool_state?, condition?, description?)
  .fork()

parlant.sdk.NLPServices
  .gemini  -- Google Gemini
  .openai  -- OpenAI
  .anthropic  -- Anthropic
  .together  -- Together.ai

8.2 Google GenAI SDK Reference

from google import genai
from google.genai import types

# Client
client = genai.Client(api_key="...")

# Structured output
client.models.generate_content(
    model="gemini-3-pro",
    contents="...",
    config={
        "response_mime_type": "application/json",
        "response_json_schema": MyModel.model_json_schema(),
    },
)

# Function calling (manual)
tools = types.Tool(function_declarations=[{
    "name": "...",
    "description": "...",
    "parameters": {...},
}])
config = types.GenerateContentConfig(tools=[tools])

# Function calling (automatic)
config = types.GenerateContentConfig(tools=[my_python_function])

# Token counting
client.models.count_tokens(model="...", contents="...")

8.3 MedGemma Model Reference

Model ID:        google/medgemma-4b-it (instruction-tuned)
                 google/medgemma-1.5-4b-it (v1.5, Jan 2026)
Architecture:    Gemma 3 + SigLIP medical image encoder
Input:           Text + medical images (radiology, pathology, ophthalmology, dermatology)
Output:          Text only
Max Output:      8192 tokens
Context Length:  128K+ tokens
Image Resolution: 896 x 896 (normalized), 256 tokens per image
Requirements:    transformers >= 4.50.0, torch, accelerate
Pipeline Task:   "image-text-to-text"

8.4 ClinicalTrials MCP Server Reference

Base API URL:     https://clinicaltrials.gov/api/v2
MCP Server:       cyanheads/clinicaltrialsgov-mcp-server

Tools:
  clinicaltrials_search_studies    -- Search with query + filters
  clinicaltrials_get_study         -- Get study by NCT ID(s)
  clinicaltrials_find_eligible_studies  -- Match patient to trials
  clinicaltrials_analyze_trends    -- Statistical analysis
  clinicaltrials_compare_studies   -- Compare 2-5 studies

Filter Syntax:
  AREA[OverallStatus]Recruiting
  AREA[Phase]Phase 3
  AREA[MinimumAge]RANGE[MIN, 52]
  AREA[LocationCountry]United States

  Combine with AND/OR:
  (AREA[OverallStatus]Recruiting) AND (AREA[Phase]Phase 3)

Rate Limiting:
  Internal 250ms delay between API calls (API_CALL_DELAY_MS)

Error Handling:
  McpError with JsonRpcErrorCode
  "Logic Throws, Handlers Catch" pattern

8.5 Environment Variables

# Parlant
PARLANT_HOME=./parlant-data     # Data directory
PARLANT_HOST=0.0.0.0            # Server host
PARLANT_PORT=8800               # Server port

# Gemini
GEMINI_API_KEY=your-key-here    # Google AI API key

# MedGemma (HuggingFace)
HF_TOKEN=your-hf-token         # HuggingFace access token
MEDGEMMA_MODEL_ID=google/medgemma-4b-it
MEDGEMMA_DEVICE=cuda            # or cpu for testing

# ClinicalTrials MCP
MCP_SERVER_URL=http://localhost:3000  # MCP server URL

# Cost Control
MAX_SESSION_COST_USD=0.50       # Budget per patient session