yakilee Claude Opus 4.6 commited on
Commit
ec9e535
·
1 Parent(s): 1abff4e

feat: implement 5 Pydantic v2 data models with 37 TDD tests

Browse files

- PatientProfile with 10 sub-models + has_minimum_prescreen_data()
- SearchAnchors with GeographyFilter, TrialFilters, relaxation_order
- TrialCandidate with TrialLocation, AgeRange, EligibilityText
- EligibilityLedger with CriterionAssessment, traffic_light, gap analysis
- SearchLog with SearchStep, RefinementAction, transparency_summary

All 37 tests pass. Ruff clean.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

trialpath/models/__init__.py CHANGED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TrialPath data models -- Pydantic v2 data contracts."""
2
+ from trialpath.models.eligibility_ledger import (
3
+ CriterionAssessment,
4
+ CriterionDecision,
5
+ EligibilityLedger,
6
+ GapItem,
7
+ OverallAssessment,
8
+ TemporalCheck,
9
+ TrialEvidencePointer,
10
+ )
11
+ from trialpath.models.patient_profile import (
12
+ Biomarker,
13
+ Comorbidity,
14
+ Demographics,
15
+ Diagnosis,
16
+ EvidencePointer,
17
+ ImagingSummary,
18
+ LabResult,
19
+ PatientProfile,
20
+ PerformanceStatus,
21
+ SourceDocument,
22
+ Treatment,
23
+ UnknownField,
24
+ )
25
+ from trialpath.models.search_anchors import (
26
+ GeographyFilter,
27
+ SearchAnchors,
28
+ TrialFilters,
29
+ )
30
+ from trialpath.models.search_log import (
31
+ RefinementAction,
32
+ SearchLog,
33
+ SearchStep,
34
+ )
35
+ from trialpath.models.trial_candidate import (
36
+ AgeRange,
37
+ EligibilityText,
38
+ TrialCandidate,
39
+ TrialLocation,
40
+ )
41
+
42
+ __all__ = [
43
+ "AgeRange",
44
+ "Biomarker",
45
+ "Comorbidity",
46
+ "CriterionAssessment",
47
+ "CriterionDecision",
48
+ "Demographics",
49
+ "Diagnosis",
50
+ "EligibilityLedger",
51
+ "EligibilityText",
52
+ "EvidencePointer",
53
+ "GapItem",
54
+ "GeographyFilter",
55
+ "ImagingSummary",
56
+ "LabResult",
57
+ "OverallAssessment",
58
+ "PatientProfile",
59
+ "PerformanceStatus",
60
+ "RefinementAction",
61
+ "SearchAnchors",
62
+ "SearchLog",
63
+ "SearchStep",
64
+ "SourceDocument",
65
+ "TemporalCheck",
66
+ "Treatment",
67
+ "TrialCandidate",
68
+ "TrialEvidencePointer",
69
+ "TrialFilters",
70
+ "TrialLocation",
71
+ "UnknownField",
72
+ ]
trialpath/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.39 kB). View file
 
trialpath/models/__pycache__/eligibility_ledger.cpython-313.pyc ADDED
Binary file (6.11 kB). View file
 
trialpath/models/__pycache__/patient_profile.cpython-313.pyc ADDED
Binary file (6.9 kB). View file
 
trialpath/models/__pycache__/search_anchors.cpython-313.pyc ADDED
Binary file (2.16 kB). View file
 
trialpath/models/__pycache__/search_log.cpython-313.pyc ADDED
Binary file (4 kB). View file
 
trialpath/models/__pycache__/trial_candidate.cpython-313.pyc ADDED
Binary file (2.06 kB). View file
 
trialpath/models/eligibility_ledger.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EligibilityLedger v1 -- Per-trial criterion-level eligibility assessment."""
2
+ from datetime import date
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+ from trialpath.models.patient_profile import EvidencePointer
9
+
10
+
11
+ class CriterionDecision(str, Enum):
12
+ MET = "met"
13
+ NOT_MET = "not_met"
14
+ UNKNOWN = "unknown"
15
+
16
+
17
+ class OverallAssessment(str, Enum):
18
+ LIKELY_ELIGIBLE = "likely_eligible"
19
+ LIKELY_INELIGIBLE = "likely_ineligible"
20
+ UNCERTAIN = "uncertain"
21
+
22
+
23
+ class TrialEvidencePointer(BaseModel):
24
+ field: str = Field(description="e.g. 'eligibility_text.inclusion'")
25
+ offset_start: int
26
+ offset_end: int
27
+
28
+
29
+ class TemporalCheck(BaseModel):
30
+ """Validates whether patient evidence falls within a required time window."""
31
+ required_window_days: Optional[int] = Field(
32
+ None, description="e.g. 14 for 'within 14 days'"
33
+ )
34
+ reference_date: Optional[date] = Field(
35
+ None, description="Date of the patient evidence"
36
+ )
37
+ evaluation_date: Optional[date] = Field(default_factory=date.today)
38
+ is_within_window: Optional[bool] = None
39
+
40
+ @property
41
+ def days_elapsed(self) -> Optional[int]:
42
+ if self.reference_date and self.evaluation_date:
43
+ return (self.evaluation_date - self.reference_date).days
44
+ return None
45
+
46
+
47
+ class CriterionAssessment(BaseModel):
48
+ criterion_id: str = Field(description="e.g. 'inc_1', 'exc_3'")
49
+ type: str = Field(description="'inclusion' or 'exclusion'")
50
+ text: str = Field(description="Original criterion text from trial")
51
+ decision: CriterionDecision
52
+ patient_evidence: list[EvidencePointer] = Field(default_factory=list)
53
+ trial_evidence: list[TrialEvidencePointer] = Field(default_factory=list)
54
+ temporal_check: Optional[TemporalCheck] = None
55
+
56
+
57
+ class GapItem(BaseModel):
58
+ description: str
59
+ recommended_action: str
60
+ clinical_importance: str = Field(description="high|medium|low")
61
+
62
+
63
+ class EligibilityLedger(BaseModel):
64
+ patient_id: str
65
+ nct_id: str
66
+ overall_assessment: OverallAssessment
67
+ criteria: list[CriterionAssessment] = Field(default_factory=list)
68
+ gaps: list[GapItem] = Field(default_factory=list)
69
+
70
+ @property
71
+ def met_count(self) -> int:
72
+ return sum(1 for c in self.criteria if c.decision == CriterionDecision.MET)
73
+
74
+ @property
75
+ def not_met_count(self) -> int:
76
+ return sum(1 for c in self.criteria if c.decision == CriterionDecision.NOT_MET)
77
+
78
+ @property
79
+ def unknown_count(self) -> int:
80
+ return sum(1 for c in self.criteria if c.decision == CriterionDecision.UNKNOWN)
81
+
82
+ @property
83
+ def traffic_light(self) -> str:
84
+ """Return traffic light color for UI display."""
85
+ if self.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE:
86
+ return "green"
87
+ elif self.overall_assessment == OverallAssessment.UNCERTAIN:
88
+ return "yellow"
89
+ return "red"
trialpath/models/patient_profile.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PatientProfile v1 -- MedGemma extraction output for NSCLC patients."""
2
+ import datetime
3
+ from typing import Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class EvidencePointer(BaseModel):
9
+ doc_id: str = Field(description="Source document identifier")
10
+ page: Optional[int] = Field(default=None, description="Page number")
11
+ span_id: Optional[str] = Field(default=None, description="Text span identifier")
12
+
13
+
14
+ class SourceDocument(BaseModel):
15
+ doc_id: str
16
+ type: str = Field(description="clinic_letter|pathology|lab|imaging")
17
+ meta: dict = Field(default_factory=dict)
18
+
19
+
20
+ class Demographics(BaseModel):
21
+ age: Optional[int] = None
22
+ sex: Optional[str] = None
23
+
24
+
25
+ class Diagnosis(BaseModel):
26
+ primary_condition: str = Field(description="e.g. 'Non-Small Cell Lung Cancer'")
27
+ histology: Optional[str] = Field(default=None, description="e.g. 'adenocarcinoma'")
28
+ stage: Optional[str] = Field(default=None, description="e.g. 'IVa'")
29
+ diagnosis_date: Optional[datetime.date] = None
30
+
31
+
32
+ class PerformanceStatus(BaseModel):
33
+ scale: str = Field(description="'ECOG' or 'KPS'")
34
+ value: int
35
+ evidence: list[EvidencePointer] = Field(default_factory=list)
36
+
37
+
38
+ class Biomarker(BaseModel):
39
+ name: str = Field(description="e.g. 'EGFR', 'ALK', 'PD-L1'")
40
+ result: str = Field(description="e.g. 'Exon 19 deletion', 'Positive 80%'")
41
+ date: Optional[datetime.date] = None
42
+ evidence: list[EvidencePointer] = Field(default_factory=list)
43
+
44
+
45
+ class LabResult(BaseModel):
46
+ name: str = Field(description="e.g. 'ANC', 'Creatinine'")
47
+ value: float
48
+ unit: str
49
+ date: Optional[datetime.date] = None
50
+ evidence: list[EvidencePointer] = Field(default_factory=list)
51
+
52
+
53
+ class Treatment(BaseModel):
54
+ drug_name: str
55
+ start_date: Optional[datetime.date] = None
56
+ end_date: Optional[datetime.date] = None
57
+ line: Optional[int] = Field(default=None, description="Line of therapy (1, 2, 3...)")
58
+ evidence: list[EvidencePointer] = Field(default_factory=list)
59
+
60
+
61
+ class Comorbidity(BaseModel):
62
+ name: str
63
+ grade: Optional[str] = None
64
+ evidence: list[EvidencePointer] = Field(default_factory=list)
65
+
66
+
67
+ class ImagingSummary(BaseModel):
68
+ modality: str = Field(description="e.g. 'MRI brain', 'CT chest'")
69
+ date: Optional[datetime.date] = None
70
+ finding: str
71
+ interpretation: Optional[str] = None
72
+ certainty: Optional[str] = Field(default=None, description="low|medium|high")
73
+ evidence: list[EvidencePointer] = Field(default_factory=list)
74
+
75
+
76
+ class UnknownField(BaseModel):
77
+ field: str = Field(description="Name of missing field")
78
+ reason: str = Field(description="Why it is unknown")
79
+ importance: str = Field(description="high|medium|low")
80
+
81
+
82
+ class PatientProfile(BaseModel):
83
+ patient_id: str
84
+ source_docs: list[SourceDocument] = Field(default_factory=list)
85
+ demographics: Demographics = Field(default_factory=Demographics)
86
+ diagnosis: Optional[Diagnosis] = None
87
+ performance_status: Optional[PerformanceStatus] = None
88
+ biomarkers: list[Biomarker] = Field(default_factory=list)
89
+ key_labs: list[LabResult] = Field(default_factory=list)
90
+ treatments: list[Treatment] = Field(default_factory=list)
91
+ comorbidities: list[Comorbidity] = Field(default_factory=list)
92
+ imaging_summary: list[ImagingSummary] = Field(default_factory=list)
93
+ unknowns: list[UnknownField] = Field(default_factory=list)
94
+
95
+ def has_minimum_prescreen_data(self) -> bool:
96
+ """Check if profile has enough data for prescreening."""
97
+ return (
98
+ self.diagnosis is not None
99
+ and self.diagnosis.stage is not None
100
+ and self.performance_status is not None
101
+ )
trialpath/models/search_anchors.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SearchAnchors v1 -- Gemini-generated query parameters for ClinicalTrials MCP search."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class GeographyFilter(BaseModel):
10
+ country: str = Field(description="ISO country code or full name")
11
+ max_distance_km: Optional[int] = None
12
+
13
+
14
+ class TrialFilters(BaseModel):
15
+ recruitment_status: list[str] = Field(
16
+ default=["Recruiting", "Not yet recruiting"]
17
+ )
18
+ phase: list[str] = Field(default=["Phase 2", "Phase 3"])
19
+
20
+
21
+ class SearchAnchors(BaseModel):
22
+ condition: str = Field(description="Primary condition for search")
23
+ subtype: Optional[str] = Field(default=None, description="Cancer subtype")
24
+ biomarkers: list[str] = Field(default_factory=list)
25
+ stage: Optional[str] = None
26
+ geography: Optional[GeographyFilter] = None
27
+ age: Optional[int] = None
28
+ performance_status_max: Optional[int] = None
29
+ trial_filters: TrialFilters = Field(default_factory=TrialFilters)
30
+ relaxation_order: list[str] = Field(
31
+ default=["phase", "distance", "biomarker_strictness"],
32
+ description="Order in which to relax criteria if too few results",
33
+ )
trialpath/models/search_log.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SearchLog v1 -- Iterative query refinement tracking."""
2
+ from datetime import datetime, timezone
3
+ from enum import Enum
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class RefinementAction(str, Enum):
9
+ INITIAL = "initial"
10
+ REFINE = "refine"
11
+ RELAX = "relax"
12
+ SHORTLIST = "shortlist"
13
+ ABORT = "abort"
14
+
15
+
16
+ class SearchStep(BaseModel):
17
+ step_number: int
18
+ query_params: dict = Field(description="SearchAnchors snapshot used for this query")
19
+ result_count: int
20
+ action_taken: RefinementAction
21
+ action_reason: str = Field(description="Human-readable why this action was chosen")
22
+ timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
23
+ nct_ids_sample: list[str] = Field(
24
+ default_factory=list,
25
+ description="Sample of NCT IDs returned (up to 10 for transparency)",
26
+ )
27
+
28
+
29
+ class SearchLog(BaseModel):
30
+ session_id: str
31
+ patient_id: str
32
+ steps: list[SearchStep] = Field(default_factory=list)
33
+ final_shortlist_nct_ids: list[str] = Field(default_factory=list)
34
+ total_refinement_rounds: int = 0
35
+ max_refinement_rounds: int = Field(
36
+ default=5, description="Safety cap to prevent infinite loops"
37
+ )
38
+
39
+ @property
40
+ def is_refinement_exhausted(self) -> bool:
41
+ return self.total_refinement_rounds >= self.max_refinement_rounds
42
+
43
+ def add_step(
44
+ self,
45
+ query_params: dict,
46
+ result_count: int,
47
+ action: RefinementAction,
48
+ reason: str,
49
+ nct_ids_sample: list[str] | None = None,
50
+ ) -> None:
51
+ step = SearchStep(
52
+ step_number=len(self.steps) + 1,
53
+ query_params=query_params,
54
+ result_count=result_count,
55
+ action_taken=action,
56
+ action_reason=reason,
57
+ nct_ids_sample=nct_ids_sample or [],
58
+ )
59
+ self.steps.append(step)
60
+ if action in (RefinementAction.REFINE, RefinementAction.RELAX):
61
+ self.total_refinement_rounds += 1
62
+
63
+ def to_transparency_summary(self) -> list[dict]:
64
+ """Generate human-readable search process for FE display."""
65
+ return [
66
+ {
67
+ "step": s.step_number,
68
+ "query": s.query_params,
69
+ "found": s.result_count,
70
+ "action": s.action_taken.value,
71
+ "reason": s.action_reason,
72
+ }
73
+ for s in self.steps
74
+ ]
trialpath/models/trial_candidate.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TrialCandidate v1 -- Normalized ClinicalTrials MCP search results."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class TrialLocation(BaseModel):
10
+ country: str
11
+ city: Optional[str] = None
12
+
13
+
14
+ class AgeRange(BaseModel):
15
+ min: Optional[int] = None
16
+ max: Optional[int] = None
17
+
18
+
19
+ class EligibilityText(BaseModel):
20
+ inclusion: str
21
+ exclusion: str
22
+
23
+
24
+ class TrialCandidate(BaseModel):
25
+ nct_id: str = Field(description="NCT identifier e.g. 'NCT01234567'")
26
+ title: str
27
+ conditions: list[str] = Field(default_factory=list)
28
+ phase: Optional[str] = None
29
+ status: Optional[str] = None
30
+ locations: list[TrialLocation] = Field(default_factory=list)
31
+ age_range: Optional[AgeRange] = None
32
+ fingerprint_text: str = Field(description="Short text for Gemini reranking")
33
+ eligibility_text: Optional[EligibilityText] = None
trialpath/tests/test_models.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TDD tests for TrialPath data models (RED phase — write tests first)."""
2
+ from __future__ import annotations
3
+
4
+ from datetime import date
5
+
6
+
7
+ class TestPatientProfile:
8
+ """PatientProfile v1 validation and helper tests."""
9
+
10
+ def test_minimal_valid_profile(self):
11
+ """A profile with only patient_id should be valid."""
12
+ from trialpath.models.patient_profile import PatientProfile
13
+
14
+ profile = PatientProfile(patient_id="P001")
15
+ assert profile.patient_id == "P001"
16
+ assert profile.unknowns == []
17
+
18
+ def test_complete_nsclc_profile(self):
19
+ """Full NSCLC patient profile should serialize/deserialize correctly."""
20
+ from trialpath.models.patient_profile import (
21
+ Biomarker,
22
+ Demographics,
23
+ Diagnosis,
24
+ EvidencePointer,
25
+ PatientProfile,
26
+ PerformanceStatus,
27
+ UnknownField,
28
+ )
29
+
30
+ profile = PatientProfile(
31
+ patient_id="P001",
32
+ demographics=Demographics(age=52, sex="female"),
33
+ diagnosis=Diagnosis(
34
+ primary_condition="Non-Small Cell Lung Cancer",
35
+ histology="adenocarcinoma",
36
+ stage="IVa",
37
+ diagnosis_date=date(2025, 11, 15),
38
+ ),
39
+ performance_status=PerformanceStatus(
40
+ scale="ECOG", value=1,
41
+ evidence=[EvidencePointer(doc_id="clinic_1", page=2, span_id="s_17")],
42
+ ),
43
+ biomarkers=[
44
+ Biomarker(
45
+ name="EGFR", result="Exon 19 deletion",
46
+ date=date(2026, 1, 10),
47
+ evidence=[EvidencePointer(doc_id="path_egfr", page=1, span_id="s_3")],
48
+ ),
49
+ ],
50
+ unknowns=[
51
+ UnknownField(field="PD-L1", reason="Not found in documents", importance="medium"),
52
+ ],
53
+ )
54
+
55
+ data = profile.model_dump()
56
+ restored = PatientProfile.model_validate(data)
57
+ assert restored.patient_id == "P001"
58
+ assert restored.diagnosis.stage == "IVa"
59
+ assert len(restored.biomarkers) == 1
60
+ assert restored.biomarkers[0].name == "EGFR"
61
+
62
+ def test_has_minimum_prescreen_data_true(self):
63
+ """Profile with diagnosis + stage + ECOG satisfies prescreen requirements."""
64
+ from trialpath.models.patient_profile import (
65
+ Diagnosis,
66
+ PatientProfile,
67
+ PerformanceStatus,
68
+ )
69
+
70
+ profile = PatientProfile(
71
+ patient_id="P001",
72
+ diagnosis=Diagnosis(
73
+ primary_condition="NSCLC", stage="IV",
74
+ ),
75
+ performance_status=PerformanceStatus(scale="ECOG", value=1),
76
+ )
77
+ assert profile.has_minimum_prescreen_data() is True
78
+
79
+ def test_has_minimum_prescreen_data_false_no_stage(self):
80
+ """Profile without stage should fail prescreen check."""
81
+ from trialpath.models.patient_profile import (
82
+ Diagnosis,
83
+ PatientProfile,
84
+ PerformanceStatus,
85
+ )
86
+
87
+ profile = PatientProfile(
88
+ patient_id="P001",
89
+ diagnosis=Diagnosis(primary_condition="NSCLC"),
90
+ performance_status=PerformanceStatus(scale="ECOG", value=1),
91
+ )
92
+ assert profile.has_minimum_prescreen_data() is False
93
+
94
+ def test_has_minimum_prescreen_data_false_no_ecog(self):
95
+ """Profile without performance status should fail prescreen check."""
96
+ from trialpath.models.patient_profile import (
97
+ Diagnosis,
98
+ PatientProfile,
99
+ )
100
+
101
+ profile = PatientProfile(
102
+ patient_id="P001",
103
+ diagnosis=Diagnosis(primary_condition="NSCLC", stage="IV"),
104
+ )
105
+ assert profile.has_minimum_prescreen_data() is False
106
+
107
+ def test_json_roundtrip(self):
108
+ """Profile should survive JSON serialization roundtrip."""
109
+ from trialpath.models.patient_profile import (
110
+ Demographics,
111
+ Diagnosis,
112
+ PatientProfile,
113
+ )
114
+
115
+ profile = PatientProfile(
116
+ patient_id="P001",
117
+ demographics=Demographics(age=65, sex="male"),
118
+ diagnosis=Diagnosis(
119
+ primary_condition="NSCLC",
120
+ histology="squamous",
121
+ stage="IIIb",
122
+ ),
123
+ )
124
+ json_str = profile.model_dump_json()
125
+ restored = PatientProfile.model_validate_json(json_str)
126
+ assert restored == profile
127
+
128
+ def test_source_docs_default_empty(self):
129
+ """source_docs should default to empty list."""
130
+ from trialpath.models.patient_profile import PatientProfile
131
+
132
+ profile = PatientProfile(patient_id="P001")
133
+ assert profile.source_docs == []
134
+
135
+ def test_source_doc_creation(self):
136
+ """SourceDocument with all fields."""
137
+ from trialpath.models.patient_profile import PatientProfile, SourceDocument
138
+
139
+ profile = PatientProfile(
140
+ patient_id="P001",
141
+ source_docs=[
142
+ SourceDocument(doc_id="doc1", type="pathology", meta={"pages": 3}),
143
+ ],
144
+ )
145
+ assert len(profile.source_docs) == 1
146
+ assert profile.source_docs[0].type == "pathology"
147
+
148
+ def test_lab_result(self):
149
+ """LabResult with value, unit, date, and evidence."""
150
+ from trialpath.models.patient_profile import (
151
+ EvidencePointer,
152
+ LabResult,
153
+ PatientProfile,
154
+ )
155
+
156
+ profile = PatientProfile(
157
+ patient_id="P001",
158
+ key_labs=[
159
+ LabResult(
160
+ name="ANC", value=1.8, unit="10^9/L",
161
+ date=date(2026, 1, 28),
162
+ evidence=[EvidencePointer(doc_id="labs_jan", page=1, span_id="tbl_anc")],
163
+ ),
164
+ ],
165
+ )
166
+ assert profile.key_labs[0].value == 1.8
167
+ assert profile.key_labs[0].unit == "10^9/L"
168
+
169
+ def test_treatment(self):
170
+ """Treatment with drug_name, dates, and line of therapy."""
171
+ from trialpath.models.patient_profile import PatientProfile, Treatment
172
+
173
+ profile = PatientProfile(
174
+ patient_id="P001",
175
+ treatments=[
176
+ Treatment(
177
+ drug_name="Pembrolizumab",
178
+ start_date=date(2024, 6, 1),
179
+ end_date=date(2024, 11, 30),
180
+ line=1,
181
+ ),
182
+ ],
183
+ )
184
+ assert profile.treatments[0].drug_name == "Pembrolizumab"
185
+ assert profile.treatments[0].line == 1
186
+
187
+ def test_comorbidity(self):
188
+ """Comorbidity with name and grade."""
189
+ from trialpath.models.patient_profile import Comorbidity, PatientProfile
190
+
191
+ profile = PatientProfile(
192
+ patient_id="P001",
193
+ comorbidities=[
194
+ Comorbidity(name="CKD", grade="Stage 3"),
195
+ ],
196
+ )
197
+ assert profile.comorbidities[0].name == "CKD"
198
+
199
+ def test_imaging_summary(self):
200
+ """ImagingSummary with modality, finding, interpretation, certainty."""
201
+ from trialpath.models.patient_profile import ImagingSummary, PatientProfile
202
+
203
+ profile = PatientProfile(
204
+ patient_id="P001",
205
+ imaging_summary=[
206
+ ImagingSummary(
207
+ modality="MRI brain",
208
+ date=date(2026, 1, 20),
209
+ finding="Stable 3mm left frontal lesion",
210
+ interpretation="likely inactive scar",
211
+ certainty="low",
212
+ ),
213
+ ],
214
+ )
215
+ assert profile.imaging_summary[0].modality == "MRI brain"
216
+ assert profile.imaging_summary[0].certainty == "low"
217
+
218
+
219
+ class TestSearchAnchors:
220
+ """SearchAnchors v1 validation tests."""
221
+
222
+ def test_minimal_anchors(self):
223
+ from trialpath.models.search_anchors import SearchAnchors
224
+
225
+ anchors = SearchAnchors(condition="NSCLC")
226
+ assert anchors.condition == "NSCLC"
227
+ assert anchors.trial_filters.recruitment_status == ["Recruiting", "Not yet recruiting"]
228
+
229
+ def test_full_anchors(self):
230
+ from trialpath.models.search_anchors import SearchAnchors, TrialFilters
231
+
232
+ anchors = SearchAnchors(
233
+ condition="Non-Small Cell Lung Cancer",
234
+ subtype="adenocarcinoma",
235
+ biomarkers=["EGFR exon 19 deletion"],
236
+ stage="IV",
237
+ age=52,
238
+ performance_status_max=1,
239
+ trial_filters=TrialFilters(
240
+ recruitment_status=["Recruiting"],
241
+ phase=["Phase 3"],
242
+ ),
243
+ relaxation_order=["phase", "distance"],
244
+ )
245
+ assert len(anchors.biomarkers) == 1
246
+ assert anchors.trial_filters.phase == ["Phase 3"]
247
+
248
+ def test_default_relaxation_order(self):
249
+ from trialpath.models.search_anchors import SearchAnchors
250
+
251
+ anchors = SearchAnchors(condition="NSCLC")
252
+ assert anchors.relaxation_order == ["phase", "distance", "biomarker_strictness"]
253
+
254
+ def test_default_trial_filters(self):
255
+ from trialpath.models.search_anchors import SearchAnchors
256
+
257
+ anchors = SearchAnchors(condition="NSCLC")
258
+ assert anchors.trial_filters.phase == ["Phase 2", "Phase 3"]
259
+
260
+ def test_geography_filter(self):
261
+ from trialpath.models.search_anchors import GeographyFilter, SearchAnchors
262
+
263
+ anchors = SearchAnchors(
264
+ condition="NSCLC",
265
+ geography=GeographyFilter(country="DE", max_distance_km=200),
266
+ )
267
+ assert anchors.geography.country == "DE"
268
+ assert anchors.geography.max_distance_km == 200
269
+
270
+ def test_json_roundtrip(self):
271
+ from trialpath.models.search_anchors import SearchAnchors
272
+
273
+ anchors = SearchAnchors(
274
+ condition="NSCLC", stage="IV", age=55,
275
+ )
276
+ json_str = anchors.model_dump_json()
277
+ restored = SearchAnchors.model_validate_json(json_str)
278
+ assert restored == anchors
279
+
280
+
281
+ class TestTrialCandidate:
282
+ """TrialCandidate v1 tests."""
283
+
284
+ def test_trial_with_eligibility_text(self):
285
+ from trialpath.models.trial_candidate import EligibilityText, TrialCandidate
286
+
287
+ trial = TrialCandidate(
288
+ nct_id="NCT01234567",
289
+ title="Phase 3 Study of Osimertinib",
290
+ conditions=["NSCLC"],
291
+ phase="Phase 3",
292
+ status="Recruiting",
293
+ fingerprint_text="Osimertinib EGFR+ NSCLC Phase 3",
294
+ eligibility_text=EligibilityText(
295
+ inclusion="Histologically confirmed NSCLC stage IV",
296
+ exclusion="Prior EGFR TKI therapy",
297
+ ),
298
+ )
299
+ assert trial.nct_id == "NCT01234567"
300
+ assert trial.eligibility_text.inclusion.startswith("Histologically")
301
+
302
+ def test_minimal_trial(self):
303
+ from trialpath.models.trial_candidate import TrialCandidate
304
+
305
+ trial = TrialCandidate(
306
+ nct_id="NCT99999999",
307
+ title="Test Trial",
308
+ fingerprint_text="test",
309
+ )
310
+ assert trial.conditions == []
311
+ assert trial.locations == []
312
+ assert trial.eligibility_text is None
313
+
314
+ def test_trial_with_locations(self):
315
+ from trialpath.models.trial_candidate import TrialCandidate, TrialLocation
316
+
317
+ trial = TrialCandidate(
318
+ nct_id="NCT01234567",
319
+ title="Test Trial",
320
+ fingerprint_text="test",
321
+ locations=[
322
+ TrialLocation(country="DE", city="Berlin"),
323
+ TrialLocation(country="DE", city="Hamburg"),
324
+ ],
325
+ )
326
+ assert len(trial.locations) == 2
327
+ assert trial.locations[0].city == "Berlin"
328
+
329
+ def test_trial_with_age_range(self):
330
+ from trialpath.models.trial_candidate import AgeRange, TrialCandidate
331
+
332
+ trial = TrialCandidate(
333
+ nct_id="NCT01234567",
334
+ title="Test Trial",
335
+ fingerprint_text="test",
336
+ age_range=AgeRange(min=18, max=75),
337
+ )
338
+ assert trial.age_range.min == 18
339
+ assert trial.age_range.max == 75
340
+
341
+ def test_json_roundtrip(self):
342
+ from trialpath.models.trial_candidate import TrialCandidate
343
+
344
+ trial = TrialCandidate(
345
+ nct_id="NCT01234567",
346
+ title="Test",
347
+ fingerprint_text="test fp",
348
+ phase="Phase 2",
349
+ )
350
+ json_str = trial.model_dump_json()
351
+ restored = TrialCandidate.model_validate_json(json_str)
352
+ assert restored == trial
353
+
354
+
355
+ class TestEligibilityLedger:
356
+ """EligibilityLedger v1 tests."""
357
+
358
+ def test_traffic_light_green(self):
359
+ from trialpath.models.eligibility_ledger import (
360
+ EligibilityLedger,
361
+ OverallAssessment,
362
+ )
363
+
364
+ ledger = EligibilityLedger(
365
+ patient_id="P001",
366
+ nct_id="NCT01234567",
367
+ overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
368
+ )
369
+ assert ledger.traffic_light == "green"
370
+
371
+ def test_traffic_light_yellow(self):
372
+ from trialpath.models.eligibility_ledger import (
373
+ EligibilityLedger,
374
+ OverallAssessment,
375
+ )
376
+
377
+ ledger = EligibilityLedger(
378
+ patient_id="P001",
379
+ nct_id="NCT01234567",
380
+ overall_assessment=OverallAssessment.UNCERTAIN,
381
+ )
382
+ assert ledger.traffic_light == "yellow"
383
+
384
+ def test_traffic_light_red(self):
385
+ from trialpath.models.eligibility_ledger import (
386
+ EligibilityLedger,
387
+ OverallAssessment,
388
+ )
389
+
390
+ ledger = EligibilityLedger(
391
+ patient_id="P001",
392
+ nct_id="NCT01234567",
393
+ overall_assessment=OverallAssessment.LIKELY_INELIGIBLE,
394
+ )
395
+ assert ledger.traffic_light == "red"
396
+
397
+ def test_criterion_counts(self):
398
+ from trialpath.models.eligibility_ledger import (
399
+ CriterionAssessment,
400
+ CriterionDecision,
401
+ EligibilityLedger,
402
+ GapItem,
403
+ OverallAssessment,
404
+ )
405
+
406
+ ledger = EligibilityLedger(
407
+ patient_id="P001",
408
+ nct_id="NCT01234567",
409
+ overall_assessment=OverallAssessment.UNCERTAIN,
410
+ criteria=[
411
+ CriterionAssessment(
412
+ criterion_id="inc_1", type="inclusion",
413
+ text="Stage IV NSCLC", decision=CriterionDecision.MET,
414
+ ),
415
+ CriterionAssessment(
416
+ criterion_id="inc_2", type="inclusion",
417
+ text="ECOG 0-1", decision=CriterionDecision.MET,
418
+ ),
419
+ CriterionAssessment(
420
+ criterion_id="exc_1", type="exclusion",
421
+ text="No prior immunotherapy", decision=CriterionDecision.NOT_MET,
422
+ ),
423
+ CriterionAssessment(
424
+ criterion_id="inc_3", type="inclusion",
425
+ text="EGFR mutation", decision=CriterionDecision.UNKNOWN,
426
+ ),
427
+ ],
428
+ gaps=[
429
+ GapItem(
430
+ description="EGFR mutation status unknown",
431
+ recommended_action="Order EGFR mutation test",
432
+ clinical_importance="high",
433
+ ),
434
+ ],
435
+ )
436
+ assert ledger.met_count == 2
437
+ assert ledger.not_met_count == 1
438
+ assert ledger.unknown_count == 1
439
+ assert len(ledger.gaps) == 1
440
+
441
+ def test_empty_criteria_counts(self):
442
+ from trialpath.models.eligibility_ledger import (
443
+ EligibilityLedger,
444
+ OverallAssessment,
445
+ )
446
+
447
+ ledger = EligibilityLedger(
448
+ patient_id="P001",
449
+ nct_id="NCT01234567",
450
+ overall_assessment=OverallAssessment.UNCERTAIN,
451
+ )
452
+ assert ledger.met_count == 0
453
+ assert ledger.not_met_count == 0
454
+ assert ledger.unknown_count == 0
455
+
456
+ def test_json_roundtrip(self):
457
+ from trialpath.models.eligibility_ledger import (
458
+ EligibilityLedger,
459
+ OverallAssessment,
460
+ )
461
+
462
+ ledger = EligibilityLedger(
463
+ patient_id="P001",
464
+ nct_id="NCT01234567",
465
+ overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
466
+ )
467
+ json_str = ledger.model_dump_json()
468
+ restored = EligibilityLedger.model_validate_json(json_str)
469
+ assert restored.patient_id == "P001"
470
+ assert restored.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE
471
+
472
+
473
+ class TestTemporalCheck:
474
+ """TemporalCheck validation for time-windowed criteria."""
475
+
476
+ def test_within_window(self):
477
+ """Evidence 7 days old should be within a 14-day window."""
478
+ from trialpath.models.eligibility_ledger import TemporalCheck
479
+
480
+ check = TemporalCheck(
481
+ required_window_days=14,
482
+ reference_date=date(2026, 1, 20),
483
+ evaluation_date=date(2026, 1, 27),
484
+ is_within_window=True,
485
+ )
486
+ assert check.days_elapsed == 7
487
+ assert check.is_within_window is True
488
+
489
+ def test_outside_window(self):
490
+ """Evidence 21 days old should be outside a 14-day window."""
491
+ from trialpath.models.eligibility_ledger import TemporalCheck
492
+
493
+ check = TemporalCheck(
494
+ required_window_days=14,
495
+ reference_date=date(2026, 1, 1),
496
+ evaluation_date=date(2026, 1, 22),
497
+ is_within_window=False,
498
+ )
499
+ assert check.days_elapsed == 21
500
+ assert check.is_within_window is False
501
+
502
+ def test_no_reference_date(self):
503
+ """Missing reference date should yield None for days_elapsed."""
504
+ from trialpath.models.eligibility_ledger import TemporalCheck
505
+
506
+ check = TemporalCheck(
507
+ required_window_days=14,
508
+ reference_date=None,
509
+ )
510
+ assert check.days_elapsed is None
511
+ assert check.is_within_window is None
512
+
513
+ def test_criterion_with_temporal_check(self):
514
+ """CriterionAssessment should accept an optional temporal_check."""
515
+ from trialpath.models.eligibility_ledger import (
516
+ CriterionAssessment,
517
+ CriterionDecision,
518
+ TemporalCheck,
519
+ )
520
+
521
+ assessment = CriterionAssessment(
522
+ criterion_id="inc_5",
523
+ type="inclusion",
524
+ text="ANC >= 1.5 x 10^9/L within 14 days of enrollment",
525
+ decision=CriterionDecision.MET,
526
+ temporal_check=TemporalCheck(
527
+ required_window_days=14,
528
+ reference_date=date(2026, 1, 20),
529
+ evaluation_date=date(2026, 1, 27),
530
+ is_within_window=True,
531
+ ),
532
+ )
533
+ assert assessment.temporal_check is not None
534
+ assert assessment.temporal_check.days_elapsed == 7
535
+ assert assessment.temporal_check.is_within_window is True
536
+
537
+
538
+ class TestSearchLog:
539
+ """SearchLog v1 -- iterative query refinement tracking tests."""
540
+
541
+ def test_add_step_increments_count(self):
542
+ """Adding a refinement step should increment total_refinement_rounds."""
543
+ from trialpath.models.search_log import RefinementAction, SearchLog
544
+
545
+ log = SearchLog(session_id="S001", patient_id="P001")
546
+ assert log.total_refinement_rounds == 0
547
+
548
+ log.add_step(
549
+ query_params={"condition": "NSCLC"},
550
+ result_count=75,
551
+ action=RefinementAction.REFINE,
552
+ reason="Too many results, adding phase filter",
553
+ )
554
+ assert log.total_refinement_rounds == 1
555
+ assert len(log.steps) == 1
556
+
557
+ def test_refinement_exhausted_at_max(self):
558
+ """After 5 refinement rounds, is_refinement_exhausted should be True."""
559
+ from trialpath.models.search_log import RefinementAction, SearchLog
560
+
561
+ log = SearchLog(session_id="S001", patient_id="P001")
562
+
563
+ for i in range(5):
564
+ log.add_step(
565
+ query_params={"condition": "NSCLC", "round": i},
566
+ result_count=0,
567
+ action=RefinementAction.RELAX,
568
+ reason=f"Relaxation round {i + 1}",
569
+ )
570
+
571
+ assert log.total_refinement_rounds == 5
572
+ assert log.is_refinement_exhausted is True
573
+
574
+ def test_transparency_summary_format(self):
575
+ """to_transparency_summary should return list of dicts with expected keys."""
576
+ from trialpath.models.search_log import RefinementAction, SearchLog
577
+
578
+ log = SearchLog(session_id="S001", patient_id="P001")
579
+
580
+ log.add_step(
581
+ query_params={"condition": "NSCLC"},
582
+ result_count=100,
583
+ action=RefinementAction.REFINE,
584
+ reason="Too many results",
585
+ )
586
+ log.add_step(
587
+ query_params={"condition": "NSCLC", "phase": "Phase 3"},
588
+ result_count=25,
589
+ action=RefinementAction.SHORTLIST,
590
+ reason="Right-sized result set",
591
+ )
592
+
593
+ summary = log.to_transparency_summary()
594
+ assert len(summary) == 2
595
+ assert summary[0]["step"] == 1
596
+ assert summary[0]["found"] == 100
597
+ assert summary[0]["action"] == "refine"
598
+ assert summary[1]["step"] == 2
599
+ assert summary[1]["found"] == 25
600
+ assert summary[1]["action"] == "shortlist"
601
+
602
+ def test_initial_search_no_refinement_count(self):
603
+ """An INITIAL action should not increment the refinement counter."""
604
+ from trialpath.models.search_log import RefinementAction, SearchLog
605
+
606
+ log = SearchLog(session_id="S001", patient_id="P001")
607
+
608
+ log.add_step(
609
+ query_params={"condition": "NSCLC"},
610
+ result_count=30,
611
+ action=RefinementAction.INITIAL,
612
+ reason="First search",
613
+ )
614
+
615
+ assert log.total_refinement_rounds == 0
616
+ assert len(log.steps) == 1