Spaces:
Running
Running
Initial deployment: ClinicalMatch AI v2.0 β FHIR R4 Β· MCP (9 tools) Β· A2A workflow Β· SHARP compliance Β· 100k synthetic patients Β· Neo4j graph Β· GraphRAG chatbot
59abb4f | from fastapi import FastAPI, HTTPException, BackgroundTasks, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import os | |
| import asyncio | |
| import threading | |
| import json | |
| import time | |
| import httpx | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from neo4j_setup import neo4j_conn, setup_schema | |
| from graphrag import retrieve_patient_trial_matches, rag_query, get_graph_stats | |
| from data_ingestion import ingest_sample_data | |
| from fhir_adapter import get_patient_profile, get_mock_fhir_patient, get_all_patient_ids, MOCK_FHIR_PATIENTS | |
| from clinicaltrials_api import search_trials_sync, get_trial_details_sync, get_trial_details | |
| from matching_engine import match_patient_to_trials, score_patient_for_trial, find_eligible_patients_for_trial | |
| from a2a_workflow import start_pipeline, run_pipeline, get_workflow_status, list_workflows, _workflows | |
| from analytics import get_kpi_summary, get_enrollment_funnel, get_site_performance, get_patient_demographics, get_recruitment_timeline, get_map_data | |
| from recruitment_pipeline import get_kanban_board, get_all_records, create_record, update_status, generate_and_store_outreach, RecruitmentStatus | |
| from llm_client import summarize_trial | |
| from graph_seeder import run_seeder, seed_sync | |
| from trial_enrichment import enrich_trials_from_search, get_eligible_patient_counts, get_graph_intelligence | |
| from intake_matching import match_intake_to_trials, save_intake_as_patient, BIOMARKER_REGISTRY | |
| from llm_client import get_provider_status | |
| from fhir_server import ( | |
| get_fhir_server_status, get_live_patient_profile, | |
| search_fhir_patients, build_sharp_context, | |
| ) | |
| import consent_agent | |
| app = FastAPI( | |
| title="Precision Clinical Trial Matching & Recruitment Agent", | |
| version="2.0.0", | |
| description="A2A-powered agent for precision clinical trial matching using FHIR R4 standards and GraphRAG", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Request Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PatientIngestRequest(BaseModel): | |
| id: str | |
| age: int | |
| gender: str | |
| diagnosis_code: str | |
| class WorkflowRequest(BaseModel): | |
| patient_id: str | |
| nct_id: Optional[str] = None | |
| condition: Optional[str] = None | |
| # SHARP / SMART on FHIR fields | |
| fhir_token: Optional[str] = None # Bearer token for FHIR server access | |
| fhir_base_url: Optional[str] = None # Override FHIR base for this session | |
| session_id: Optional[str] = None # Caller-supplied session ID for tracing | |
| class OutreachRequest(BaseModel): | |
| patient_id: str | |
| nct_id: str | |
| trial_title: str | |
| channel: str = "patient_email" | |
| class StatusUpdateRequest(BaseModel): | |
| status: RecruitmentStatus | |
| class RAGRequest(BaseModel): | |
| question: str | |
| class IntakeLabs(BaseModel): | |
| hemoglobin: Optional[float] = None # g/dL | |
| wbc: Optional[float] = None # Γ10βΉ/L | |
| anc: Optional[float] = None # Γ10βΉ/L | |
| platelets: Optional[float] = None # Γ10βΉ/L | |
| creatinine: Optional[float] = None # ΞΌmol/L | |
| egfr: Optional[float] = None # mL/min/1.73mΒ² | |
| bilirubin: Optional[float] = None # ΞΌmol/L | |
| alt: Optional[float] = None # U/L | |
| ast: Optional[float] = None # U/L | |
| albumin: Optional[float] = None # g/dL | |
| class IntakeRequest(BaseModel): | |
| condition: str # free text: "breast cancer" | |
| age: Optional[int] = None # years | |
| sex: Optional[str] = None # MALE / FEMALE | |
| ecog: Optional[int] = None # 0β4 | |
| stage: Optional[str] = None # I / II / III / IV | |
| biomarkers: list[str] = [] # list of BIOMARKER_REGISTRY keys | |
| labs: Optional[IntakeLabs] = None | |
| prior_chemo: bool = False | |
| prior_radiation: bool = False | |
| prior_surgery: bool = False | |
| medications: list[str] = [] | |
| save_to_graph: bool = False # persist as Patient node | |
| class ConsentStatusRequest(BaseModel): | |
| status: str # SIGNED | DECLINED | EXPIRED | |
| notes: Optional[str] = None | |
| class A2ATaskRequest(BaseModel): | |
| task_id: Optional[str] = None | |
| type: str | |
| payload: dict | |
| class RecruitmentRecordRequest(BaseModel): | |
| patient_id: str | |
| nct_id: str | |
| trial_title: str | |
| match_score: float = 0.75 | |
| # ββ Core / Health ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| return { | |
| "name": "Precision Clinical Trial Matching Agent", | |
| "version": "2.0.0", | |
| "status": "operational", | |
| "standards": ["FHIR R4", "MCP", "A2A"], | |
| } | |
| # ββ Configuration & Provider Status ββββββββββββββββββββββββββββββββββββββββββ | |
| async def llm_config(): | |
| """Current LLM provider configuration and HIPAA BAA eligibility status.""" | |
| return get_provider_status() | |
| async def fhir_config(): | |
| """Current FHIR server connection status and SMART token configuration.""" | |
| return get_fhir_server_status() | |
| async def full_config(): | |
| """Full system configuration β LLM provider + FHIR server status.""" | |
| return { | |
| "llm": get_provider_status(), | |
| "fhir": get_fhir_server_status(), | |
| } | |
| # ββ Live FHIR Patient Endpoints βββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def list_live_fhir_patients(count: int = 10): | |
| """Fetch real Patient resources from the configured FHIR R4 server.""" | |
| patients = search_fhir_patients(count=min(count, 50)) | |
| return {"patients": patients, "total": len(patients), "source": "fhir_server"} | |
| async def get_live_fhir_patient(fhir_id: str, fhir_token: Optional[str] = None): | |
| """ | |
| Fetch a patient from the live FHIR server, build a matching profile, | |
| and attach a SHARP context envelope. | |
| """ | |
| sharp_ctx = build_sharp_context( | |
| patient_id=fhir_id, | |
| fhir_ref=f"Patient/{fhir_id}", | |
| ) | |
| profile = get_live_patient_profile(fhir_id, sharp_context=sharp_ctx) | |
| if not profile: | |
| raise HTTPException(status_code=404, detail=f"FHIR Patient {fhir_id} not found on server") | |
| return profile | |
| async def match_live_fhir_patient(fhir_id: str, fhir_token: Optional[str] = None, top_n: int = 5): | |
| """ | |
| Full pipeline: fetch patient from live FHIR server β match against trials. | |
| SHARP context envelope included in response. | |
| """ | |
| sharp_ctx = build_sharp_context(patient_id=fhir_id, fhir_ref=f"Patient/{fhir_id}") | |
| profile = get_live_patient_profile(fhir_id, sharp_context=sharp_ctx) | |
| if not profile: | |
| raise HTTPException(status_code=404, detail=f"FHIR Patient {fhir_id} not found") | |
| from matching_engine import match_patient_to_trials as _match | |
| condition = profile.get("diagnosis_names", ["cancer"])[0] if profile.get("diagnosis_names") else "cancer" | |
| matches = _match(fhir_id, condition, top_n) | |
| return { | |
| "fhir_id": fhir_id, | |
| "profile": profile, | |
| "matches": matches, | |
| "total": len(matches), | |
| "sharp_context": sharp_ctx, | |
| } | |
| async def health(): | |
| stats = get_graph_stats() | |
| # Neo4j connectivity check | |
| neo4j_ok = False | |
| try: | |
| neo4j_conn.run_query("RETURN 1") | |
| neo4j_ok = True | |
| except Exception: | |
| pass | |
| # CT.gov reachability | |
| ctgov_ok = False | |
| try: | |
| async with httpx.AsyncClient(timeout=4) as client: | |
| r = await client.get( | |
| "https://clinicaltrials.gov/api/v2/studies", | |
| params={"query.term": "cancer", "pageSize": 1}, | |
| ) | |
| ctgov_ok = r.status_code == 200 | |
| except Exception: | |
| pass | |
| patient_count = stats.get("patients", 0) | |
| trial_count = stats.get("trials", 0) | |
| edge_count = stats.get("eligible_for_relationships", 0) | |
| seeded = patient_count >= 100 and trial_count >= 50 | |
| llm_status = get_provider_status() | |
| fhir_status = get_fhir_server_status() | |
| overall = "healthy" if (neo4j_ok and ctgov_ok and seeded) else ("degraded" if neo4j_ok else "unhealthy") | |
| return { | |
| "status": overall, | |
| "neo4j": "connected" if neo4j_ok else "unavailable", | |
| "ctgov_api": "reachable" if ctgov_ok else "unreachable", | |
| "fhir_server": "reachable" if fhir_status.get("reachable") else "unreachable", | |
| "fhir_base_url": fhir_status.get("base_url"), | |
| "smart_auth": fhir_status.get("auth_method"), | |
| "graph_seeded": seeded, | |
| "graph_stats": stats, | |
| "patient_count": patient_count, | |
| "trial_count": trial_count, | |
| "eligible_edges": edge_count, | |
| "llm_provider": llm_status.get("provider"), | |
| "llm_model": llm_status.get("model"), | |
| "llm_hipaa_eligible": llm_status.get("hipaa_eligible"), | |
| "version": "2.0.0", | |
| "standards": ["FHIR R4", "MCP", "A2A", "SHARP"], | |
| } | |
| # ββ FHIR Patient Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def list_patients(): | |
| patients = [] | |
| for pid in get_all_patient_ids(): | |
| profile = get_patient_profile(pid) | |
| if profile: | |
| patients.append(profile) | |
| return {"patients": patients, "total": len(patients)} | |
| async def get_patient(patient_id: str): | |
| profile = get_patient_profile(patient_id) | |
| if not profile: | |
| raise HTTPException(status_code=404, detail=f"Patient {patient_id} not found") | |
| fhir = get_mock_fhir_patient(patient_id) | |
| return {"profile": profile, "fhir_bundle": fhir.model_dump() if fhir else None} | |
| async def get_patient_fhir(patient_id: str): | |
| fhir = get_mock_fhir_patient(patient_id) | |
| if not fhir: | |
| raise HTTPException(status_code=404, detail="Patient not found") | |
| return fhir.model_dump() | |
| # Legacy endpoint | |
| async def ingest_patient(patient: PatientIngestRequest): | |
| query = """ | |
| MERGE (p:Patient {id: $id}) | |
| SET p += {age: $age, gender: $gender} | |
| MERGE (d:Diagnosis {code: $code}) | |
| MERGE (p)-[:HAS_DIAGNOSIS]->(d) | |
| """ | |
| try: | |
| neo4j_conn.run_query(query, {"id": patient.id, "age": patient.age, "gender": patient.gender, "code": patient.diagnosis_code}) | |
| return {"status": "Patient data ingested"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ββ Trial Search & Details βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def search_trials_endpoint( | |
| condition: str, | |
| phase: Optional[str] = None, | |
| status: str = "RECRUITING", | |
| page_size: int = 20, | |
| background_tasks: BackgroundTasks = None, | |
| ): | |
| trials = search_trials_sync(condition, phase, status, page_size) | |
| # Passive graph enrichment β fire-and-forget in background | |
| if background_tasks and trials: | |
| background_tasks.add_task(enrich_trials_from_search, trials, condition) | |
| # Attach graph-derived eligible patient counts | |
| nct_ids = [t["nct_id"] for t in trials if t.get("nct_id")] | |
| counts = get_eligible_patient_counts(nct_ids) | |
| for t in trials: | |
| t["eligible_patients_in_graph"] = counts.get(t.get("nct_id", ""), 0) | |
| return {"trials": trials, "total": len(trials), "condition": condition, "sorted_by": "last_updated"} | |
| async def get_trial(nct_id: str): | |
| trial = get_trial_details_sync(nct_id) | |
| if not trial: | |
| raise HTTPException(status_code=404, detail=f"Trial {nct_id} not found") | |
| summary = summarize_trial(trial) | |
| return {**trial, "ai_summary": summary} | |
| async def get_eligible_patients(nct_id: str): | |
| results = find_eligible_patients_for_trial(nct_id) | |
| return {"nct_id": nct_id, "eligible_patients": results, "total": len(results)} | |
| async def trial_graph_intelligence(nct_id: str): | |
| """Graph-derived intelligence: eligible count, similar trials, biomarker distribution, sites.""" | |
| return get_graph_intelligence(nct_id) | |
| # ββ Clinical Data Intake βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def intake_match(request: IntakeRequest): | |
| """ | |
| Accept raw clinical data (SI units) and return ranked trial matches. | |
| No patient ID required β useful for individuals, clinicians, and researchers. | |
| """ | |
| intake = { | |
| "condition": request.condition, | |
| "age": request.age, | |
| "sex": (request.sex or "").upper() or None, | |
| "ecog": request.ecog, | |
| "stage": request.stage, | |
| "biomarkers": request.biomarkers, | |
| "labs": request.labs.model_dump(exclude_none=True) if request.labs else {}, | |
| "prior_chemo": request.prior_chemo, | |
| "prior_radiation": request.prior_radiation, | |
| "prior_surgery": request.prior_surgery, | |
| "medications": request.medications, | |
| } | |
| matches = match_intake_to_trials(intake, request.condition, limit=10) | |
| patient_id = None | |
| if request.save_to_graph: | |
| patient_id = save_intake_as_patient(intake) | |
| return { | |
| "condition": request.condition, | |
| "matches": matches, | |
| "total": len(matches), | |
| "patient_id": patient_id, | |
| } | |
| async def list_biomarkers(): | |
| """Return the full biomarker registry for populating the intake form.""" | |
| return { | |
| "biomarkers": [ | |
| {"id": bid, "label": info[0]} | |
| for bid, info in BIOMARKER_REGISTRY.items() | |
| ] | |
| } | |
| # Legacy endpoint | |
| async def match_trials_legacy(patient_id: str): | |
| matches = retrieve_patient_trial_matches(patient_id) | |
| return {"matches": matches} | |
| # ββ Matching Engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def match_patient_trials(patient_id: str, condition: Optional[str] = None, top_n: int = 5): | |
| matches = match_patient_to_trials(patient_id, condition, top_n) | |
| return {"patient_id": patient_id, "matches": matches, "total": len(matches)} | |
| async def screen_patient_for_trial(patient_id: str, nct_id: str): | |
| trial = await get_trial_details(nct_id) | |
| if not trial: | |
| raise HTTPException(status_code=404, detail=f"Trial {nct_id} not found") | |
| result = score_patient_for_trial(patient_id, trial) | |
| if "error" in result: | |
| raise HTTPException(status_code=404, detail=result["error"]) | |
| return result | |
| # ββ A2A Workflow βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_workflow(request: WorkflowRequest, background_tasks: BackgroundTasks): | |
| workflow_id = start_pipeline(request.patient_id, request.nct_id, request.condition) | |
| result = run_pipeline(workflow_id) | |
| return { | |
| "workflow_id": workflow_id, | |
| "status": result["current_state"], | |
| "result": result.get("result"), | |
| "events": result.get("events", []), | |
| } | |
| async def start_workflow(request: WorkflowRequest, background_tasks: BackgroundTasks): | |
| """Start a pipeline and return workflow_id immediately; stream progress via /workflow/{id}/stream.""" | |
| workflow_id = start_pipeline( | |
| request.patient_id, request.nct_id, request.condition, | |
| fhir_token=request.fhir_token, | |
| fhir_base_url=request.fhir_base_url, | |
| session_id=request.session_id, | |
| ) | |
| background_tasks.add_task(_run_pipeline_background, workflow_id) | |
| sharp_ctx = _workflows[workflow_id].get("sharp_context", {}) | |
| return { | |
| "workflow_id": workflow_id, | |
| "status": "PENDING", | |
| "stream_url": f"/api/v1/workflow/{workflow_id}/stream", | |
| "sharp_context": sharp_ctx, | |
| } | |
| def _run_pipeline_background(workflow_id: str): | |
| run_pipeline(workflow_id) | |
| async def stream_workflow(workflow_id: str, request: Request): | |
| """SSE endpoint β streams A2A state transitions as they happen.""" | |
| async def event_generator(): | |
| seen = 0 | |
| timeout = 120 # max seconds to stream | |
| deadline = time.time() + timeout | |
| while time.time() < deadline: | |
| if await request.is_disconnected(): | |
| break | |
| wf = _workflows.get(workflow_id) | |
| if not wf: | |
| yield f"data: {json.dumps({'error': 'workflow_not_found'})}\n\n" | |
| break | |
| events = wf.get("events", []) | |
| # Emit any new events since last check | |
| for evt in events[seen:]: | |
| payload = { | |
| "state": evt["state"], | |
| "message": evt["message"], | |
| "timestamp": evt["timestamp"], | |
| } | |
| if evt.get("data") and not evt["data"].__class__.__name__ == "dict" or evt.get("data"): | |
| try: | |
| # Only include lightweight summary data, not full result blobs | |
| d = evt.get("data") or {} | |
| if isinstance(d, dict): | |
| safe = {k: v for k, v in d.items() if k not in ("matched_trials", "recruitment_records", "patient_profile")} | |
| if safe: | |
| payload["data"] = safe | |
| except Exception: | |
| pass | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| seen += 1 | |
| current = wf.get("current_state", "") | |
| if current in ("COMPLETED", "FAILED"): | |
| # Send final event with result summary | |
| result = wf.get("result") or {} | |
| final = { | |
| "state": current, | |
| "eligible_trials": result.get("eligible_trials", 0), | |
| "total_evaluated": result.get("total_trials_evaluated", 0), | |
| "recruitment_records": len(result.get("recruitment_records", [])), | |
| "error": wf.get("error"), | |
| } | |
| yield f"data: {json.dumps(final)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| break | |
| await asyncio.sleep(0.5) | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| async def workflow_status(workflow_id: str): | |
| status = get_workflow_status(workflow_id) | |
| if "error" in status: | |
| raise HTTPException(status_code=404, detail=status["error"]) | |
| return status | |
| async def list_all_workflows(): | |
| return {"workflows": list_workflows()} | |
| # ββ Consent & Scheduling Agent ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def a2a_task(request: A2ATaskRequest): | |
| """A2A inter-agent task endpoint β routes CONSENT_REQUEST and SCHEDULE_REQUEST tasks.""" | |
| result = consent_agent.receive_a2a_task(request.model_dump()) | |
| return result | |
| async def list_consents(patient_id: Optional[str] = None): | |
| return {"consents": consent_agent.list_consent_records(patient_id)} | |
| async def consent_stats(): | |
| return consent_agent.get_consent_stats() | |
| async def get_consent(consent_id: str): | |
| record = consent_agent.get_consent_record(consent_id) | |
| if not record: | |
| raise HTTPException(status_code=404, detail="Consent record not found") | |
| return record | |
| async def update_consent(consent_id: str, request: ConsentStatusRequest): | |
| valid = {"SIGNED", "DECLINED", "EXPIRED"} | |
| if request.status not in valid: | |
| raise HTTPException(status_code=400, detail=f"status must be one of {valid}") | |
| result = consent_agent.update_consent_status(consent_id, request.status, request.notes or "") | |
| if "error" in result: | |
| raise HTTPException(status_code=404, detail=result["error"]) | |
| return result | |
| async def list_appointments(patient_id: Optional[str] = None): | |
| return {"appointments": consent_agent.list_appointments(patient_id)} | |
| async def confirm_appointment(appt_id: str): | |
| result = consent_agent.confirm_appointment(appt_id) | |
| if "error" in result: | |
| raise HTTPException(status_code=404, detail=result["error"]) | |
| return result | |
| # ββ Recruitment Pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def kanban_board(): | |
| return get_kanban_board() | |
| async def all_recruitment_records(): | |
| return {"records": get_all_records()} | |
| async def create_recruitment_record(request: RecruitmentRecordRequest): | |
| record = create_record(request.patient_id, request.nct_id, request.trial_title, request.match_score) | |
| return record | |
| async def update_record_status(record_id: str, request: StatusUpdateRequest): | |
| try: | |
| return update_status(record_id, request.status) | |
| except ValueError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| async def generate_outreach(request: OutreachRequest): | |
| trial = get_trial_details_sync(request.nct_id) or { | |
| "nct_id": request.nct_id, | |
| "title": request.trial_title, | |
| "brief_summary": "", | |
| "phase": "N/A", | |
| "sponsor": "N/A", | |
| "locations": [], | |
| } | |
| try: | |
| result = generate_and_store_outreach( | |
| request.patient_id, request.nct_id, request.trial_title, trial, request.channel | |
| ) | |
| return result | |
| except ValueError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| # ββ Analytics & Dashboard ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def kpi_summary(): | |
| return get_kpi_summary() | |
| async def enrollment_funnel(trial_id: Optional[str] = None): | |
| return {"funnel": get_enrollment_funnel(trial_id)} | |
| async def site_performance(): | |
| return {"sites": get_site_performance()} | |
| async def patient_demographics(trial_id: Optional[str] = None): | |
| return get_patient_demographics(trial_id) | |
| async def recruitment_timeline(days: int = 30): | |
| return {"timeline": get_recruitment_timeline(days)} | |
| async def map_data(): | |
| return get_map_data() | |
| # ββ GraphRAG βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def graph_query(question: str): | |
| response = rag_query(question) | |
| return {"response": response} | |
| async def graph_query_post(request: RAGRequest): | |
| response = rag_query(request.question) | |
| return {"response": response} | |
| async def graph_stats(): | |
| return get_graph_stats() | |
| async def list_graph_patients(condition: Optional[str] = None, limit: int = 200): | |
| """Query Neo4j for seeded patient records.""" | |
| if condition: | |
| rows = neo4j_conn.run_query( | |
| "MATCH (p:Patient) WHERE toLower(p.condition) CONTAINS toLower($cond) " | |
| "RETURN p.id AS id, p.name AS name, p.age AS age, p.condition AS condition, " | |
| "p.city AS city, p.state AS state ORDER BY p.id LIMIT $limit", | |
| {"cond": condition, "limit": limit}, | |
| ) | |
| else: | |
| rows = neo4j_conn.run_query( | |
| "MATCH (p:Patient) RETURN p.id AS id, p.name AS name, p.age AS age, " | |
| "p.condition AS condition, p.city AS city, p.state AS state " | |
| "ORDER BY p.id LIMIT $limit", | |
| {"limit": limit}, | |
| ) | |
| return {"patients": rows, "total": len(rows)} | |
| # Legacy | |
| async def rag_query_legacy(question: str): | |
| return {"response": rag_query(question)} | |
| async def enrich_legacy(): | |
| return {"reward": 0.75, "message": "Graph enrichment via RL (see rl_enrichment.py)"} | |
| # ββ Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def full_setup(background_tasks: BackgroundTasks): | |
| setup_schema() | |
| ingest_sample_data() | |
| # Seed real data from live APIs in the background | |
| background_tasks.add_task(_run_seeder_thread) | |
| return {"status": "Setup started β schema initialized, sample data ingested, real-data seeding running in background"} | |
| async def setup_sample(): | |
| ingest_sample_data() | |
| return {"status": "Sample data ingested"} | |
| async def seed_graph(background_tasks: BackgroundTasks, conditions: list[str] | None = None): | |
| """Trigger real-data seeding from ClinicalTrials.gov, RxNorm, ICD-10, PubMed.""" | |
| background_tasks.add_task(_run_seeder_thread, conditions) | |
| return { | |
| "status": "Seeding started in background", | |
| "sources": ["clinicaltrials.gov", "rxnorm.nlm.nih.gov", "icd10cm nlm", "pubmed ncbi"], | |
| "conditions": conditions or "all default oncology conditions", | |
| } | |
| async def seed_status(): | |
| stats = get_graph_stats() | |
| return {"graph_stats": stats, "note": "Check /api/v1/graph/stats for node counts"} | |
| def _run_seeder_thread(conditions: list[str] | None = None): | |
| """Run the async seeder in a new thread (avoids event loop conflict with FastAPI).""" | |
| try: | |
| asyncio.run(run_seeder(conditions)) | |
| except Exception as e: | |
| print(f"[seeder] error: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) | |