TrialPath / trialpath /services /gemini_planner.py
yakilee's picture
feat(gemini): update search anchors prompt for interventions and eligibility keywords
65a898a
"""Gemini structured output for SearchAnchors generation and eligibility evaluation."""
import asyncio
import json
import time
import structlog
from google import genai
logger = structlog.get_logger("trialpath.gemini")
from trialpath.config import GEMINI_API_KEY, GEMINI_MODEL
from trialpath.models.eligibility_ledger import EligibilityLedger
from trialpath.models.search_anchors import SearchAnchors
class GeminiPlanner:
"""Orchestration layer using Gemini for structured output."""
def __init__(
self,
model: str | None = None,
api_key: str | None = None,
max_concurrent: int = 5,
):
self.model = model or GEMINI_MODEL
self.client = genai.Client(api_key=api_key or GEMINI_API_KEY)
self._semaphore = asyncio.Semaphore(max_concurrent)
def _generate_sync(self, prompt: str, schema: dict) -> str:
"""Blocking Gemini call (runs on a worker thread via _generate)."""
start = time.monotonic()
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config={
"response_mime_type": "application/json",
"response_json_schema": schema,
},
)
elapsed = time.monotonic() - start
text = response.text
logger.info(
"gemini_call",
model=self.model,
prompt_chars=len(prompt),
response_chars=len(text or ""),
duration_s=round(elapsed, 2),
schema_type=schema.get("title", "unknown"),
)
if text is None:
raise ValueError("Gemini returned empty response")
# Guard against truncated JSON from hitting token limits
try:
json.loads(text)
except json.JSONDecodeError as exc:
raise ValueError(f"Gemini returned truncated JSON ({len(text)} chars): {exc}") from exc
return text
async def _generate(self, prompt: str, schema: dict) -> str:
"""Call Gemini with structured JSON output, rate-limited by semaphore."""
async with self._semaphore:
return await asyncio.to_thread(self._generate_sync, prompt, schema)
async def generate_search_anchors(self, patient_profile: dict) -> SearchAnchors:
"""Use Gemini structured output to generate SearchAnchors from PatientProfile."""
prompt = f"""
Given the following patient profile, generate search parameters
for finding relevant NSCLC clinical trials on ClinicalTrials.gov.
Patient Profile:
{json.dumps(patient_profile, indent=2)}
Generate SearchAnchors that:
1. Focus on the patient's specific cancer type, stage, and biomarkers
2. Include appropriate geographic filters
3. Consider the patient's age and performance status
4. Set a relaxation_order for broadening search if too few results
5. Generate a list of target interventions (drug names) based on:
- Known effective drugs for the patient's biomarkers (e.g., EGFR β†’ osimertinib, erlotinib; ALK β†’ alectinib, crizotinib; KRAS G12C β†’ sotorasib, adagrasib; ROS1 β†’ crizotinib, entrectinib; BRAF V600E β†’ dabrafenib + trametinib)
- Next-line therapies after the patient's current treatments
- Immunotherapy drugs if no targetable mutations (e.g., pembrolizumab, nivolumab, atezolizumab)
6. Generate eligibility_keywords from the patient's key clinical features:
- Performance status (e.g., "ECOG 0-1")
- Stage (e.g., "stage IV", "stage IIIB")
- Key biomarker terms as they appear in trial criteria (e.g., "EGFR mutation", "PD-L1 positive")
- Prior therapy line (e.g., "first-line", "second-line")
"""
raw = await self._generate(prompt, SearchAnchors.model_json_schema())
return SearchAnchors.model_validate_json(raw)
async def evaluate_eligibility(
self,
patient_profile: dict,
trial_candidate: dict,
search_log: object | None = None,
) -> EligibilityLedger:
"""Use Gemini to evaluate eligibility for a single trial."""
prompt = f"""
Evaluate this patient's eligibility for the clinical trial below.
For each inclusion/exclusion criterion:
1. Assign a criterion_id (inc_1, inc_2, ... or exc_1, exc_2, ...)
2. Determine if the criterion is met, not_met, or unknown
3. Provide reasoning and evidence pointers
Patient Profile:
{json.dumps(patient_profile, indent=2, default=str)}
Trial:
{json.dumps(trial_candidate, indent=2, default=str)}
Also identify gaps: criteria that are 'unknown' where additional data
could change the assessment.
"""
raw = await self._generate(prompt, EligibilityLedger.model_json_schema())
return EligibilityLedger.model_validate_json(raw)
async def refine_search(
self,
anchors: SearchAnchors,
result_count: int,
search_log: object | None = None,
) -> SearchAnchors:
"""Tighten search filters when too many results returned."""
prompt = f"""
The current search returned {result_count} results, which is too many.
Tighten the search parameters to reduce the result set.
Current SearchAnchors:
{anchors.model_dump_json(indent=2)}
Strategies:
1. Add more specific biomarker terms
2. Restrict phases to Phase 3 only
3. Narrow geographic scope
4. Tighten recruitment status to "Recruiting" only
Return refined SearchAnchors with stricter filters.
"""
raw = await self._generate(prompt, SearchAnchors.model_json_schema())
return SearchAnchors.model_validate_json(raw)
async def relax_search(
self,
anchors: SearchAnchors,
result_count: int,
search_log: object | None = None,
) -> SearchAnchors:
"""Loosen search filters when too few results returned."""
prompt = f"""
The current search returned {result_count} results, which is too few.
Relax the search parameters following the relaxation_order.
Current SearchAnchors:
{anchors.model_dump_json(indent=2)}
Relaxation order: {anchors.relaxation_order}
Strategies (in order of relaxation_order):
1. Broaden phases (add Phase 1, Phase 2)
2. Remove geographic restrictions
3. Remove specific biomarker filters
4. Broaden recruitment status
Return relaxed SearchAnchors with looser filters.
"""
raw = await self._generate(prompt, SearchAnchors.model_json_schema())
return SearchAnchors.model_validate_json(raw)
async def slice_criteria(self, trial: dict) -> list[dict]:
"""Split eligibility text into atomic criteria with type classification."""
schema = {
"type": "object",
"properties": {
"criteria": {
"type": "array",
"items": {
"type": "object",
"properties": {
"criterion_id": {"type": "string"},
"type": {"type": "string", "enum": ["inclusion", "exclusion"]},
"text": {"type": "string"},
"category": {
"type": "string",
"enum": ["medical", "structural"],
},
},
"required": ["criterion_id", "type", "text", "category"],
},
},
},
"required": ["criteria"],
}
prompt = f"""
Split the following clinical trial eligibility criteria into atomic,
individually evaluatable criteria. Classify each as 'medical'
(requires clinical knowledge) or 'structural' (age, geography, consent).
Trial:
{json.dumps(trial, indent=2, default=str)}
For each criterion:
- Assign a criterion_id (inc_1, inc_2, ... or exc_1, exc_2, ...)
- Set type to "inclusion" or "exclusion"
- Extract the exact text
- Classify as "medical" or "structural"
"""
raw = await self._generate(prompt, schema)
parsed = json.loads(raw)
return parsed["criteria"]
async def evaluate_structural_criterion(
self,
criterion_text: str,
patient_profile: dict,
) -> dict:
"""Evaluate non-medical criteria (age, geography, consent)."""
schema = {
"type": "object",
"properties": {
"decision": {
"type": "string",
"enum": ["met", "not_met", "unknown"],
},
"reasoning": {"type": "string"},
"confidence": {"type": "number"},
},
"required": ["decision", "reasoning", "confidence"],
}
prompt = f"""
Evaluate whether the patient meets this structural (non-medical)
clinical trial criterion.
CRITERION: {criterion_text}
Patient Profile:
{json.dumps(patient_profile, indent=2, default=str)}
Return your assessment with decision, reasoning, and confidence score.
"""
raw = await self._generate(prompt, schema)
return json.loads(raw)
async def aggregate_assessments(
self,
profile: dict,
trial: dict,
assessments: list[dict],
) -> EligibilityLedger:
"""Aggregate per-criterion assessments into overall trial eligibility."""
prompt = f"""
Given individual criterion assessments for a clinical trial,
determine the overall eligibility assessment.
Patient Profile:
{json.dumps(profile, indent=2, default=str)}
Trial:
{json.dumps(trial, indent=2, default=str)}
Individual Assessments:
{json.dumps(assessments, indent=2, default=str)}
Rules:
- If ANY exclusion criterion is met β†’ likely_ineligible
- If ANY inclusion criterion is not_met β†’ likely_ineligible
- If all criteria are met β†’ likely_eligible
- Otherwise β†’ uncertain
Also identify gaps where unknown criteria could change the outcome.
"""
raw = await self._generate(prompt, EligibilityLedger.model_json_schema())
return EligibilityLedger.model_validate_json(raw)
async def analyze_gaps(
self,
profile: dict,
ledgers: list[dict],
) -> list[dict]:
"""Identify minimal actionable set of missing data across all trials."""
schema = {
"type": "object",
"properties": {
"gaps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"description": {"type": "string"},
"recommended_action": {"type": "string"},
"clinical_importance": {
"type": "string",
"enum": ["high", "medium", "low"],
},
"affected_trial_count": {"type": "integer"},
},
"required": [
"description",
"recommended_action",
"clinical_importance",
"affected_trial_count",
],
},
},
},
"required": ["gaps"],
}
prompt = f"""
Analyze the eligibility ledgers across all trials and identify
the minimal set of actionable gaps that would most improve the
patient's trial matching.
Patient Profile:
{json.dumps(profile, indent=2, default=str)}
Eligibility Ledgers:
{json.dumps(ledgers, indent=2, default=str)}
Prioritize gaps by:
1. Number of trials affected
2. Clinical importance (high > medium > low)
3. Feasibility of obtaining the missing data
Return a deduplicated list of gaps with recommended actions.
"""
raw = await self._generate(prompt, schema)
parsed = json.loads(raw)
return parsed["gaps"]