CTA / backend /a2a_workflow.py
TheQuantEd's picture
Initial deployment: ClinicalMatch AI v2.0 β€” FHIR R4 Β· MCP (9 tools) Β· A2A workflow Β· SHARP compliance Β· 100k synthetic patients Β· Neo4j graph Β· GraphRAG chatbot
59abb4f
"""A2A (Agent-to-Agent) orchestration workflow β€” state machine for the recruitment pipeline.
Every inter-agent message carries a SHARP Extension Spec context envelope:
sharp_version, patient_context (id, fhir_ref, fhir_base, tenant_id, session_id),
data_classification, baa_in_scope, consent_status
"""
import uuid
import time
from datetime import datetime
from enum import Enum
from typing import Any
from fhir_adapter import get_patient_profile, get_mock_fhir_patient, build_patient_profile
from clinicaltrials_api import search_trials_sync, get_trial_details_sync
from matching_engine import get_criteria_for_trial, score_patient_for_trial, match_patient_to_trials
from llm_client import generate_outreach_message, summarize_trial
from fhir_server import build_sharp_context, get_live_patient_profile
import consent_agent
class WorkflowState(str, Enum):
PENDING = "PENDING"
INGESTING = "INGESTING"
PARSING_PROTOCOL = "PARSING_PROTOCOL"
MATCHING = "MATCHING"
SCORING = "SCORING"
RECRUITING = "RECRUITING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
# In-memory workflow store (production: use Redis or Neo4j)
_workflows: dict[str, dict] = {}
def _emit_event(workflow_id: str, state: WorkflowState, message: str, data: Any = None):
workflow = _workflows[workflow_id]
event = {
"state": state,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
"data": data,
# SHARP envelope on every event so downstream agents have full context
"sharp_context": workflow.get("sharp_context", {}),
}
workflow["events"].append(event)
workflow["current_state"] = state
workflow["updated_at"] = datetime.utcnow().isoformat()
print(f"[A2A:{workflow_id[:8]}] {state} β€” {message}")
# ── Sub-agents ────────────────────────────────────────────────────────────────
def _agent_ingest_patient(workflow_id: str, patient_id: str) -> dict:
"""Sub-agent: Ingest and validate patient FHIR data."""
_emit_event(workflow_id, WorkflowState.INGESTING, f"Ingesting FHIR R4 data for patient {patient_id}")
time.sleep(0.3) # Simulate async data fetch
fhir_patient = get_mock_fhir_patient(patient_id)
if not fhir_patient:
raise ValueError(f"Patient {patient_id} not found in FHIR registry")
profile = build_patient_profile(fhir_patient)
_emit_event(workflow_id, WorkflowState.INGESTING,
f"FHIR data loaded: {len(fhir_patient.conditions)} conditions, {len(fhir_patient.medications)} medications",
{"profile": profile})
return profile
def _agent_parse_protocol(workflow_id: str, nct_id: str | None, condition: str) -> tuple[list[dict], dict]:
"""Sub-agent: Parse trial protocol and extract criteria."""
_emit_event(workflow_id, WorkflowState.PARSING_PROTOCOL,
f"Parsing trial protocols for condition: {condition}")
time.sleep(0.5)
if nct_id:
trials = [get_trial_details_sync(nct_id)]
trials = [t for t in trials if t]
else:
trials = search_trials_sync(condition, page_size=8)
if not trials:
raise ValueError(f"No trials found for condition: {condition}")
# Parse criteria for each trial using LLM
parsed_trials = []
for trial in trials[:5]: # Limit to avoid timeout
criteria = get_criteria_for_trial(trial)
parsed_trials.append({**trial, "parsed_criteria": criteria})
summary = summarize_trial(trials[0]) if trials else ""
_emit_event(workflow_id, WorkflowState.PARSING_PROTOCOL,
f"Parsed {len(parsed_trials)} trial protocols",
{"trial_count": len(parsed_trials), "protocol_summary": summary})
return parsed_trials, {"summary": summary}
def _agent_match(workflow_id: str, patient_profile: dict, trials: list[dict]) -> list[dict]:
"""Sub-agent: Semantic matching of patient to trials."""
_emit_event(workflow_id, WorkflowState.MATCHING,
f"Running semantic matching for patient {patient_profile['patient_id']} against {len(trials)} trials")
time.sleep(0.3)
candidates = []
for trial in trials:
score_result = score_patient_for_trial(patient_profile["patient_id"], trial)
candidates.append({
**trial,
"match_score": score_result.get("overall_score", 0.0),
"eligible": score_result.get("eligible", False),
"inclusion_results": score_result.get("inclusion_results", []),
"exclusion_results": score_result.get("exclusion_results", []),
"match_summary": score_result.get("summary", ""),
"risk_flags": score_result.get("risk_flags", []),
})
candidates.sort(key=lambda x: x["match_score"], reverse=True)
eligible = [c for c in candidates if c["eligible"]]
_emit_event(workflow_id, WorkflowState.MATCHING,
f"Matching complete: {len(eligible)}/{len(candidates)} trials eligible",
{"eligible_count": len(eligible), "top_score": candidates[0]["match_score"] if candidates else 0})
return candidates
def _agent_score(workflow_id: str, candidates: list[dict], patient_profile: dict) -> list[dict]:
"""Sub-agent: Predictive screening scoring with risk flags."""
_emit_event(workflow_id, WorkflowState.SCORING, "Running predictive screening analysis")
time.sleep(0.2)
for candidate in candidates:
flags = candidate.get("risk_flags", [])
# Add distance risk flag if no nearby sites
locs = candidate.get("locations", [])
if not locs:
flags.append("No site location data available")
# Add data completeness flag
if not patient_profile.get("biomarkers"):
flags.append("Biomarker data incomplete β€” may affect screening")
candidate["risk_flags"] = flags
candidate["screening_priority"] = (
"HIGH" if candidate["match_score"] >= 0.8
else "MEDIUM" if candidate["match_score"] >= 0.5
else "LOW"
)
_emit_event(workflow_id, WorkflowState.SCORING,
"Screening scoring complete",
{"high_priority": sum(1 for c in candidates if c.get("screening_priority") == "HIGH")})
return candidates
def _agent_recruit(workflow_id: str, candidates: list[dict], patient_profile: dict) -> list[dict]:
"""Sub-agent: Generate recruitment outreach for eligible candidates."""
_emit_event(workflow_id, WorkflowState.RECRUITING, "Generating personalized recruitment communications")
eligible = [c for c in candidates if c.get("eligible")][:3]
recruitment_records = []
for trial in eligible:
try:
outreach = generate_outreach_message(patient_profile, trial, "patient_email")
pcp_letter = generate_outreach_message(patient_profile, trial, "pcp_letter")
# A2A handoff β†’ consent agent (SHARP envelope attached)
consent_task = {
"task_id": f"consent_{workflow_id}_{trial.get('nct_id','')}",
"type": "CONSENT_REQUEST",
"payload": {
"patient_id": patient_profile.get("patient_id", ""),
"nct_id": trial.get("nct_id", ""),
"trial_title": trial.get("title", ""),
"match_score": trial.get("match_score", 0.0),
},
"sharp_context": _workflows[workflow_id].get("sharp_context", {}),
}
consent_result = consent_agent.receive_a2a_task(consent_task)
recruitment_records.append({
"nct_id": trial.get("nct_id", ""),
"trial_title": trial.get("title", ""),
"match_score": trial.get("match_score", 0.0),
"patient_email": outreach,
"pcp_letter": pcp_letter,
"status": "PENDING",
"consent_id": consent_result.get("consent_id"),
"consent_status": consent_result.get("status", "PENDING"),
"created_at": datetime.utcnow().isoformat(),
})
except Exception as e:
recruitment_records.append({
"nct_id": trial.get("nct_id", ""),
"trial_title": trial.get("title", ""),
"error": str(e),
"status": "ERROR",
})
_emit_event(workflow_id, WorkflowState.RECRUITING,
f"Generated outreach for {len(recruitment_records)} trials",
{"record_count": len(recruitment_records)})
return recruitment_records
# ── Public API ─────────────────────────────────────────────────────────────────
def start_pipeline(
patient_id: str,
nct_id: str | None = None,
condition: str | None = None,
fhir_token: str | None = None,
fhir_base_url: str | None = None,
session_id: str | None = None,
) -> str:
"""Start the A2A pipeline and return a workflow_id."""
workflow_id = str(uuid.uuid4())
sharp_ctx = build_sharp_context(
patient_id=patient_id,
fhir_ref=f"Patient/{patient_id}",
session_id=session_id or workflow_id,
)
if fhir_token:
sharp_ctx["fhir_token"] = fhir_token
if fhir_base_url:
sharp_ctx["patient_context"]["fhir_base"] = fhir_base_url
_workflows[workflow_id] = {
"workflow_id": workflow_id,
"patient_id": patient_id,
"nct_id": nct_id,
"condition": condition,
"current_state": WorkflowState.PENDING,
"events": [],
"result": None,
"sharp_context": sharp_ctx,
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat(),
}
return workflow_id
def run_pipeline(workflow_id: str) -> dict:
"""Execute the full A2A pipeline synchronously."""
workflow = _workflows.get(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
patient_id = workflow["patient_id"]
nct_id = workflow.get("nct_id")
condition = workflow.get("condition")
try:
# Agent 1: Ingest FHIR patient data
patient_profile = _agent_ingest_patient(workflow_id, patient_id)
# Infer condition
if not condition and patient_profile.get("diagnosis_names"):
condition = patient_profile["diagnosis_names"][0]
elif not condition:
condition = "cancer"
# Agent 2: Parse trial protocols
trials, protocol_meta = _agent_parse_protocol(workflow_id, nct_id, condition)
# Agent 3: Semantic matching
candidates = _agent_match(workflow_id, patient_profile, trials)
# Agent 4: Predictive scoring
candidates = _agent_score(workflow_id, candidates, patient_profile)
# Agent 5: Recruitment communication
recruitment_records = _agent_recruit(workflow_id, candidates, patient_profile)
result = {
"patient_profile": patient_profile,
"matched_trials": candidates,
"recruitment_records": recruitment_records,
"protocol_summary": protocol_meta.get("summary", ""),
"total_trials_evaluated": len(trials),
"eligible_trials": sum(1 for c in candidates if c.get("eligible")),
}
workflow["result"] = result
_emit_event(workflow_id, WorkflowState.COMPLETED,
f"Pipeline complete: {result['eligible_trials']} eligible trials found", result)
except Exception as e:
_emit_event(workflow_id, WorkflowState.FAILED, f"Pipeline failed: {str(e)}")
workflow["error"] = str(e)
return _workflows[workflow_id]
def get_workflow_status(workflow_id: str) -> dict:
workflow = _workflows.get(workflow_id)
if not workflow:
return {"error": "Workflow not found"}
return {
"workflow_id": workflow_id,
"current_state": workflow["current_state"],
"events": workflow["events"][-10:], # Last 10 events
"result": workflow.get("result"),
"error": workflow.get("error"),
"created_at": workflow["created_at"],
"updated_at": workflow["updated_at"],
}
def list_workflows() -> list[dict]:
return [
{
"workflow_id": wf["workflow_id"],
"patient_id": wf["patient_id"],
"current_state": wf["current_state"],
"created_at": wf["created_at"],
}
for wf in _workflows.values()
]