Spaces:
Running
Running
Initial deployment: ClinicalMatch AI v2.0 β FHIR R4 Β· MCP (9 tools) Β· A2A workflow Β· SHARP compliance Β· 100k synthetic patients Β· Neo4j graph Β· GraphRAG chatbot
59abb4f | """A2A (Agent-to-Agent) orchestration workflow β state machine for the recruitment pipeline. | |
| Every inter-agent message carries a SHARP Extension Spec context envelope: | |
| sharp_version, patient_context (id, fhir_ref, fhir_base, tenant_id, session_id), | |
| data_classification, baa_in_scope, consent_status | |
| """ | |
| import uuid | |
| import time | |
| from datetime import datetime | |
| from enum import Enum | |
| from typing import Any | |
| from fhir_adapter import get_patient_profile, get_mock_fhir_patient, build_patient_profile | |
| from clinicaltrials_api import search_trials_sync, get_trial_details_sync | |
| from matching_engine import get_criteria_for_trial, score_patient_for_trial, match_patient_to_trials | |
| from llm_client import generate_outreach_message, summarize_trial | |
| from fhir_server import build_sharp_context, get_live_patient_profile | |
| import consent_agent | |
| class WorkflowState(str, Enum): | |
| PENDING = "PENDING" | |
| INGESTING = "INGESTING" | |
| PARSING_PROTOCOL = "PARSING_PROTOCOL" | |
| MATCHING = "MATCHING" | |
| SCORING = "SCORING" | |
| RECRUITING = "RECRUITING" | |
| COMPLETED = "COMPLETED" | |
| FAILED = "FAILED" | |
| # In-memory workflow store (production: use Redis or Neo4j) | |
| _workflows: dict[str, dict] = {} | |
| def _emit_event(workflow_id: str, state: WorkflowState, message: str, data: Any = None): | |
| workflow = _workflows[workflow_id] | |
| event = { | |
| "state": state, | |
| "message": message, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "data": data, | |
| # SHARP envelope on every event so downstream agents have full context | |
| "sharp_context": workflow.get("sharp_context", {}), | |
| } | |
| workflow["events"].append(event) | |
| workflow["current_state"] = state | |
| workflow["updated_at"] = datetime.utcnow().isoformat() | |
| print(f"[A2A:{workflow_id[:8]}] {state} β {message}") | |
| # ββ Sub-agents ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _agent_ingest_patient(workflow_id: str, patient_id: str) -> dict: | |
| """Sub-agent: Ingest and validate patient FHIR data.""" | |
| _emit_event(workflow_id, WorkflowState.INGESTING, f"Ingesting FHIR R4 data for patient {patient_id}") | |
| time.sleep(0.3) # Simulate async data fetch | |
| fhir_patient = get_mock_fhir_patient(patient_id) | |
| if not fhir_patient: | |
| raise ValueError(f"Patient {patient_id} not found in FHIR registry") | |
| profile = build_patient_profile(fhir_patient) | |
| _emit_event(workflow_id, WorkflowState.INGESTING, | |
| f"FHIR data loaded: {len(fhir_patient.conditions)} conditions, {len(fhir_patient.medications)} medications", | |
| {"profile": profile}) | |
| return profile | |
| def _agent_parse_protocol(workflow_id: str, nct_id: str | None, condition: str) -> tuple[list[dict], dict]: | |
| """Sub-agent: Parse trial protocol and extract criteria.""" | |
| _emit_event(workflow_id, WorkflowState.PARSING_PROTOCOL, | |
| f"Parsing trial protocols for condition: {condition}") | |
| time.sleep(0.5) | |
| if nct_id: | |
| trials = [get_trial_details_sync(nct_id)] | |
| trials = [t for t in trials if t] | |
| else: | |
| trials = search_trials_sync(condition, page_size=8) | |
| if not trials: | |
| raise ValueError(f"No trials found for condition: {condition}") | |
| # Parse criteria for each trial using LLM | |
| parsed_trials = [] | |
| for trial in trials[:5]: # Limit to avoid timeout | |
| criteria = get_criteria_for_trial(trial) | |
| parsed_trials.append({**trial, "parsed_criteria": criteria}) | |
| summary = summarize_trial(trials[0]) if trials else "" | |
| _emit_event(workflow_id, WorkflowState.PARSING_PROTOCOL, | |
| f"Parsed {len(parsed_trials)} trial protocols", | |
| {"trial_count": len(parsed_trials), "protocol_summary": summary}) | |
| return parsed_trials, {"summary": summary} | |
| def _agent_match(workflow_id: str, patient_profile: dict, trials: list[dict]) -> list[dict]: | |
| """Sub-agent: Semantic matching of patient to trials.""" | |
| _emit_event(workflow_id, WorkflowState.MATCHING, | |
| f"Running semantic matching for patient {patient_profile['patient_id']} against {len(trials)} trials") | |
| time.sleep(0.3) | |
| candidates = [] | |
| for trial in trials: | |
| score_result = score_patient_for_trial(patient_profile["patient_id"], trial) | |
| candidates.append({ | |
| **trial, | |
| "match_score": score_result.get("overall_score", 0.0), | |
| "eligible": score_result.get("eligible", False), | |
| "inclusion_results": score_result.get("inclusion_results", []), | |
| "exclusion_results": score_result.get("exclusion_results", []), | |
| "match_summary": score_result.get("summary", ""), | |
| "risk_flags": score_result.get("risk_flags", []), | |
| }) | |
| candidates.sort(key=lambda x: x["match_score"], reverse=True) | |
| eligible = [c for c in candidates if c["eligible"]] | |
| _emit_event(workflow_id, WorkflowState.MATCHING, | |
| f"Matching complete: {len(eligible)}/{len(candidates)} trials eligible", | |
| {"eligible_count": len(eligible), "top_score": candidates[0]["match_score"] if candidates else 0}) | |
| return candidates | |
| def _agent_score(workflow_id: str, candidates: list[dict], patient_profile: dict) -> list[dict]: | |
| """Sub-agent: Predictive screening scoring with risk flags.""" | |
| _emit_event(workflow_id, WorkflowState.SCORING, "Running predictive screening analysis") | |
| time.sleep(0.2) | |
| for candidate in candidates: | |
| flags = candidate.get("risk_flags", []) | |
| # Add distance risk flag if no nearby sites | |
| locs = candidate.get("locations", []) | |
| if not locs: | |
| flags.append("No site location data available") | |
| # Add data completeness flag | |
| if not patient_profile.get("biomarkers"): | |
| flags.append("Biomarker data incomplete β may affect screening") | |
| candidate["risk_flags"] = flags | |
| candidate["screening_priority"] = ( | |
| "HIGH" if candidate["match_score"] >= 0.8 | |
| else "MEDIUM" if candidate["match_score"] >= 0.5 | |
| else "LOW" | |
| ) | |
| _emit_event(workflow_id, WorkflowState.SCORING, | |
| "Screening scoring complete", | |
| {"high_priority": sum(1 for c in candidates if c.get("screening_priority") == "HIGH")}) | |
| return candidates | |
| def _agent_recruit(workflow_id: str, candidates: list[dict], patient_profile: dict) -> list[dict]: | |
| """Sub-agent: Generate recruitment outreach for eligible candidates.""" | |
| _emit_event(workflow_id, WorkflowState.RECRUITING, "Generating personalized recruitment communications") | |
| eligible = [c for c in candidates if c.get("eligible")][:3] | |
| recruitment_records = [] | |
| for trial in eligible: | |
| try: | |
| outreach = generate_outreach_message(patient_profile, trial, "patient_email") | |
| pcp_letter = generate_outreach_message(patient_profile, trial, "pcp_letter") | |
| # A2A handoff β consent agent (SHARP envelope attached) | |
| consent_task = { | |
| "task_id": f"consent_{workflow_id}_{trial.get('nct_id','')}", | |
| "type": "CONSENT_REQUEST", | |
| "payload": { | |
| "patient_id": patient_profile.get("patient_id", ""), | |
| "nct_id": trial.get("nct_id", ""), | |
| "trial_title": trial.get("title", ""), | |
| "match_score": trial.get("match_score", 0.0), | |
| }, | |
| "sharp_context": _workflows[workflow_id].get("sharp_context", {}), | |
| } | |
| consent_result = consent_agent.receive_a2a_task(consent_task) | |
| recruitment_records.append({ | |
| "nct_id": trial.get("nct_id", ""), | |
| "trial_title": trial.get("title", ""), | |
| "match_score": trial.get("match_score", 0.0), | |
| "patient_email": outreach, | |
| "pcp_letter": pcp_letter, | |
| "status": "PENDING", | |
| "consent_id": consent_result.get("consent_id"), | |
| "consent_status": consent_result.get("status", "PENDING"), | |
| "created_at": datetime.utcnow().isoformat(), | |
| }) | |
| except Exception as e: | |
| recruitment_records.append({ | |
| "nct_id": trial.get("nct_id", ""), | |
| "trial_title": trial.get("title", ""), | |
| "error": str(e), | |
| "status": "ERROR", | |
| }) | |
| _emit_event(workflow_id, WorkflowState.RECRUITING, | |
| f"Generated outreach for {len(recruitment_records)} trials", | |
| {"record_count": len(recruitment_records)}) | |
| return recruitment_records | |
| # ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def start_pipeline( | |
| patient_id: str, | |
| nct_id: str | None = None, | |
| condition: str | None = None, | |
| fhir_token: str | None = None, | |
| fhir_base_url: str | None = None, | |
| session_id: str | None = None, | |
| ) -> str: | |
| """Start the A2A pipeline and return a workflow_id.""" | |
| workflow_id = str(uuid.uuid4()) | |
| sharp_ctx = build_sharp_context( | |
| patient_id=patient_id, | |
| fhir_ref=f"Patient/{patient_id}", | |
| session_id=session_id or workflow_id, | |
| ) | |
| if fhir_token: | |
| sharp_ctx["fhir_token"] = fhir_token | |
| if fhir_base_url: | |
| sharp_ctx["patient_context"]["fhir_base"] = fhir_base_url | |
| _workflows[workflow_id] = { | |
| "workflow_id": workflow_id, | |
| "patient_id": patient_id, | |
| "nct_id": nct_id, | |
| "condition": condition, | |
| "current_state": WorkflowState.PENDING, | |
| "events": [], | |
| "result": None, | |
| "sharp_context": sharp_ctx, | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat(), | |
| } | |
| return workflow_id | |
| def run_pipeline(workflow_id: str) -> dict: | |
| """Execute the full A2A pipeline synchronously.""" | |
| workflow = _workflows.get(workflow_id) | |
| if not workflow: | |
| raise ValueError(f"Workflow {workflow_id} not found") | |
| patient_id = workflow["patient_id"] | |
| nct_id = workflow.get("nct_id") | |
| condition = workflow.get("condition") | |
| try: | |
| # Agent 1: Ingest FHIR patient data | |
| patient_profile = _agent_ingest_patient(workflow_id, patient_id) | |
| # Infer condition | |
| if not condition and patient_profile.get("diagnosis_names"): | |
| condition = patient_profile["diagnosis_names"][0] | |
| elif not condition: | |
| condition = "cancer" | |
| # Agent 2: Parse trial protocols | |
| trials, protocol_meta = _agent_parse_protocol(workflow_id, nct_id, condition) | |
| # Agent 3: Semantic matching | |
| candidates = _agent_match(workflow_id, patient_profile, trials) | |
| # Agent 4: Predictive scoring | |
| candidates = _agent_score(workflow_id, candidates, patient_profile) | |
| # Agent 5: Recruitment communication | |
| recruitment_records = _agent_recruit(workflow_id, candidates, patient_profile) | |
| result = { | |
| "patient_profile": patient_profile, | |
| "matched_trials": candidates, | |
| "recruitment_records": recruitment_records, | |
| "protocol_summary": protocol_meta.get("summary", ""), | |
| "total_trials_evaluated": len(trials), | |
| "eligible_trials": sum(1 for c in candidates if c.get("eligible")), | |
| } | |
| workflow["result"] = result | |
| _emit_event(workflow_id, WorkflowState.COMPLETED, | |
| f"Pipeline complete: {result['eligible_trials']} eligible trials found", result) | |
| except Exception as e: | |
| _emit_event(workflow_id, WorkflowState.FAILED, f"Pipeline failed: {str(e)}") | |
| workflow["error"] = str(e) | |
| return _workflows[workflow_id] | |
| def get_workflow_status(workflow_id: str) -> dict: | |
| workflow = _workflows.get(workflow_id) | |
| if not workflow: | |
| return {"error": "Workflow not found"} | |
| return { | |
| "workflow_id": workflow_id, | |
| "current_state": workflow["current_state"], | |
| "events": workflow["events"][-10:], # Last 10 events | |
| "result": workflow.get("result"), | |
| "error": workflow.get("error"), | |
| "created_at": workflow["created_at"], | |
| "updated_at": workflow["updated_at"], | |
| } | |
| def list_workflows() -> list[dict]: | |
| return [ | |
| { | |
| "workflow_id": wf["workflow_id"], | |
| "patient_id": wf["patient_id"], | |
| "current_state": wf["current_state"], | |
| "created_at": wf["created_at"], | |
| } | |
| for wf in _workflows.values() | |
| ] | |