Initial PoC codebase
Browse filesCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- CLAUDE.md +32 -40
- architecture/overview.md +119 -0
- trialpath/services/parlant_client.py +94 -0
- trialpath/tests/test_parlant.py +230 -0
CLAUDE.md
CHANGED
|
@@ -4,58 +4,50 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|
| 4 |
|
| 5 |
## Project Overview
|
| 6 |
|
| 7 |
-
TrialPath is an AI-powered clinical trial matching system for NSCLC (Non-Small Cell Lung Cancer) patients.
|
| 8 |
|
| 9 |
**Core idea:** Help patients understand which clinical trials they may qualify for, transform "rejection" into "actionable next steps" via gap analysis.
|
| 10 |
|
| 11 |
-
##
|
| 12 |
|
| 13 |
-
|
| 14 |
-
- `TrialPath AI Synergy in Digital Health Trials.md` β Technical architecture, data contracts, Parlant workflow design
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
2. **Parlant Agent + Journey** β Single agent (`patient_trial_copilot`) with 5 states: `INGEST` β `PRESCREEN` β `VALIDATE_TRIALS` β `GAP_FOLLOWUP` β `SUMMARY`
|
| 20 |
-
3. **MedGemma 4B** (HF endpoint) β Multimodal extraction from PDFs/images β `PatientProfile` + evidence spans
|
| 21 |
-
4. **Gemini 3 Pro** β LLM planner: generates `SearchAnchors` from profile, reranks trials, orchestrates criterion evaluation
|
| 22 |
-
5. **ClinicalTrials MCP Server** (existing, not custom) β Wraps ClinicalTrials.gov REST API v2
|
| 23 |
|
| 24 |
-
##
|
| 25 |
|
| 26 |
-
- **No vector DB / RAG** β Uses agentic search via ClinicalTrials.gov API with iterative query refinement
|
| 27 |
-
- **Reuse existing MCP** β Don't build custom trial search; use off-the-shelf ClinicalTrials MCP servers
|
| 28 |
-
- **Two-stage clinical screening** β Mirrors real-world: prescreen (minimal dataset) β validation (full criterion-by-criterion)
|
| 29 |
-
- **Evidence-linked** β Every decision must cite source doc/page/span
|
| 30 |
-
- **Gap analysis as core differentiator** β "You'd qualify IF you had X" rather than just "No match"
|
| 31 |
-
|
| 32 |
-
## Data Contracts (JSON Schemas)
|
| 33 |
-
|
| 34 |
-
Four core contracts defined in the tech design doc (section 4):
|
| 35 |
-
- **PatientProfile v1** β MedGemma output with demographics, diagnosis, biomarkers, labs, treatments, unknowns
|
| 36 |
-
- **SearchAnchors v1** β Gemini-generated query params for MCP search
|
| 37 |
-
- **TrialCandidate v1** β Normalized MCP search results
|
| 38 |
-
- **EligibilityLedger v1** β Per-trial criterion-level assessment with evidence pointers and gaps
|
| 39 |
-
|
| 40 |
-
## Planned Code Structure
|
| 41 |
-
|
| 42 |
-
From PRD deliverables section:
|
| 43 |
```
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
```
|
| 50 |
|
| 51 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
- Python (Streamlit
|
| 54 |
-
- Google Gemini 3 Pro (orchestration)
|
| 55 |
-
- MedGemma 4B via Hugging Face endpoint (multimodal extraction)
|
| 56 |
-
- Parlant (agentic workflow engine)
|
| 57 |
-
-
|
| 58 |
-
- TREC Clinical Trials Track 2021/2022 (benchmarking)
|
| 59 |
|
| 60 |
## Success Targets
|
| 61 |
|
|
|
|
| 4 |
|
| 5 |
## Project Overview
|
| 6 |
|
| 7 |
+
TrialPath is an AI-powered clinical trial matching system for NSCLC (Non-Small Cell Lung Cancer) patients. Currently in **PoC phase** β models, service stubs, and UI with mock data are implemented; live AI integrations are pending.
|
| 8 |
|
| 9 |
**Core idea:** Help patients understand which clinical trials they may qualify for, transform "rejection" into "actionable next steps" via gap analysis.
|
| 10 |
|
| 11 |
+
## Architecture
|
| 12 |
|
| 13 |
+
See `architecture/overview.md` for full architecture diagram, data flow, component details, and implementation status.
|
|
|
|
| 14 |
|
| 15 |
+
**5 Components**: Streamlit UI β Parlant Orchestrator β MedGemma 4B (extraction) + Gemini 3 Pro (planning) + ClinicalTrials MCP Server (search)
|
| 16 |
|
| 17 |
+
**5 Data Contracts** (Pydantic v2 in `trialpath/models/`): `PatientProfile`, `SearchAnchors`, `TrialCandidate`, `EligibilityLedger`, `SearchLog`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
## Project Structure
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
```
|
| 22 |
+
trialpath/ # Backend module
|
| 23 |
+
models/ # 5 Pydantic v2 data contracts (implemented)
|
| 24 |
+
services/ # 4 service stubs: medgemma, gemini, mcp, parlant
|
| 25 |
+
agent/ # Parlant journey logic (not yet implemented)
|
| 26 |
+
tests/ # Backend TDD tests (37+ model, 33 service)
|
| 27 |
+
app/ # Streamlit frontend
|
| 28 |
+
pages/ # 5-page journey (upload β profile β matching β gaps β summary)
|
| 29 |
+
components/ # 6 reusable widgets
|
| 30 |
+
services/ # State manager, parlant client, mock data
|
| 31 |
+
tests/ # Frontend TDD tests (30+ component, 5 page)
|
| 32 |
+
tests/ # Integration tests (18 tests)
|
| 33 |
+
architecture/ # Architecture documentation
|
| 34 |
+
docs/ # Design docs and TDD guides
|
| 35 |
```
|
| 36 |
|
| 37 |
+
## Documents
|
| 38 |
+
|
| 39 |
+
- `docs/Trialpath PRD.md` β Product requirements, success metrics, HAI-DEF submission plan
|
| 40 |
+
- `docs/TrialPath AI technical design.md` β Technical architecture, data contracts, Parlant workflow
|
| 41 |
+
- `docs/tdd-guide-*.md` β TDD implementation guides (backend, frontend, data/eval)
|
| 42 |
+
- `architecture/overview.md` β Architecture overview, data flow, component status
|
| 43 |
+
|
| 44 |
+
## Tech Stack
|
| 45 |
|
| 46 |
+
- Python 3.11+ (Streamlit + Pydantic v2)
|
| 47 |
+
- Google Gemini 3 Pro (orchestration) β stubbed
|
| 48 |
+
- MedGemma 4B via Hugging Face endpoint (multimodal extraction) β stubbed
|
| 49 |
+
- Parlant (agentic workflow engine) β client ready, agent pending
|
| 50 |
+
- ClinicalTrials MCP Server (ClinicalTrials.gov API v2) β client ready
|
|
|
|
| 51 |
|
| 52 |
## Success Targets
|
| 53 |
|
architecture/overview.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TrialPath Architecture Overview
|
| 2 |
+
|
| 3 |
+
## System Architecture
|
| 4 |
+
|
| 5 |
+
TrialPath is composed of 5 core components connected via async HTTP/JSON-RPC:
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
+---------------------+
|
| 9 |
+
| Streamlit UI |
|
| 10 |
+
| (5-page journey) |
|
| 11 |
+
+----------+----------+
|
| 12 |
+
|
|
| 13 |
+
+----------v----------+
|
| 14 |
+
| Parlant Engine |
|
| 15 |
+
| (orchestrator) |
|
| 16 |
+
| patient_trial_copilot|
|
| 17 |
+
+----+-----+-----+----+
|
| 18 |
+
| | |
|
| 19 |
+
+------------+ +--+--+ +------------+
|
| 20 |
+
| | | |
|
| 21 |
+
+--------v-------+ +----v----+ +------v--------+
|
| 22 |
+
| MedGemma 4B | | Gemini | | ClinicalTrials|
|
| 23 |
+
| (HF endpoint) | | 3 Pro | | MCP Server |
|
| 24 |
+
| extraction | | planner | | (API v2) |
|
| 25 |
+
+----------------+ +---------+ +---------------+
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Component Details
|
| 29 |
+
|
| 30 |
+
### 1. UI & Orchestrator (Streamlit)
|
| 31 |
+
|
| 32 |
+
- **Location**: `app/`, `streamlit_app.py`
|
| 33 |
+
- **Role**: 5-page patient journey matching Parlant states
|
| 34 |
+
- **Pages**: Upload -> Profile Review -> Trial Matching -> Gap Analysis -> Summary
|
| 35 |
+
- **Components**: `app/components/` (6 reusable widgets)
|
| 36 |
+
- **State**: `app/services/state_manager.py` (session-based journey state)
|
| 37 |
+
|
| 38 |
+
### 2. Parlant Agent + Journey
|
| 39 |
+
|
| 40 |
+
- **Location**: `trialpath/agent/` (not yet implemented)
|
| 41 |
+
- **Role**: Single agent `patient_trial_copilot` with 5 states
|
| 42 |
+
- **States**: `INGEST` -> `PRESCREEN` -> `VALIDATE_TRIALS` -> `GAP_FOLLOWUP` -> `SUMMARY`
|
| 43 |
+
- **Client**: `trialpath/services/parlant_client.py` (async REST wrapper)
|
| 44 |
+
|
| 45 |
+
### 3. MedGemma 4B (Multimodal Extraction)
|
| 46 |
+
|
| 47 |
+
- **Location**: `trialpath/services/medgemma_extractor.py`
|
| 48 |
+
- **Role**: Extract structured `PatientProfile` from PDFs/images
|
| 49 |
+
- **Output**: Demographics, diagnosis, biomarkers, labs, treatments, unknowns + evidence spans
|
| 50 |
+
- **Status**: Prompt templates ready, HF endpoint integration pending
|
| 51 |
+
|
| 52 |
+
### 4. Gemini 3 Pro (LLM Planner)
|
| 53 |
+
|
| 54 |
+
- **Location**: `trialpath/services/gemini_planner.py`
|
| 55 |
+
- **Role**: Generate `SearchAnchors` from profile, evaluate eligibility criteria, produce `EligibilityLedger`
|
| 56 |
+
- **Output**: Structured JSON via google-genai client
|
| 57 |
+
- **Status**: Prompt templates and structured output calls stubbed
|
| 58 |
+
|
| 59 |
+
### 5. ClinicalTrials MCP Server
|
| 60 |
+
|
| 61 |
+
- **Location**: `trialpath/services/mcp_client.py`
|
| 62 |
+
- **Role**: JSON-RPC client wrapping ClinicalTrials.gov REST API v2
|
| 63 |
+
- **Tools**: `search_studies`, `get_study`, `find_eligible`, `compare_studies`
|
| 64 |
+
- **Status**: Client wrapper implemented, needs running MCP server
|
| 65 |
+
|
| 66 |
+
## Data Flow
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
Patient Document (PDF/image)
|
| 70 |
+
|
|
| 71 |
+
v
|
| 72 |
+
[MedGemma 4B] -- extracts --> PatientProfile v1
|
| 73 |
+
|
|
| 74 |
+
v
|
| 75 |
+
[Gemini 3 Pro] -- generates --> SearchAnchors v1
|
| 76 |
+
|
|
| 77 |
+
v
|
| 78 |
+
[MCP Client] -- queries --> ClinicalTrials.gov API
|
| 79 |
+
|
|
| 80 |
+
v
|
| 81 |
+
[MCP Client] -- returns --> TrialCandidate v1 (list)
|
| 82 |
+
|
|
| 83 |
+
v
|
| 84 |
+
[Gemini 3 Pro] -- evaluates --> EligibilityLedger v1 (per trial)
|
| 85 |
+
| (criterion-level verdicts + gaps)
|
| 86 |
+
v
|
| 87 |
+
[UI Summary] -- renders --> Doctor Packet (JSON/Markdown)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## Data Contracts
|
| 91 |
+
|
| 92 |
+
All defined as Pydantic v2 models in `trialpath/models/`:
|
| 93 |
+
|
| 94 |
+
| Contract | File | Purpose |
|
| 95 |
+
|----------|------|---------|
|
| 96 |
+
| `PatientProfile` | `patient_profile.py` | MedGemma output: 11 fields (demographics, diagnosis, biomarkers, labs, treatments, unknowns) |
|
| 97 |
+
| `SearchAnchors` | `search_anchors.py` | Gemini-generated query params for MCP search with relaxation order |
|
| 98 |
+
| `TrialCandidate` | `trial_candidate.py` | Normalized MCP search results (NCT ID, phase, locations, eligibility text) |
|
| 99 |
+
| `EligibilityLedger` | `eligibility_ledger.py` | Per-trial criterion assessment with traffic-light status + gap tracking |
|
| 100 |
+
| `SearchLog` | `search_log.py` | Iterative query refinement tracking (max 5 rounds) |
|
| 101 |
+
|
| 102 |
+
## Current Implementation Status
|
| 103 |
+
|
| 104 |
+
| Component | Models | Service Stub | UI | Tests | Live Integration |
|
| 105 |
+
|-----------|--------|-------------|-----|-------|-----------------|
|
| 106 |
+
| Data Models | Done | - | - | 37 tests | - |
|
| 107 |
+
| MedGemma | Done | Prompts ready | Mock | 5 tests | Pending |
|
| 108 |
+
| Gemini | Done | Prompts ready | Mock | 7 tests | Pending |
|
| 109 |
+
| MCP Client | Done | Wrapper done | Mock | 6 tests | Pending |
|
| 110 |
+
| Parlant | Done | Client done | Mock | 15 tests | Pending |
|
| 111 |
+
| Streamlit UI | - | - | 5 pages, 6 components | 30+ tests | Mock data |
|
| 112 |
+
|
| 113 |
+
## Key Design Decisions
|
| 114 |
+
|
| 115 |
+
- **No vector DB / RAG**: Agentic search via ClinicalTrials.gov API with iterative query refinement
|
| 116 |
+
- **Reuse existing MCP**: Off-the-shelf ClinicalTrials MCP server, no custom search
|
| 117 |
+
- **Two-stage screening**: Prescreen (minimal dataset) -> Validation (full criterion-by-criterion)
|
| 118 |
+
- **Evidence-linked**: Every decision cites source doc/page/span
|
| 119 |
+
- **Gap analysis**: "You'd qualify IF you had X" rather than just "No match"
|
trialpath/services/parlant_client.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Async Parlant REST API client for TrialPath backend services."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import httpx
|
| 8 |
+
|
| 9 |
+
_DEFAULT_BASE_URL = "http://localhost:8800"
|
| 10 |
+
_DEFAULT_TIMEOUT = 65.0 # > long-poll wait_for_data default
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ParlantClient:
|
| 14 |
+
"""Async wrapper around the Parlant REST API."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
base_url: str = _DEFAULT_BASE_URL,
|
| 19 |
+
*,
|
| 20 |
+
transport: Optional[httpx.AsyncBaseTransport] = None,
|
| 21 |
+
timeout: float = _DEFAULT_TIMEOUT,
|
| 22 |
+
) -> None:
|
| 23 |
+
self.base_url = base_url
|
| 24 |
+
self._http = httpx.AsyncClient(
|
| 25 |
+
base_url=base_url,
|
| 26 |
+
timeout=timeout,
|
| 27 |
+
transport=transport,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ------ context manager ------
|
| 31 |
+
|
| 32 |
+
async def __aenter__(self) -> ParlantClient:
|
| 33 |
+
return self
|
| 34 |
+
|
| 35 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 36 |
+
await self.close()
|
| 37 |
+
|
| 38 |
+
# ------ sessions ------
|
| 39 |
+
|
| 40 |
+
async def create_session(
|
| 41 |
+
self, agent_id: str, customer_id: Optional[str] = None
|
| 42 |
+
) -> str:
|
| 43 |
+
"""Create a Parlant session and return the session_id."""
|
| 44 |
+
payload: dict = {"agent_id": agent_id}
|
| 45 |
+
if customer_id:
|
| 46 |
+
payload["customer_id"] = customer_id
|
| 47 |
+
resp = await self._http.post("/sessions", json=payload)
|
| 48 |
+
resp.raise_for_status()
|
| 49 |
+
return resp.json()["session_id"]
|
| 50 |
+
|
| 51 |
+
async def get_session_status(self, session_id: str) -> dict:
|
| 52 |
+
"""Fetch session metadata."""
|
| 53 |
+
resp = await self._http.get(f"/sessions/{session_id}")
|
| 54 |
+
resp.raise_for_status()
|
| 55 |
+
return resp.json()
|
| 56 |
+
|
| 57 |
+
# ------ events ------
|
| 58 |
+
|
| 59 |
+
async def send_message(self, session_id: str, message: str) -> dict:
|
| 60 |
+
"""Send a customer message event to a session."""
|
| 61 |
+
resp = await self._http.post(
|
| 62 |
+
f"/sessions/{session_id}/events",
|
| 63 |
+
json={"kind": "message", "source": "customer", "message": message},
|
| 64 |
+
)
|
| 65 |
+
resp.raise_for_status()
|
| 66 |
+
return resp.json()
|
| 67 |
+
|
| 68 |
+
async def poll_events(
|
| 69 |
+
self,
|
| 70 |
+
session_id: str,
|
| 71 |
+
min_offset: int = 0,
|
| 72 |
+
wait_seconds: int = 60,
|
| 73 |
+
) -> list[dict]:
|
| 74 |
+
"""Poll for new events starting from *min_offset*."""
|
| 75 |
+
resp = await self._http.get(
|
| 76 |
+
f"/sessions/{session_id}/events",
|
| 77 |
+
params={"min_offset": min_offset, "wait_for_data": wait_seconds},
|
| 78 |
+
)
|
| 79 |
+
resp.raise_for_status()
|
| 80 |
+
return resp.json()
|
| 81 |
+
|
| 82 |
+
# ------ agents ------
|
| 83 |
+
|
| 84 |
+
async def list_agents(self) -> list[dict]:
|
| 85 |
+
"""List available agents."""
|
| 86 |
+
resp = await self._http.get("/agents")
|
| 87 |
+
resp.raise_for_status()
|
| 88 |
+
return resp.json()
|
| 89 |
+
|
| 90 |
+
# ------ lifecycle ------
|
| 91 |
+
|
| 92 |
+
async def close(self) -> None:
|
| 93 |
+
"""Close the underlying HTTP client."""
|
| 94 |
+
await self._http.aclose()
|
trialpath/tests/test_parlant.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for trialpath.services.parlant_client β async Parlant REST wrapper."""
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from trialpath.services.parlant_client import ParlantClient
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _mock_transport(handler):
|
| 10 |
+
"""Create an httpx.MockTransport from a handler function."""
|
| 11 |
+
return httpx.MockTransport(handler)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------- create_session ----------
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestCreateSession:
|
| 18 |
+
@pytest.mark.asyncio
|
| 19 |
+
async def test_sends_post_and_returns_session_id(self):
|
| 20 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 21 |
+
assert request.method == "POST"
|
| 22 |
+
assert request.url.path == "/sessions"
|
| 23 |
+
import json
|
| 24 |
+
|
| 25 |
+
body = json.loads(request.content)
|
| 26 |
+
assert body["agent_id"] == "agent-123"
|
| 27 |
+
return httpx.Response(200, json={"session_id": "sess-abc"})
|
| 28 |
+
|
| 29 |
+
client = ParlantClient(
|
| 30 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 31 |
+
)
|
| 32 |
+
result = await client.create_session("agent-123")
|
| 33 |
+
assert result == "sess-abc"
|
| 34 |
+
await client.close()
|
| 35 |
+
|
| 36 |
+
@pytest.mark.asyncio
|
| 37 |
+
async def test_passes_customer_id_when_provided(self):
|
| 38 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 39 |
+
import json
|
| 40 |
+
|
| 41 |
+
body = json.loads(request.content)
|
| 42 |
+
assert body["customer_id"] == "cust-456"
|
| 43 |
+
return httpx.Response(200, json={"session_id": "sess-def"})
|
| 44 |
+
|
| 45 |
+
client = ParlantClient(
|
| 46 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 47 |
+
)
|
| 48 |
+
result = await client.create_session("agent-123", customer_id="cust-456")
|
| 49 |
+
assert result == "sess-def"
|
| 50 |
+
await client.close()
|
| 51 |
+
|
| 52 |
+
@pytest.mark.asyncio
|
| 53 |
+
async def test_raises_on_server_error(self):
|
| 54 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 55 |
+
return httpx.Response(500, json={"error": "internal"})
|
| 56 |
+
|
| 57 |
+
client = ParlantClient(
|
| 58 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 59 |
+
)
|
| 60 |
+
with pytest.raises(httpx.HTTPStatusError):
|
| 61 |
+
await client.create_session("agent-123")
|
| 62 |
+
await client.close()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ---------- send_message ----------
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TestSendMessage:
|
| 69 |
+
@pytest.mark.asyncio
|
| 70 |
+
async def test_sends_correct_payload(self):
|
| 71 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 72 |
+
assert request.method == "POST"
|
| 73 |
+
assert "/sessions/sess-1/events" in request.url.path
|
| 74 |
+
import json
|
| 75 |
+
|
| 76 |
+
body = json.loads(request.content)
|
| 77 |
+
assert body["kind"] == "message"
|
| 78 |
+
assert body["source"] == "customer"
|
| 79 |
+
assert body["message"] == "Hello"
|
| 80 |
+
return httpx.Response(200, json={"event_id": "evt-1", "offset": 0})
|
| 81 |
+
|
| 82 |
+
client = ParlantClient(
|
| 83 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 84 |
+
)
|
| 85 |
+
result = await client.send_message("sess-1", "Hello")
|
| 86 |
+
assert result["event_id"] == "evt-1"
|
| 87 |
+
await client.close()
|
| 88 |
+
|
| 89 |
+
@pytest.mark.asyncio
|
| 90 |
+
async def test_raises_on_not_found(self):
|
| 91 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 92 |
+
return httpx.Response(404, json={"error": "session not found"})
|
| 93 |
+
|
| 94 |
+
client = ParlantClient(
|
| 95 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 96 |
+
)
|
| 97 |
+
with pytest.raises(httpx.HTTPStatusError):
|
| 98 |
+
await client.send_message("bad-sess", "Hello")
|
| 99 |
+
await client.close()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---------- poll_events ----------
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class TestPollEvents:
|
| 106 |
+
@pytest.mark.asyncio
|
| 107 |
+
async def test_fetches_with_offset_param(self):
|
| 108 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 109 |
+
assert request.method == "GET"
|
| 110 |
+
assert "/sessions/sess-1/events" in request.url.path
|
| 111 |
+
assert request.url.params["min_offset"] == "5"
|
| 112 |
+
return httpx.Response(
|
| 113 |
+
200,
|
| 114 |
+
json=[
|
| 115 |
+
{"offset": 5, "kind": "message", "source": "ai_agent", "message": "Hi"},
|
| 116 |
+
{"offset": 6, "kind": "status", "data": "ready"},
|
| 117 |
+
],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
client = ParlantClient(
|
| 121 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 122 |
+
)
|
| 123 |
+
events = await client.poll_events("sess-1", min_offset=5)
|
| 124 |
+
assert len(events) == 2
|
| 125 |
+
assert events[0]["offset"] == 5
|
| 126 |
+
await client.close()
|
| 127 |
+
|
| 128 |
+
@pytest.mark.asyncio
|
| 129 |
+
async def test_default_offset_is_zero(self):
|
| 130 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 131 |
+
assert request.url.params["min_offset"] == "0"
|
| 132 |
+
return httpx.Response(200, json=[])
|
| 133 |
+
|
| 134 |
+
client = ParlantClient(
|
| 135 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 136 |
+
)
|
| 137 |
+
events = await client.poll_events("sess-1")
|
| 138 |
+
assert events == []
|
| 139 |
+
await client.close()
|
| 140 |
+
|
| 141 |
+
@pytest.mark.asyncio
|
| 142 |
+
async def test_passes_wait_seconds(self):
|
| 143 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 144 |
+
assert request.url.params["wait_for_data"] == "30"
|
| 145 |
+
return httpx.Response(200, json=[])
|
| 146 |
+
|
| 147 |
+
client = ParlantClient(
|
| 148 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 149 |
+
)
|
| 150 |
+
await client.poll_events("sess-1", wait_seconds=30)
|
| 151 |
+
await client.close()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ---------- get_session_status ----------
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class TestGetSessionStatus:
|
| 158 |
+
@pytest.mark.asyncio
|
| 159 |
+
async def test_returns_session_info(self):
|
| 160 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 161 |
+
assert request.method == "GET"
|
| 162 |
+
assert request.url.path == "/sessions/sess-1"
|
| 163 |
+
return httpx.Response(
|
| 164 |
+
200,
|
| 165 |
+
json={
|
| 166 |
+
"session_id": "sess-1",
|
| 167 |
+
"agent_id": "agent-1",
|
| 168 |
+
"status": "active",
|
| 169 |
+
},
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
client = ParlantClient(
|
| 173 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 174 |
+
)
|
| 175 |
+
status = await client.get_session_status("sess-1")
|
| 176 |
+
assert status["session_id"] == "sess-1"
|
| 177 |
+
assert status["status"] == "active"
|
| 178 |
+
await client.close()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ---------- list_agents ----------
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class TestListAgents:
|
| 185 |
+
@pytest.mark.asyncio
|
| 186 |
+
async def test_returns_agents_list(self):
|
| 187 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 188 |
+
assert request.method == "GET"
|
| 189 |
+
assert request.url.path == "/agents"
|
| 190 |
+
return httpx.Response(
|
| 191 |
+
200,
|
| 192 |
+
json=[
|
| 193 |
+
{"agent_id": "agent-1", "name": "patient_trial_copilot"},
|
| 194 |
+
],
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
client = ParlantClient(
|
| 198 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 199 |
+
)
|
| 200 |
+
agents = await client.list_agents()
|
| 201 |
+
assert len(agents) == 1
|
| 202 |
+
assert agents[0]["agent_id"] == "agent-1"
|
| 203 |
+
await client.close()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ---------- lifecycle ----------
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TestLifecycle:
|
| 210 |
+
@pytest.mark.asyncio
|
| 211 |
+
async def test_client_close(self):
|
| 212 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 213 |
+
return httpx.Response(200, json={})
|
| 214 |
+
|
| 215 |
+
client = ParlantClient(
|
| 216 |
+
base_url="http://test", transport=_mock_transport(handler)
|
| 217 |
+
)
|
| 218 |
+
await client.close()
|
| 219 |
+
assert client._http.is_closed
|
| 220 |
+
|
| 221 |
+
@pytest.mark.asyncio
|
| 222 |
+
async def test_async_context_manager(self):
|
| 223 |
+
async def handler(request: httpx.Request) -> httpx.Response:
|
| 224 |
+
return httpx.Response(200, json={"session_id": "sess-ctx"})
|
| 225 |
+
|
| 226 |
+
transport = _mock_transport(handler)
|
| 227 |
+
async with ParlantClient(base_url="http://test", transport=transport) as client:
|
| 228 |
+
result = await client.create_session("agent-1")
|
| 229 |
+
assert result == "sess-ctx"
|
| 230 |
+
assert client._http.is_closed
|