yakilee Claude Opus 4.6 commited on
Commit
565148b
·
1 Parent(s): 974edcf

fix: cache service instances in Parlant tools to avoid per-call overhead

Browse files

Replace per-call instantiation of MedGemmaExtractor, GeminiPlanner, and
ClinicalTrialsMCPClient with lazy module-level singletons via _get_*()
helpers. Prevents hundreds of redundant client initializations during
evaluate_trial_eligibility's inner criterion loop.

Add autouse fixture in test_tools.py to reset singletons between tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

trialpath/agent/tools.py CHANGED
@@ -11,6 +11,49 @@ from trialpath.config import (
11
  MEDGEMMA_ENDPOINT_URL,
12
  )
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @tool
16
  async def extract_patient_profile(
@@ -25,12 +68,7 @@ async def extract_patient_profile(
25
  document_urls: JSON list of document file paths.
26
  metadata: JSON object with known patient metadata (age, sex).
27
  """
28
- from trialpath.services.medgemma_extractor import MedGemmaExtractor
29
-
30
- extractor = MedGemmaExtractor(
31
- endpoint_url=MEDGEMMA_ENDPOINT_URL,
32
- hf_token=HF_TOKEN,
33
- )
34
  urls = json.loads(document_urls)
35
  meta = json.loads(metadata)
36
  profile = await extractor.extract(urls, meta)
@@ -52,9 +90,7 @@ async def generate_search_anchors(
52
  context: Parlant tool context.
53
  patient_profile: JSON string of PatientProfile data.
54
  """
55
- from trialpath.services.gemini_planner import GeminiPlanner
56
-
57
- planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
58
  profile = json.loads(patient_profile)
59
  anchors = await planner.generate_search_anchors(profile)
60
 
@@ -76,12 +112,13 @@ async def search_clinical_trials(
76
  search_anchors: JSON string of SearchAnchors data.
77
  """
78
  from trialpath.models.search_anchors import SearchAnchors
79
- from trialpath.services.mcp_client import ClinicalTrialsMCPClient
80
 
81
- client = ClinicalTrialsMCPClient(mcp_url=MCP_URL)
82
  anchors = SearchAnchors.model_validate(json.loads(search_anchors))
83
  raw_studies = await client.search(anchors)
84
 
 
 
85
  trials = [
86
  ClinicalTrialsMCPClient.normalize_trial(s).model_dump()
87
  for s in raw_studies
@@ -107,9 +144,8 @@ async def refine_search_query(
107
  result_count: Number of results from last search.
108
  """
109
  from trialpath.models.search_anchors import SearchAnchors
110
- from trialpath.services.gemini_planner import GeminiPlanner
111
 
112
- planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
113
  anchors = SearchAnchors.model_validate(json.loads(search_anchors))
114
  refined = await planner.refine_search(anchors, int(result_count))
115
 
@@ -133,9 +169,8 @@ async def relax_search_query(
133
  result_count: Number of results from last search.
134
  """
135
  from trialpath.models.search_anchors import SearchAnchors
136
- from trialpath.services.gemini_planner import GeminiPlanner
137
 
138
- planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
139
  anchors = SearchAnchors.model_validate(json.loads(search_anchors))
140
  relaxed = await planner.relax_search(anchors, int(result_count))
141
 
@@ -160,17 +195,11 @@ async def evaluate_trial_eligibility(
160
  patient_profile: JSON string of PatientProfile data.
161
  trial_candidate: JSON string of TrialCandidate data.
162
  """
163
- from trialpath.services.gemini_planner import GeminiPlanner
164
- from trialpath.services.medgemma_extractor import MedGemmaExtractor
165
-
166
  profile = json.loads(patient_profile)
167
  trial = json.loads(trial_candidate)
168
 
169
- planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
170
- extractor = MedGemmaExtractor(
171
- endpoint_url=MEDGEMMA_ENDPOINT_URL,
172
- hf_token=HF_TOKEN,
173
- )
174
 
175
  # Step 1: Slice criteria into atomic items
176
  criteria = await planner.slice_criteria(trial)
@@ -210,9 +239,7 @@ async def analyze_gaps(
210
  patient_profile: JSON string of PatientProfile data.
211
  eligibility_ledgers: JSON list of EligibilityLedger data.
212
  """
213
- from trialpath.services.gemini_planner import GeminiPlanner
214
-
215
- planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
216
  profile = json.loads(patient_profile)
217
  ledgers = json.loads(eligibility_ledgers)
218
  gaps = await planner.analyze_gaps(profile, ledgers)
 
11
  MEDGEMMA_ENDPOINT_URL,
12
  )
13
 
14
+ # ---------------------------------------------------------------------------
15
+ # Lazy singletons — one instance per service, reused across tool calls.
16
+ # ---------------------------------------------------------------------------
17
+
18
+ _extractor = None
19
+ _planner = None
20
+ _mcp_client = None
21
+
22
+
23
+ def _get_extractor():
24
+ global _extractor
25
+ if _extractor is None:
26
+ from trialpath.services.medgemma_extractor import MedGemmaExtractor
27
+
28
+ _extractor = MedGemmaExtractor(
29
+ endpoint_url=MEDGEMMA_ENDPOINT_URL,
30
+ hf_token=HF_TOKEN,
31
+ )
32
+ return _extractor
33
+
34
+
35
+ def _get_planner():
36
+ global _planner
37
+ if _planner is None:
38
+ from trialpath.services.gemini_planner import GeminiPlanner
39
+
40
+ _planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
41
+ return _planner
42
+
43
+
44
+ def _get_mcp_client():
45
+ global _mcp_client
46
+ if _mcp_client is None:
47
+ from trialpath.services.mcp_client import ClinicalTrialsMCPClient
48
+
49
+ _mcp_client = ClinicalTrialsMCPClient(mcp_url=MCP_URL)
50
+ return _mcp_client
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Tools
55
+ # ---------------------------------------------------------------------------
56
+
57
 
58
  @tool
59
  async def extract_patient_profile(
 
68
  document_urls: JSON list of document file paths.
69
  metadata: JSON object with known patient metadata (age, sex).
70
  """
71
+ extractor = _get_extractor()
 
 
 
 
 
72
  urls = json.loads(document_urls)
73
  meta = json.loads(metadata)
74
  profile = await extractor.extract(urls, meta)
 
90
  context: Parlant tool context.
91
  patient_profile: JSON string of PatientProfile data.
92
  """
93
+ planner = _get_planner()
 
 
94
  profile = json.loads(patient_profile)
95
  anchors = await planner.generate_search_anchors(profile)
96
 
 
112
  search_anchors: JSON string of SearchAnchors data.
113
  """
114
  from trialpath.models.search_anchors import SearchAnchors
 
115
 
116
+ client = _get_mcp_client()
117
  anchors = SearchAnchors.model_validate(json.loads(search_anchors))
118
  raw_studies = await client.search(anchors)
119
 
120
+ from trialpath.services.mcp_client import ClinicalTrialsMCPClient
121
+
122
  trials = [
123
  ClinicalTrialsMCPClient.normalize_trial(s).model_dump()
124
  for s in raw_studies
 
144
  result_count: Number of results from last search.
145
  """
146
  from trialpath.models.search_anchors import SearchAnchors
 
147
 
148
+ planner = _get_planner()
149
  anchors = SearchAnchors.model_validate(json.loads(search_anchors))
150
  refined = await planner.refine_search(anchors, int(result_count))
151
 
 
169
  result_count: Number of results from last search.
170
  """
171
  from trialpath.models.search_anchors import SearchAnchors
 
172
 
173
+ planner = _get_planner()
174
  anchors = SearchAnchors.model_validate(json.loads(search_anchors))
175
  relaxed = await planner.relax_search(anchors, int(result_count))
176
 
 
195
  patient_profile: JSON string of PatientProfile data.
196
  trial_candidate: JSON string of TrialCandidate data.
197
  """
 
 
 
198
  profile = json.loads(patient_profile)
199
  trial = json.loads(trial_candidate)
200
 
201
+ planner = _get_planner()
202
+ extractor = _get_extractor()
 
 
 
203
 
204
  # Step 1: Slice criteria into atomic items
205
  criteria = await planner.slice_criteria(trial)
 
239
  patient_profile: JSON string of PatientProfile data.
240
  eligibility_ledgers: JSON list of EligibilityLedger data.
241
  """
242
+ planner = _get_planner()
 
 
243
  profile = json.loads(patient_profile)
244
  ledgers = json.loads(eligibility_ledgers)
245
  gaps = await planner.analyze_gaps(profile, ledgers)
trialpath/tests/test_tools.py CHANGED
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
4
 
5
  import pytest
6
 
 
7
  from trialpath.agent.tools import (
8
  ALL_TOOLS,
9
  analyze_gaps,
@@ -16,6 +17,18 @@ from trialpath.agent.tools import (
16
  )
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @pytest.fixture
20
  def mock_context():
21
  return MagicMock()
 
4
 
5
  import pytest
6
 
7
+ import trialpath.agent.tools as tools_module
8
  from trialpath.agent.tools import (
9
  ALL_TOOLS,
10
  analyze_gaps,
 
17
  )
18
 
19
 
20
+ @pytest.fixture(autouse=True)
21
+ def _reset_singletons():
22
+ """Reset cached service singletons between tests."""
23
+ tools_module._extractor = None
24
+ tools_module._planner = None
25
+ tools_module._mcp_client = None
26
+ yield
27
+ tools_module._extractor = None
28
+ tools_module._planner = None
29
+ tools_module._mcp_client = None
30
+
31
+
32
  @pytest.fixture
33
  def mock_context():
34
  return MagicMock()