Spaces:
Sleeping
Sleeping
| """FastAPI API surface for Clarke consultation and health endpoints.""" | |
| from __future__ import annotations | |
| import json | |
| import threading | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from shutil import copyfileobj | |
| from typing import Any | |
| from fastapi import BackgroundTasks, Body, FastAPI, File, Form, HTTPException, UploadFile | |
| from backend.audio import convert_to_wav_16k, validate_audio | |
| from backend.config import get_settings | |
| from backend.errors import AudioError, ModelExecutionError, get_component_logger | |
| from backend.orchestrator import PipelineOrchestrator | |
| from backend.schemas import ConsultationStatus, Patient | |
| app = FastAPI(title="Clarke API", version="0.1.0") | |
| settings = get_settings() | |
| logger = get_component_logger("api") | |
| orchestrator = PipelineOrchestrator() | |
| def _iso_timestamp() -> str: | |
| """Return UTC timestamp in ISO 8601 format with Z suffix. | |
| Args: | |
| None: No input parameters. | |
| Returns: | |
| str: UTC timestamp string. | |
| """ | |
| return datetime.now(tz=timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") | |
| def _load_clinic_list_payload() -> dict[str, Any]: | |
| """Load clinic list fixture JSON from disk. | |
| Args: | |
| None: Reads static clinic list fixture file. | |
| Returns: | |
| dict[str, Any]: Parsed clinic list JSON payload. | |
| """ | |
| clinic_path = Path("data/clinic_list.json") | |
| if not clinic_path.exists(): | |
| raise HTTPException(status_code=500, detail="Clinic list file is missing") | |
| return json.loads(clinic_path.read_text(encoding="utf-8")) | |
| def _load_patient_resource(patient_id: str) -> dict[str, Any]: | |
| """Load patient resource for a patient id from local FHIR bundle fixture. | |
| Args: | |
| patient_id (str): Patient identifier. | |
| Returns: | |
| dict[str, Any]: FHIR Patient resource from bundle fixture. | |
| """ | |
| bundle_path = Path("data/fhir_bundles") / f"{patient_id}.json" | |
| if not bundle_path.exists(): | |
| return {} | |
| bundle = json.loads(bundle_path.read_text(encoding="utf-8")) | |
| for entry in bundle.get("entry", []): | |
| resource = entry.get("resource", {}) | |
| if resource.get("resourceType") == "Patient": | |
| return resource | |
| return {} | |
| def _patient_from_records(clinic_row: dict[str, Any], patient_resource: dict[str, Any]) -> Patient: | |
| """Build a Patient schema object from clinic list and FHIR records. | |
| Args: | |
| clinic_row (dict[str, Any]): Patient row from `clinic_list.json`. | |
| patient_resource (dict[str, Any]): FHIR Patient resource dictionary. | |
| Returns: | |
| Patient: Normalised patient object for API responses. | |
| """ | |
| nhs_number = "" | |
| for identifier in patient_resource.get("identifier", []): | |
| value = str(identifier.get("value", "")).strip() | |
| if value: | |
| nhs_number = value | |
| break | |
| date_of_birth = str(patient_resource.get("birthDate", "")) | |
| if date_of_birth and "-" in date_of_birth: | |
| yyyy, mm, dd = date_of_birth.split("-") | |
| date_of_birth = f"{dd}/{mm}/{yyyy}" | |
| return Patient( | |
| id=str(clinic_row.get("id", "")), | |
| nhs_number=nhs_number, | |
| name=str(clinic_row.get("name", "")), | |
| date_of_birth=date_of_birth, | |
| age=int(clinic_row.get("age", 0)), | |
| sex=str(clinic_row.get("sex", "")), | |
| appointment_time=str(clinic_row.get("time", "")), | |
| summary=str(clinic_row.get("summary", "")), | |
| ) | |
| def _get_patient(patient_id: str) -> Patient: | |
| """Resolve a patient from local fixture data. | |
| Args: | |
| patient_id (str): Patient identifier. | |
| Returns: | |
| Patient: Patient record for the requested id. | |
| """ | |
| payload = _load_clinic_list_payload() | |
| for row in payload.get("patients", []): | |
| if str(row.get("id")) == patient_id: | |
| resource = _load_patient_resource(patient_id) | |
| return _patient_from_records(row, resource) | |
| raise HTTPException(status_code=404, detail=f"Patient not found: {patient_id}") | |
| def _health_fhir_status() -> dict[str, Any]: | |
| """Compute lightweight FHIR service health metadata for health endpoint. | |
| Args: | |
| None: Reads settings and local data fixtures. | |
| Returns: | |
| dict[str, Any]: FHIR connectivity status and available patient count. | |
| """ | |
| clinic_list_path = Path("data/clinic_list.json") | |
| if settings.USE_MOCK_FHIR and clinic_list_path.exists(): | |
| payload = json.loads(clinic_list_path.read_text(encoding="utf-8")) | |
| patients = payload.get("patients", []) | |
| return {"status": "connected", "patient_count": len(patients)} | |
| return {"status": "unknown", "patient_count": 0} | |
| def _save_upload_temp(upload: UploadFile, destination: Path) -> Path: | |
| """Persist an uploaded file to disk for subsequent conversion/validation. | |
| Args: | |
| upload (UploadFile): Uploaded file object from multipart payload. | |
| destination (Path): Output path for storing file bytes. | |
| Returns: | |
| Path: Saved destination path. | |
| """ | |
| destination.parent.mkdir(parents=True, exist_ok=True) | |
| with destination.open("wb") as output_stream: | |
| copyfileobj(upload.file, output_stream) | |
| return destination | |
| def get_health() -> dict[str, Any]: | |
| """Return current service, model, and GPU health state. | |
| Args: | |
| None: Uses process-level singletons for status. | |
| Returns: | |
| dict[str, Any]: Health response payload. | |
| """ | |
| gpu_info = orchestrator._medasr_model.model_manager.check_gpu() | |
| models = { | |
| "medasr": { | |
| "loaded": orchestrator._medasr_model.model_manager.get_model("medasr") is not None, | |
| "device": "cuda:0" if gpu_info["vram_total_bytes"] > 0 else "cpu", | |
| }, | |
| "medgemma_4b": {"loaded": False, "device": "n/a", "quantised": "4bit"}, | |
| "medgemma_27b": {"loaded": False, "device": "n/a", "quantised": "4bit"}, | |
| } | |
| return { | |
| "status": "healthy", | |
| "models": models, | |
| "fhir": _health_fhir_status(), | |
| "gpu": { | |
| "name": gpu_info["gpu_name"], | |
| "vram_used_gb": round(float(gpu_info["vram_used_bytes"]) / 1_000_000_000, 2), | |
| "vram_total_gb": round(float(gpu_info["vram_total_bytes"]) / 1_000_000_000, 2), | |
| }, | |
| "timestamp": _iso_timestamp(), | |
| } | |
| def get_patients() -> dict[str, list[dict[str, Any]]]: | |
| """Return clinic patient list with normalised schema fields. | |
| Args: | |
| None: Reads local clinic list and FHIR bundle fixtures. | |
| Returns: | |
| dict[str, list[dict[str, Any]]]: Patient list payload. | |
| """ | |
| payload = _load_clinic_list_payload() | |
| patients: list[dict[str, Any]] = [] | |
| for row in payload.get("patients", []): | |
| patient_id = str(row.get("id", "")) | |
| if not patient_id: | |
| continue | |
| patient_model = _patient_from_records(row, _load_patient_resource(patient_id)) | |
| patients.append(patient_model.model_dump()) | |
| return {"patients": patients} | |
| def get_patient(patient_id: str) -> dict[str, Any]: | |
| """Return a single patient by id. | |
| Args: | |
| patient_id (str): Patient identifier. | |
| Returns: | |
| dict[str, Any]: Patient payload. | |
| """ | |
| patient = _get_patient(patient_id) | |
| return patient.model_dump() | |
| def generate_patient_context(patient_id: str) -> dict[str, Any]: | |
| """Generate patient context via EHR agent and return context JSON. | |
| Args: | |
| patient_id (str): Patient identifier. | |
| Returns: | |
| dict[str, Any]: Structured patient context payload. | |
| """ | |
| _get_patient(patient_id) | |
| context = orchestrator._ehr_agent.get_patient_context(patient_id) | |
| return context.model_dump(mode="json") | |
| def start_consultation(payload: dict[str, str], background_tasks: BackgroundTasks) -> dict[str, Any]: | |
| """Start a consultation session and trigger context prefetch in background. | |
| Args: | |
| payload (dict[str, str]): Request body containing patient_id. | |
| background_tasks (BackgroundTasks): FastAPI background task runner. | |
| Returns: | |
| dict[str, Any]: Consultation identifier and initial state. | |
| """ | |
| patient_id = payload.get("patient_id", "") | |
| if not patient_id: | |
| raise HTTPException(status_code=400, detail="patient_id is required") | |
| patient = _get_patient(patient_id) | |
| consultation = orchestrator.start_consultation(patient) | |
| background_tasks.add_task(orchestrator.prefetch_context, consultation.id) | |
| return { | |
| "consultation_id": consultation.id, | |
| "patient_id": patient_id, | |
| "status": ConsultationStatus.RECORDING.value, | |
| "started_at": consultation.started_at, | |
| } | |
| def upload_audio( | |
| consultation_id: str, | |
| audio_file: UploadFile = File(...), | |
| is_final: bool = Form(...), | |
| ) -> dict[str, Any]: | |
| """Accept consultation audio, convert to MedASR format, and store metadata. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| audio_file (UploadFile): Uploaded WAV/WebM audio file. | |
| is_final (bool): Whether upload is the final complete recording. | |
| Returns: | |
| dict[str, Any]: Upload acknowledgement and validated duration seconds. | |
| """ | |
| try: | |
| consultation = orchestrator.get_consultation(consultation_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| extension = Path(audio_file.filename or "").suffix.lower() | |
| if extension not in {".wav", ".webm"}: | |
| raise HTTPException(status_code=400, detail="Only WAV or WebM audio uploads are supported") | |
| uploads_root = Path("/tmp/uploads") / consultation_id | |
| raw_path = uploads_root / f"raw{extension}" | |
| original_stem = Path(audio_file.filename or "audio").stem or "audio" | |
| wav_path = uploads_root / f"{original_stem}_16k.wav" | |
| try: | |
| _save_upload_temp(audio_file, raw_path) | |
| processed_path = raw_path if extension == ".wav" else Path(convert_to_wav_16k(str(raw_path), str(wav_path))) | |
| if extension == ".wav": | |
| processed_path = Path(convert_to_wav_16k(str(raw_path), str(wav_path))) | |
| audio_metadata = validate_audio(str(processed_path)) | |
| except AudioError as exc: | |
| logger.error("Audio processing failed", consultation_id=consultation_id, error=str(exc)) | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| consultation.audio_file_path = str(processed_path) | |
| logger.info( | |
| "Stored consultation audio", | |
| consultation_id=consultation_id, | |
| duration_s=audio_metadata["duration_s"], | |
| is_final=is_final, | |
| ) | |
| return { | |
| "consultation_id": consultation_id, | |
| "audio_received": True, | |
| "duration_s": audio_metadata["duration_s"], | |
| } | |
| def end_consultation(consultation_id: str, body: dict[str, Any] | None = Body(None)) -> dict[str, Any]: | |
| """End a consultation and execute full processing pipeline. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| Returns: | |
| dict[str, Any]: Pipeline kickoff and status metadata. | |
| """ | |
| try: | |
| consultation = orchestrator.get_consultation(consultation_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| # Accept audio path directly from frontend for demo/server-side audio files | |
| if body and body.get("audio_path"): | |
| audio_path = body["audio_path"] | |
| if Path(audio_path).exists(): | |
| consultation.audio_file_path = audio_path | |
| # Accept document type and letter preferences from frontend | |
| if body and body.get("doc_type"): | |
| consultation.doc_type = body["doc_type"] | |
| if body and body.get("letter_prefs"): | |
| consultation.letter_prefs = body["letter_prefs"] | |
| # Launch pipeline in background thread – avoids HF Spaces 60s gateway timeout | |
| def _run_pipeline_background() -> None: | |
| """Execute the orchestrator pipeline in a daemon thread. | |
| Allows the /end endpoint to return 202 immediately while processing | |
| continues. Frontend polls /progress for status updates. | |
| """ | |
| try: | |
| orchestrator.end_consultation(consultation_id) | |
| except Exception as exc: | |
| logger.error(f"Background pipeline error: {exc}", consultation_id=consultation_id) | |
| thread = threading.Thread(target=_run_pipeline_background, daemon=True) | |
| thread.start() | |
| return { | |
| "consultation_id": consultation.id, | |
| "status": "processing", | |
| "pipeline_stage": "transcribing", | |
| "message": "Pipeline started in background. Poll /progress for updates.", | |
| } | |
| def get_transcript(consultation_id: str) -> dict[str, Any]: | |
| """Return transcript for a consultation when available. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| Returns: | |
| dict[str, Any]: Transcript payload. | |
| """ | |
| try: | |
| consultation = orchestrator.get_consultation(consultation_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| if consultation.transcript is None: | |
| raise HTTPException(status_code=404, detail="Transcript not generated yet") | |
| return consultation.transcript.model_dump(mode="json") | |
| def get_document(consultation_id: str) -> dict[str, Any]: | |
| """Return generated document placeholder for a consultation. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| Returns: | |
| dict[str, Any]: Consultation document object or null. | |
| """ | |
| try: | |
| consultation = orchestrator.get_consultation(consultation_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| return { | |
| "consultation_id": consultation_id, | |
| "document": consultation.document.model_dump(mode="json") if consultation.document else None, | |
| } | |
| def get_progress(consultation_id: str) -> dict[str, Any]: | |
| """Return latest pipeline progress object for a consultation. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| Returns: | |
| dict[str, Any]: Pipeline progress payload. | |
| """ | |
| try: | |
| progress = orchestrator.get_progress(consultation_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| return progress.model_dump(mode="json") | |
| def sign_off_document(consultation_id: str, payload: dict[str, Any] | None = Body(default=None)) -> dict[str, Any]: | |
| """Sign off a generated document and set consultation status to signed off. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| Returns: | |
| dict[str, Any]: Signed-off document payload with updated status. | |
| """ | |
| sections = (payload or {}).get("sections", []) | |
| try: | |
| if sections: | |
| orchestrator.update_document_sections(consultation_id, sections) | |
| document = orchestrator.sign_off_document(consultation_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ModelExecutionError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| return { | |
| "consultation_id": consultation_id, | |
| "status": ConsultationStatus.SIGNED_OFF.value, | |
| "document": document.model_dump(mode="json"), | |
| } | |
| def regenerate_document_section( | |
| consultation_id: str, | |
| payload: dict[str, str] = Body(...), | |
| ) -> dict[str, Any]: | |
| """Regenerate one document section while keeping other sections unchanged. | |
| Args: | |
| consultation_id (str): Consultation identifier. | |
| payload (dict[str, str]): Request body containing `section_heading`. | |
| Returns: | |
| dict[str, Any]: Updated document payload. | |
| """ | |
| section_heading = payload.get("section_heading", "").strip() | |
| if not section_heading: | |
| raise HTTPException(status_code=400, detail="section_heading is required") | |
| try: | |
| document = orchestrator.regenerate_document_section(consultation_id, section_heading) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ModelExecutionError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| return { | |
| "consultation_id": consultation_id, | |
| "status": ConsultationStatus.REVIEW.value, | |
| "document": document.model_dump(mode="json"), | |
| } | |