"""A2A (Agent-to-Agent) orchestration workflow — state machine for the recruitment pipeline. Every inter-agent message carries a SHARP Extension Spec context envelope: sharp_version, patient_context (id, fhir_ref, fhir_base, tenant_id, session_id), data_classification, baa_in_scope, consent_status """ import uuid import time from datetime import datetime from enum import Enum from typing import Any from fhir_adapter import get_patient_profile, get_mock_fhir_patient, build_patient_profile from clinicaltrials_api import search_trials_sync, get_trial_details_sync from matching_engine import get_criteria_for_trial, score_patient_for_trial, match_patient_to_trials from llm_client import generate_outreach_message, summarize_trial from fhir_server import build_sharp_context, get_live_patient_profile import consent_agent class WorkflowState(str, Enum): PENDING = "PENDING" INGESTING = "INGESTING" PARSING_PROTOCOL = "PARSING_PROTOCOL" MATCHING = "MATCHING" SCORING = "SCORING" RECRUITING = "RECRUITING" COMPLETED = "COMPLETED" FAILED = "FAILED" # In-memory workflow store (production: use Redis or Neo4j) _workflows: dict[str, dict] = {} def _emit_event(workflow_id: str, state: WorkflowState, message: str, data: Any = None): workflow = _workflows[workflow_id] event = { "state": state, "message": message, "timestamp": datetime.utcnow().isoformat(), "data": data, # SHARP envelope on every event so downstream agents have full context "sharp_context": workflow.get("sharp_context", {}), } workflow["events"].append(event) workflow["current_state"] = state workflow["updated_at"] = datetime.utcnow().isoformat() print(f"[A2A:{workflow_id[:8]}] {state} — {message}") # ── Sub-agents ──────────────────────────────────────────────────────────────── def _agent_ingest_patient(workflow_id: str, patient_id: str) -> dict: """Sub-agent: Ingest and validate patient FHIR data.""" _emit_event(workflow_id, WorkflowState.INGESTING, f"Ingesting FHIR R4 data for patient {patient_id}") time.sleep(0.3) # Simulate async data fetch fhir_patient = get_mock_fhir_patient(patient_id) if not fhir_patient: raise ValueError(f"Patient {patient_id} not found in FHIR registry") profile = build_patient_profile(fhir_patient) _emit_event(workflow_id, WorkflowState.INGESTING, f"FHIR data loaded: {len(fhir_patient.conditions)} conditions, {len(fhir_patient.medications)} medications", {"profile": profile}) return profile def _agent_parse_protocol(workflow_id: str, nct_id: str | None, condition: str) -> tuple[list[dict], dict]: """Sub-agent: Parse trial protocol and extract criteria.""" _emit_event(workflow_id, WorkflowState.PARSING_PROTOCOL, f"Parsing trial protocols for condition: {condition}") time.sleep(0.5) if nct_id: trials = [get_trial_details_sync(nct_id)] trials = [t for t in trials if t] else: trials = search_trials_sync(condition, page_size=8) if not trials: raise ValueError(f"No trials found for condition: {condition}") # Parse criteria for each trial using LLM parsed_trials = [] for trial in trials[:5]: # Limit to avoid timeout criteria = get_criteria_for_trial(trial) parsed_trials.append({**trial, "parsed_criteria": criteria}) summary = summarize_trial(trials[0]) if trials else "" _emit_event(workflow_id, WorkflowState.PARSING_PROTOCOL, f"Parsed {len(parsed_trials)} trial protocols", {"trial_count": len(parsed_trials), "protocol_summary": summary}) return parsed_trials, {"summary": summary} def _agent_match(workflow_id: str, patient_profile: dict, trials: list[dict]) -> list[dict]: """Sub-agent: Semantic matching of patient to trials.""" _emit_event(workflow_id, WorkflowState.MATCHING, f"Running semantic matching for patient {patient_profile['patient_id']} against {len(trials)} trials") time.sleep(0.3) candidates = [] for trial in trials: score_result = score_patient_for_trial(patient_profile["patient_id"], trial) candidates.append({ **trial, "match_score": score_result.get("overall_score", 0.0), "eligible": score_result.get("eligible", False), "inclusion_results": score_result.get("inclusion_results", []), "exclusion_results": score_result.get("exclusion_results", []), "match_summary": score_result.get("summary", ""), "risk_flags": score_result.get("risk_flags", []), }) candidates.sort(key=lambda x: x["match_score"], reverse=True) eligible = [c for c in candidates if c["eligible"]] _emit_event(workflow_id, WorkflowState.MATCHING, f"Matching complete: {len(eligible)}/{len(candidates)} trials eligible", {"eligible_count": len(eligible), "top_score": candidates[0]["match_score"] if candidates else 0}) return candidates def _agent_score(workflow_id: str, candidates: list[dict], patient_profile: dict) -> list[dict]: """Sub-agent: Predictive screening scoring with risk flags.""" _emit_event(workflow_id, WorkflowState.SCORING, "Running predictive screening analysis") time.sleep(0.2) for candidate in candidates: flags = candidate.get("risk_flags", []) # Add distance risk flag if no nearby sites locs = candidate.get("locations", []) if not locs: flags.append("No site location data available") # Add data completeness flag if not patient_profile.get("biomarkers"): flags.append("Biomarker data incomplete — may affect screening") candidate["risk_flags"] = flags candidate["screening_priority"] = ( "HIGH" if candidate["match_score"] >= 0.8 else "MEDIUM" if candidate["match_score"] >= 0.5 else "LOW" ) _emit_event(workflow_id, WorkflowState.SCORING, "Screening scoring complete", {"high_priority": sum(1 for c in candidates if c.get("screening_priority") == "HIGH")}) return candidates def _agent_recruit(workflow_id: str, candidates: list[dict], patient_profile: dict) -> list[dict]: """Sub-agent: Generate recruitment outreach for eligible candidates.""" _emit_event(workflow_id, WorkflowState.RECRUITING, "Generating personalized recruitment communications") eligible = [c for c in candidates if c.get("eligible")][:3] recruitment_records = [] for trial in eligible: try: outreach = generate_outreach_message(patient_profile, trial, "patient_email") pcp_letter = generate_outreach_message(patient_profile, trial, "pcp_letter") # A2A handoff → consent agent (SHARP envelope attached) consent_task = { "task_id": f"consent_{workflow_id}_{trial.get('nct_id','')}", "type": "CONSENT_REQUEST", "payload": { "patient_id": patient_profile.get("patient_id", ""), "nct_id": trial.get("nct_id", ""), "trial_title": trial.get("title", ""), "match_score": trial.get("match_score", 0.0), }, "sharp_context": _workflows[workflow_id].get("sharp_context", {}), } consent_result = consent_agent.receive_a2a_task(consent_task) recruitment_records.append({ "nct_id": trial.get("nct_id", ""), "trial_title": trial.get("title", ""), "match_score": trial.get("match_score", 0.0), "patient_email": outreach, "pcp_letter": pcp_letter, "status": "PENDING", "consent_id": consent_result.get("consent_id"), "consent_status": consent_result.get("status", "PENDING"), "created_at": datetime.utcnow().isoformat(), }) except Exception as e: recruitment_records.append({ "nct_id": trial.get("nct_id", ""), "trial_title": trial.get("title", ""), "error": str(e), "status": "ERROR", }) _emit_event(workflow_id, WorkflowState.RECRUITING, f"Generated outreach for {len(recruitment_records)} trials", {"record_count": len(recruitment_records)}) return recruitment_records # ── Public API ───────────────────────────────────────────────────────────────── def start_pipeline( patient_id: str, nct_id: str | None = None, condition: str | None = None, fhir_token: str | None = None, fhir_base_url: str | None = None, session_id: str | None = None, ) -> str: """Start the A2A pipeline and return a workflow_id.""" workflow_id = str(uuid.uuid4()) sharp_ctx = build_sharp_context( patient_id=patient_id, fhir_ref=f"Patient/{patient_id}", session_id=session_id or workflow_id, ) if fhir_token: sharp_ctx["fhir_token"] = fhir_token if fhir_base_url: sharp_ctx["patient_context"]["fhir_base"] = fhir_base_url _workflows[workflow_id] = { "workflow_id": workflow_id, "patient_id": patient_id, "nct_id": nct_id, "condition": condition, "current_state": WorkflowState.PENDING, "events": [], "result": None, "sharp_context": sharp_ctx, "created_at": datetime.utcnow().isoformat(), "updated_at": datetime.utcnow().isoformat(), } return workflow_id def run_pipeline(workflow_id: str) -> dict: """Execute the full A2A pipeline synchronously.""" workflow = _workflows.get(workflow_id) if not workflow: raise ValueError(f"Workflow {workflow_id} not found") patient_id = workflow["patient_id"] nct_id = workflow.get("nct_id") condition = workflow.get("condition") try: # Agent 1: Ingest FHIR patient data patient_profile = _agent_ingest_patient(workflow_id, patient_id) # Infer condition if not condition and patient_profile.get("diagnosis_names"): condition = patient_profile["diagnosis_names"][0] elif not condition: condition = "cancer" # Agent 2: Parse trial protocols trials, protocol_meta = _agent_parse_protocol(workflow_id, nct_id, condition) # Agent 3: Semantic matching candidates = _agent_match(workflow_id, patient_profile, trials) # Agent 4: Predictive scoring candidates = _agent_score(workflow_id, candidates, patient_profile) # Agent 5: Recruitment communication recruitment_records = _agent_recruit(workflow_id, candidates, patient_profile) result = { "patient_profile": patient_profile, "matched_trials": candidates, "recruitment_records": recruitment_records, "protocol_summary": protocol_meta.get("summary", ""), "total_trials_evaluated": len(trials), "eligible_trials": sum(1 for c in candidates if c.get("eligible")), } workflow["result"] = result _emit_event(workflow_id, WorkflowState.COMPLETED, f"Pipeline complete: {result['eligible_trials']} eligible trials found", result) except Exception as e: _emit_event(workflow_id, WorkflowState.FAILED, f"Pipeline failed: {str(e)}") workflow["error"] = str(e) return _workflows[workflow_id] def get_workflow_status(workflow_id: str) -> dict: workflow = _workflows.get(workflow_id) if not workflow: return {"error": "Workflow not found"} return { "workflow_id": workflow_id, "current_state": workflow["current_state"], "events": workflow["events"][-10:], # Last 10 events "result": workflow.get("result"), "error": workflow.get("error"), "created_at": workflow["created_at"], "updated_at": workflow["updated_at"], } def list_workflows() -> list[dict]: return [ { "workflow_id": wf["workflow_id"], "patient_id": wf["patient_id"], "current_state": wf["current_state"], "created_at": wf["created_at"], } for wf in _workflows.values() ]