TrialPath / trialpath /tests /test_models.py
yakilee's picture
test: add tests for SearchAnchors new fields and enhanced MCP search
9b26c81
"""TDD tests for TrialPath data models (RED phase — write tests first)."""
from __future__ import annotations
from datetime import date
class TestPatientProfile:
"""PatientProfile v1 validation and helper tests."""
def test_minimal_valid_profile(self):
"""A profile with only patient_id should be valid."""
from trialpath.models.patient_profile import PatientProfile
profile = PatientProfile(patient_id="P001")
assert profile.patient_id == "P001"
assert profile.unknowns == []
def test_complete_nsclc_profile(self):
"""Full NSCLC patient profile should serialize/deserialize correctly."""
from trialpath.models.patient_profile import (
Biomarker,
Demographics,
Diagnosis,
EvidencePointer,
PatientProfile,
PerformanceStatus,
UnknownField,
)
profile = PatientProfile(
patient_id="P001",
demographics=Demographics(age=52, sex="female"),
diagnosis=Diagnosis(
primary_condition="Non-Small Cell Lung Cancer",
histology="adenocarcinoma",
stage="IVa",
diagnosis_date=date(2025, 11, 15),
),
performance_status=PerformanceStatus(
scale="ECOG",
value=1,
evidence=[EvidencePointer(doc_id="clinic_1", page=2, span_id="s_17")],
),
biomarkers=[
Biomarker(
name="EGFR",
result="Exon 19 deletion",
date=date(2026, 1, 10),
evidence=[EvidencePointer(doc_id="path_egfr", page=1, span_id="s_3")],
),
],
unknowns=[
UnknownField(field="PD-L1", reason="Not found in documents", importance="medium"),
],
)
data = profile.model_dump()
restored = PatientProfile.model_validate(data)
assert restored.patient_id == "P001"
assert restored.diagnosis.stage == "IVa"
assert len(restored.biomarkers) == 1
assert restored.biomarkers[0].name == "EGFR"
def test_has_minimum_prescreen_data_true(self):
"""Profile with diagnosis + stage + ECOG satisfies prescreen requirements."""
from trialpath.models.patient_profile import (
Diagnosis,
PatientProfile,
PerformanceStatus,
)
profile = PatientProfile(
patient_id="P001",
diagnosis=Diagnosis(
primary_condition="NSCLC",
stage="IV",
),
performance_status=PerformanceStatus(scale="ECOG", value=1),
)
assert profile.has_minimum_prescreen_data() is True
def test_has_minimum_prescreen_data_false_no_stage(self):
"""Profile without stage should fail prescreen check."""
from trialpath.models.patient_profile import (
Diagnosis,
PatientProfile,
PerformanceStatus,
)
profile = PatientProfile(
patient_id="P001",
diagnosis=Diagnosis(primary_condition="NSCLC"),
performance_status=PerformanceStatus(scale="ECOG", value=1),
)
assert profile.has_minimum_prescreen_data() is False
def test_has_minimum_prescreen_data_false_no_ecog(self):
"""Profile without performance status should fail prescreen check."""
from trialpath.models.patient_profile import (
Diagnosis,
PatientProfile,
)
profile = PatientProfile(
patient_id="P001",
diagnosis=Diagnosis(primary_condition="NSCLC", stage="IV"),
)
assert profile.has_minimum_prescreen_data() is False
def test_json_roundtrip(self):
"""Profile should survive JSON serialization roundtrip."""
from trialpath.models.patient_profile import (
Demographics,
Diagnosis,
PatientProfile,
)
profile = PatientProfile(
patient_id="P001",
demographics=Demographics(age=65, sex="male"),
diagnosis=Diagnosis(
primary_condition="NSCLC",
histology="squamous",
stage="IIIb",
),
)
json_str = profile.model_dump_json()
restored = PatientProfile.model_validate_json(json_str)
assert restored == profile
def test_source_docs_default_empty(self):
"""source_docs should default to empty list."""
from trialpath.models.patient_profile import PatientProfile
profile = PatientProfile(patient_id="P001")
assert profile.source_docs == []
def test_source_doc_creation(self):
"""SourceDocument with all fields."""
from trialpath.models.patient_profile import PatientProfile, SourceDocument
profile = PatientProfile(
patient_id="P001",
source_docs=[
SourceDocument(doc_id="doc1", type="pathology", meta={"pages": 3}),
],
)
assert len(profile.source_docs) == 1
assert profile.source_docs[0].type == "pathology"
def test_lab_result(self):
"""LabResult with value, unit, date, and evidence."""
from trialpath.models.patient_profile import (
EvidencePointer,
LabResult,
PatientProfile,
)
profile = PatientProfile(
patient_id="P001",
key_labs=[
LabResult(
name="ANC",
value=1.8,
unit="10^9/L",
date=date(2026, 1, 28),
evidence=[EvidencePointer(doc_id="labs_jan", page=1, span_id="tbl_anc")],
),
],
)
assert profile.key_labs[0].value == 1.8
assert profile.key_labs[0].unit == "10^9/L"
def test_treatment(self):
"""Treatment with drug_name, dates, and line of therapy."""
from trialpath.models.patient_profile import PatientProfile, Treatment
profile = PatientProfile(
patient_id="P001",
treatments=[
Treatment(
drug_name="Pembrolizumab",
start_date=date(2024, 6, 1),
end_date=date(2024, 11, 30),
line=1,
),
],
)
assert profile.treatments[0].drug_name == "Pembrolizumab"
assert profile.treatments[0].line == 1
def test_comorbidity(self):
"""Comorbidity with name and grade."""
from trialpath.models.patient_profile import Comorbidity, PatientProfile
profile = PatientProfile(
patient_id="P001",
comorbidities=[
Comorbidity(name="CKD", grade="Stage 3"),
],
)
assert profile.comorbidities[0].name == "CKD"
def test_imaging_summary(self):
"""ImagingSummary with modality, finding, interpretation, certainty."""
from trialpath.models.patient_profile import ImagingSummary, PatientProfile
profile = PatientProfile(
patient_id="P001",
imaging_summary=[
ImagingSummary(
modality="MRI brain",
date=date(2026, 1, 20),
finding="Stable 3mm left frontal lesion",
interpretation="likely inactive scar",
certainty="low",
),
],
)
assert profile.imaging_summary[0].modality == "MRI brain"
assert profile.imaging_summary[0].certainty == "low"
class TestSearchAnchors:
"""SearchAnchors v1 validation tests."""
def test_minimal_anchors(self):
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(condition="NSCLC")
assert anchors.condition == "NSCLC"
assert anchors.trial_filters.recruitment_status == ["Recruiting", "Not yet recruiting"]
def test_full_anchors(self):
from trialpath.models.search_anchors import SearchAnchors, TrialFilters
anchors = SearchAnchors(
condition="Non-Small Cell Lung Cancer",
subtype="adenocarcinoma",
biomarkers=["EGFR exon 19 deletion"],
stage="IV",
age=52,
performance_status_max=1,
trial_filters=TrialFilters(
recruitment_status=["Recruiting"],
phase=["Phase 3"],
),
relaxation_order=["phase", "distance"],
)
assert len(anchors.biomarkers) == 1
assert anchors.trial_filters.phase == ["Phase 3"]
def test_default_relaxation_order(self):
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(condition="NSCLC")
assert anchors.relaxation_order == ["phase", "distance", "biomarker_strictness"]
def test_default_trial_filters(self):
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(condition="NSCLC")
assert anchors.trial_filters.phase == ["Phase 2", "Phase 3"]
def test_geography_filter(self):
from trialpath.models.search_anchors import GeographyFilter, SearchAnchors
anchors = SearchAnchors(
condition="NSCLC",
geography=GeographyFilter(country="DE", max_distance_km=200),
)
assert anchors.geography.country == "DE"
assert anchors.geography.max_distance_km == 200
def test_search_anchors_with_interventions(self):
"""interventions field should serialize correctly."""
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(
condition="NSCLC",
biomarkers=["EGFR exon 19 deletion"],
interventions=["osimertinib", "erlotinib"],
)
assert anchors.interventions == ["osimertinib", "erlotinib"]
data = anchors.model_dump()
assert data["interventions"] == ["osimertinib", "erlotinib"]
def test_search_anchors_with_eligibility_keywords(self):
"""eligibility_keywords field should serialize correctly."""
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(
condition="NSCLC",
eligibility_keywords=["ECOG 0-1", "stage IV", "EGFR mutation"],
)
assert anchors.eligibility_keywords == ["ECOG 0-1", "stage IV", "EGFR mutation"]
data = anchors.model_dump()
assert data["eligibility_keywords"] == ["ECOG 0-1", "stage IV", "EGFR mutation"]
def test_search_anchors_defaults_empty_lists(self):
"""New fields should default to empty lists for backward compatibility."""
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(condition="NSCLC")
assert anchors.interventions == []
assert anchors.eligibility_keywords == []
def test_json_roundtrip(self):
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(
condition="NSCLC",
stage="IV",
age=55,
)
json_str = anchors.model_dump_json()
restored = SearchAnchors.model_validate_json(json_str)
assert restored == anchors
def test_json_roundtrip_with_new_fields(self):
"""JSON roundtrip should preserve interventions and eligibility_keywords."""
from trialpath.models.search_anchors import SearchAnchors
anchors = SearchAnchors(
condition="NSCLC",
interventions=["osimertinib"],
eligibility_keywords=["ECOG 0-1", "stage IV"],
)
json_str = anchors.model_dump_json()
restored = SearchAnchors.model_validate_json(json_str)
assert restored.interventions == ["osimertinib"]
assert restored.eligibility_keywords == ["ECOG 0-1", "stage IV"]
class TestTrialCandidate:
"""TrialCandidate v1 tests."""
def test_trial_with_eligibility_text(self):
from trialpath.models.trial_candidate import EligibilityText, TrialCandidate
trial = TrialCandidate(
nct_id="NCT01234567",
title="Phase 3 Study of Osimertinib",
conditions=["NSCLC"],
phase="Phase 3",
status="Recruiting",
fingerprint_text="Osimertinib EGFR+ NSCLC Phase 3",
eligibility_text=EligibilityText(
inclusion="Histologically confirmed NSCLC stage IV",
exclusion="Prior EGFR TKI therapy",
),
)
assert trial.nct_id == "NCT01234567"
assert trial.eligibility_text.inclusion.startswith("Histologically")
def test_minimal_trial(self):
from trialpath.models.trial_candidate import TrialCandidate
trial = TrialCandidate(
nct_id="NCT99999999",
title="Test Trial",
fingerprint_text="test",
)
assert trial.conditions == []
assert trial.locations == []
assert trial.eligibility_text is None
def test_trial_with_locations(self):
from trialpath.models.trial_candidate import TrialCandidate, TrialLocation
trial = TrialCandidate(
nct_id="NCT01234567",
title="Test Trial",
fingerprint_text="test",
locations=[
TrialLocation(country="DE", city="Berlin"),
TrialLocation(country="DE", city="Hamburg"),
],
)
assert len(trial.locations) == 2
assert trial.locations[0].city == "Berlin"
def test_trial_with_age_range(self):
from trialpath.models.trial_candidate import AgeRange, TrialCandidate
trial = TrialCandidate(
nct_id="NCT01234567",
title="Test Trial",
fingerprint_text="test",
age_range=AgeRange(min=18, max=75),
)
assert trial.age_range.min == 18
assert trial.age_range.max == 75
def test_json_roundtrip(self):
from trialpath.models.trial_candidate import TrialCandidate
trial = TrialCandidate(
nct_id="NCT01234567",
title="Test",
fingerprint_text="test fp",
phase="Phase 2",
)
json_str = trial.model_dump_json()
restored = TrialCandidate.model_validate_json(json_str)
assert restored == trial
class TestEligibilityLedger:
"""EligibilityLedger v1 tests."""
def test_traffic_light_green(self):
from trialpath.models.eligibility_ledger import (
EligibilityLedger,
OverallAssessment,
)
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
)
assert ledger.traffic_light == "green"
def test_traffic_light_yellow(self):
from trialpath.models.eligibility_ledger import (
EligibilityLedger,
OverallAssessment,
)
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.UNCERTAIN,
)
assert ledger.traffic_light == "yellow"
def test_traffic_light_red(self):
from trialpath.models.eligibility_ledger import (
EligibilityLedger,
OverallAssessment,
)
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.LIKELY_INELIGIBLE,
)
assert ledger.traffic_light == "red"
def test_criterion_counts(self):
from trialpath.models.eligibility_ledger import (
CriterionAssessment,
CriterionDecision,
EligibilityLedger,
GapItem,
OverallAssessment,
)
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.UNCERTAIN,
criteria=[
CriterionAssessment(
criterion_id="inc_1",
type="inclusion",
text="Stage IV NSCLC",
decision=CriterionDecision.MET,
),
CriterionAssessment(
criterion_id="inc_2",
type="inclusion",
text="ECOG 0-1",
decision=CriterionDecision.MET,
),
CriterionAssessment(
criterion_id="exc_1",
type="exclusion",
text="No prior immunotherapy",
decision=CriterionDecision.NOT_MET,
),
CriterionAssessment(
criterion_id="inc_3",
type="inclusion",
text="EGFR mutation",
decision=CriterionDecision.UNKNOWN,
),
],
gaps=[
GapItem(
description="EGFR mutation status unknown",
recommended_action="Order EGFR mutation test",
clinical_importance="high",
),
],
)
assert ledger.met_count == 2
assert ledger.not_met_count == 1
assert ledger.unknown_count == 1
assert len(ledger.gaps) == 1
def test_empty_criteria_counts(self):
from trialpath.models.eligibility_ledger import (
EligibilityLedger,
OverallAssessment,
)
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.UNCERTAIN,
)
assert ledger.met_count == 0
assert ledger.not_met_count == 0
assert ledger.unknown_count == 0
def test_json_roundtrip(self):
from trialpath.models.eligibility_ledger import (
EligibilityLedger,
OverallAssessment,
)
ledger = EligibilityLedger(
patient_id="P001",
nct_id="NCT01234567",
overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
)
json_str = ledger.model_dump_json()
restored = EligibilityLedger.model_validate_json(json_str)
assert restored.patient_id == "P001"
assert restored.overall_assessment == OverallAssessment.LIKELY_ELIGIBLE
class TestTemporalCheck:
"""TemporalCheck validation for time-windowed criteria."""
def test_within_window(self):
"""Evidence 7 days old should be within a 14-day window."""
from trialpath.models.eligibility_ledger import TemporalCheck
check = TemporalCheck(
required_window_days=14,
reference_date=date(2026, 1, 20),
evaluation_date=date(2026, 1, 27),
is_within_window=True,
)
assert check.days_elapsed == 7
assert check.is_within_window is True
def test_outside_window(self):
"""Evidence 21 days old should be outside a 14-day window."""
from trialpath.models.eligibility_ledger import TemporalCheck
check = TemporalCheck(
required_window_days=14,
reference_date=date(2026, 1, 1),
evaluation_date=date(2026, 1, 22),
is_within_window=False,
)
assert check.days_elapsed == 21
assert check.is_within_window is False
def test_no_reference_date(self):
"""Missing reference date should yield None for days_elapsed."""
from trialpath.models.eligibility_ledger import TemporalCheck
check = TemporalCheck(
required_window_days=14,
reference_date=None,
)
assert check.days_elapsed is None
assert check.is_within_window is None
def test_criterion_with_temporal_check(self):
"""CriterionAssessment should accept an optional temporal_check."""
from trialpath.models.eligibility_ledger import (
CriterionAssessment,
CriterionDecision,
TemporalCheck,
)
assessment = CriterionAssessment(
criterion_id="inc_5",
type="inclusion",
text="ANC >= 1.5 x 10^9/L within 14 days of enrollment",
decision=CriterionDecision.MET,
temporal_check=TemporalCheck(
required_window_days=14,
reference_date=date(2026, 1, 20),
evaluation_date=date(2026, 1, 27),
is_within_window=True,
),
)
assert assessment.temporal_check is not None
assert assessment.temporal_check.days_elapsed == 7
assert assessment.temporal_check.is_within_window is True
class TestSearchLog:
"""SearchLog v1 -- iterative query refinement tracking tests."""
def test_add_step_increments_count(self):
"""Adding a refinement step should increment total_refinement_rounds."""
from trialpath.models.search_log import RefinementAction, SearchLog
log = SearchLog(session_id="S001", patient_id="P001")
assert log.total_refinement_rounds == 0
log.add_step(
query_params={"condition": "NSCLC"},
result_count=75,
action=RefinementAction.REFINE,
reason="Too many results, adding phase filter",
)
assert log.total_refinement_rounds == 1
assert len(log.steps) == 1
def test_refinement_exhausted_at_max(self):
"""After 5 refinement rounds, is_refinement_exhausted should be True."""
from trialpath.models.search_log import RefinementAction, SearchLog
log = SearchLog(session_id="S001", patient_id="P001")
for i in range(5):
log.add_step(
query_params={"condition": "NSCLC", "round": i},
result_count=0,
action=RefinementAction.RELAX,
reason=f"Relaxation round {i + 1}",
)
assert log.total_refinement_rounds == 5
assert log.is_refinement_exhausted is True
def test_transparency_summary_format(self):
"""to_transparency_summary should return list of dicts with expected keys."""
from trialpath.models.search_log import RefinementAction, SearchLog
log = SearchLog(session_id="S001", patient_id="P001")
log.add_step(
query_params={"condition": "NSCLC"},
result_count=100,
action=RefinementAction.REFINE,
reason="Too many results",
)
log.add_step(
query_params={"condition": "NSCLC", "phase": "Phase 3"},
result_count=25,
action=RefinementAction.SHORTLIST,
reason="Right-sized result set",
)
summary = log.to_transparency_summary()
assert len(summary) == 2
assert summary[0]["step"] == 1
assert summary[0]["found"] == 100
assert summary[0]["action"] == "refine"
assert summary[1]["step"] == 2
assert summary[1]["found"] == 25
assert summary[1]["action"] == "shortlist"
def test_initial_search_no_refinement_count(self):
"""An INITIAL action should not increment the refinement counter."""
from trialpath.models.search_log import RefinementAction, SearchLog
log = SearchLog(session_id="S001", patient_id="P001")
log.add_step(
query_params={"condition": "NSCLC"},
result_count=30,
action=RefinementAction.INITIAL,
reason="First search",
)
assert log.total_refinement_rounds == 0
assert len(log.steps) == 1