clarke / backend /api.py
yashvshetty's picture
Restore all features from origin/fresh-main
874f913
"""FastAPI API surface for Clarke consultation and health endpoints."""
from __future__ import annotations
import json
import threading
from datetime import datetime, timezone
from pathlib import Path
from shutil import copyfileobj
from typing import Any
from fastapi import BackgroundTasks, Body, FastAPI, File, Form, HTTPException, UploadFile
from backend.audio import convert_to_wav_16k, validate_audio
from backend.config import get_settings
from backend.errors import AudioError, ModelExecutionError, get_component_logger
from backend.orchestrator import PipelineOrchestrator
from backend.schemas import ConsultationStatus, Patient
app = FastAPI(title="Clarke API", version="0.1.0")
settings = get_settings()
logger = get_component_logger("api")
orchestrator = PipelineOrchestrator()
def _iso_timestamp() -> str:
"""Return UTC timestamp in ISO 8601 format with Z suffix.
Args:
None: No input parameters.
Returns:
str: UTC timestamp string.
"""
return datetime.now(tz=timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def _load_clinic_list_payload() -> dict[str, Any]:
"""Load clinic list fixture JSON from disk.
Args:
None: Reads static clinic list fixture file.
Returns:
dict[str, Any]: Parsed clinic list JSON payload.
"""
clinic_path = Path("data/clinic_list.json")
if not clinic_path.exists():
raise HTTPException(status_code=500, detail="Clinic list file is missing")
return json.loads(clinic_path.read_text(encoding="utf-8"))
def _load_patient_resource(patient_id: str) -> dict[str, Any]:
"""Load patient resource for a patient id from local FHIR bundle fixture.
Args:
patient_id (str): Patient identifier.
Returns:
dict[str, Any]: FHIR Patient resource from bundle fixture.
"""
bundle_path = Path("data/fhir_bundles") / f"{patient_id}.json"
if not bundle_path.exists():
return {}
bundle = json.loads(bundle_path.read_text(encoding="utf-8"))
for entry in bundle.get("entry", []):
resource = entry.get("resource", {})
if resource.get("resourceType") == "Patient":
return resource
return {}
def _patient_from_records(clinic_row: dict[str, Any], patient_resource: dict[str, Any]) -> Patient:
"""Build a Patient schema object from clinic list and FHIR records.
Args:
clinic_row (dict[str, Any]): Patient row from `clinic_list.json`.
patient_resource (dict[str, Any]): FHIR Patient resource dictionary.
Returns:
Patient: Normalised patient object for API responses.
"""
nhs_number = ""
for identifier in patient_resource.get("identifier", []):
value = str(identifier.get("value", "")).strip()
if value:
nhs_number = value
break
date_of_birth = str(patient_resource.get("birthDate", ""))
if date_of_birth and "-" in date_of_birth:
yyyy, mm, dd = date_of_birth.split("-")
date_of_birth = f"{dd}/{mm}/{yyyy}"
return Patient(
id=str(clinic_row.get("id", "")),
nhs_number=nhs_number,
name=str(clinic_row.get("name", "")),
date_of_birth=date_of_birth,
age=int(clinic_row.get("age", 0)),
sex=str(clinic_row.get("sex", "")),
appointment_time=str(clinic_row.get("time", "")),
summary=str(clinic_row.get("summary", "")),
)
def _get_patient(patient_id: str) -> Patient:
"""Resolve a patient from local fixture data.
Args:
patient_id (str): Patient identifier.
Returns:
Patient: Patient record for the requested id.
"""
payload = _load_clinic_list_payload()
for row in payload.get("patients", []):
if str(row.get("id")) == patient_id:
resource = _load_patient_resource(patient_id)
return _patient_from_records(row, resource)
raise HTTPException(status_code=404, detail=f"Patient not found: {patient_id}")
def _health_fhir_status() -> dict[str, Any]:
"""Compute lightweight FHIR service health metadata for health endpoint.
Args:
None: Reads settings and local data fixtures.
Returns:
dict[str, Any]: FHIR connectivity status and available patient count.
"""
clinic_list_path = Path("data/clinic_list.json")
if settings.USE_MOCK_FHIR and clinic_list_path.exists():
payload = json.loads(clinic_list_path.read_text(encoding="utf-8"))
patients = payload.get("patients", [])
return {"status": "connected", "patient_count": len(patients)}
return {"status": "unknown", "patient_count": 0}
def _save_upload_temp(upload: UploadFile, destination: Path) -> Path:
"""Persist an uploaded file to disk for subsequent conversion/validation.
Args:
upload (UploadFile): Uploaded file object from multipart payload.
destination (Path): Output path for storing file bytes.
Returns:
Path: Saved destination path.
"""
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as output_stream:
copyfileobj(upload.file, output_stream)
return destination
@app.get("/api/v1/health")
def get_health() -> dict[str, Any]:
"""Return current service, model, and GPU health state.
Args:
None: Uses process-level singletons for status.
Returns:
dict[str, Any]: Health response payload.
"""
gpu_info = orchestrator._medasr_model.model_manager.check_gpu()
models = {
"medasr": {
"loaded": orchestrator._medasr_model.model_manager.get_model("medasr") is not None,
"device": "cuda:0" if gpu_info["vram_total_bytes"] > 0 else "cpu",
},
"medgemma_4b": {"loaded": False, "device": "n/a", "quantised": "4bit"},
"medgemma_27b": {"loaded": False, "device": "n/a", "quantised": "4bit"},
}
return {
"status": "healthy",
"models": models,
"fhir": _health_fhir_status(),
"gpu": {
"name": gpu_info["gpu_name"],
"vram_used_gb": round(float(gpu_info["vram_used_bytes"]) / 1_000_000_000, 2),
"vram_total_gb": round(float(gpu_info["vram_total_bytes"]) / 1_000_000_000, 2),
},
"timestamp": _iso_timestamp(),
}
@app.get("/api/v1/patients")
def get_patients() -> dict[str, list[dict[str, Any]]]:
"""Return clinic patient list with normalised schema fields.
Args:
None: Reads local clinic list and FHIR bundle fixtures.
Returns:
dict[str, list[dict[str, Any]]]: Patient list payload.
"""
payload = _load_clinic_list_payload()
patients: list[dict[str, Any]] = []
for row in payload.get("patients", []):
patient_id = str(row.get("id", ""))
if not patient_id:
continue
patient_model = _patient_from_records(row, _load_patient_resource(patient_id))
patients.append(patient_model.model_dump())
return {"patients": patients}
@app.get("/api/v1/patients/{patient_id}")
def get_patient(patient_id: str) -> dict[str, Any]:
"""Return a single patient by id.
Args:
patient_id (str): Patient identifier.
Returns:
dict[str, Any]: Patient payload.
"""
patient = _get_patient(patient_id)
return patient.model_dump()
@app.post("/api/v1/patients/{patient_id}/context")
def generate_patient_context(patient_id: str) -> dict[str, Any]:
"""Generate patient context via EHR agent and return context JSON.
Args:
patient_id (str): Patient identifier.
Returns:
dict[str, Any]: Structured patient context payload.
"""
_get_patient(patient_id)
context = orchestrator._ehr_agent.get_patient_context(patient_id)
return context.model_dump(mode="json")
@app.post("/api/v1/consultations/start", status_code=201)
def start_consultation(payload: dict[str, str], background_tasks: BackgroundTasks) -> dict[str, Any]:
"""Start a consultation session and trigger context prefetch in background.
Args:
payload (dict[str, str]): Request body containing patient_id.
background_tasks (BackgroundTasks): FastAPI background task runner.
Returns:
dict[str, Any]: Consultation identifier and initial state.
"""
patient_id = payload.get("patient_id", "")
if not patient_id:
raise HTTPException(status_code=400, detail="patient_id is required")
patient = _get_patient(patient_id)
consultation = orchestrator.start_consultation(patient)
background_tasks.add_task(orchestrator.prefetch_context, consultation.id)
return {
"consultation_id": consultation.id,
"patient_id": patient_id,
"status": ConsultationStatus.RECORDING.value,
"started_at": consultation.started_at,
}
@app.post("/api/v1/consultations/{consultation_id}/audio")
def upload_audio(
consultation_id: str,
audio_file: UploadFile = File(...),
is_final: bool = Form(...),
) -> dict[str, Any]:
"""Accept consultation audio, convert to MedASR format, and store metadata.
Args:
consultation_id (str): Consultation identifier.
audio_file (UploadFile): Uploaded WAV/WebM audio file.
is_final (bool): Whether upload is the final complete recording.
Returns:
dict[str, Any]: Upload acknowledgement and validated duration seconds.
"""
try:
consultation = orchestrator.get_consultation(consultation_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
extension = Path(audio_file.filename or "").suffix.lower()
if extension not in {".wav", ".webm"}:
raise HTTPException(status_code=400, detail="Only WAV or WebM audio uploads are supported")
uploads_root = Path("/tmp/uploads") / consultation_id
raw_path = uploads_root / f"raw{extension}"
original_stem = Path(audio_file.filename or "audio").stem or "audio"
wav_path = uploads_root / f"{original_stem}_16k.wav"
try:
_save_upload_temp(audio_file, raw_path)
processed_path = raw_path if extension == ".wav" else Path(convert_to_wav_16k(str(raw_path), str(wav_path)))
if extension == ".wav":
processed_path = Path(convert_to_wav_16k(str(raw_path), str(wav_path)))
audio_metadata = validate_audio(str(processed_path))
except AudioError as exc:
logger.error("Audio processing failed", consultation_id=consultation_id, error=str(exc))
raise HTTPException(status_code=400, detail=str(exc)) from exc
consultation.audio_file_path = str(processed_path)
logger.info(
"Stored consultation audio",
consultation_id=consultation_id,
duration_s=audio_metadata["duration_s"],
is_final=is_final,
)
return {
"consultation_id": consultation_id,
"audio_received": True,
"duration_s": audio_metadata["duration_s"],
}
@app.post("/api/v1/consultations/{consultation_id}/end", status_code=202)
def end_consultation(consultation_id: str, body: dict[str, Any] | None = Body(None)) -> dict[str, Any]:
"""End a consultation and execute full processing pipeline.
Args:
consultation_id (str): Consultation identifier.
Returns:
dict[str, Any]: Pipeline kickoff and status metadata.
"""
try:
consultation = orchestrator.get_consultation(consultation_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
# Accept audio path directly from frontend for demo/server-side audio files
if body and body.get("audio_path"):
audio_path = body["audio_path"]
if Path(audio_path).exists():
consultation.audio_file_path = audio_path
# Accept document type and letter preferences from frontend
if body and body.get("doc_type"):
consultation.doc_type = body["doc_type"]
if body and body.get("letter_prefs"):
consultation.letter_prefs = body["letter_prefs"]
# Launch pipeline in background thread – avoids HF Spaces 60s gateway timeout
def _run_pipeline_background() -> None:
"""Execute the orchestrator pipeline in a daemon thread.
Allows the /end endpoint to return 202 immediately while processing
continues. Frontend polls /progress for status updates.
"""
try:
orchestrator.end_consultation(consultation_id)
except Exception as exc:
logger.error(f"Background pipeline error: {exc}", consultation_id=consultation_id)
thread = threading.Thread(target=_run_pipeline_background, daemon=True)
thread.start()
return {
"consultation_id": consultation.id,
"status": "processing",
"pipeline_stage": "transcribing",
"message": "Pipeline started in background. Poll /progress for updates.",
}
@app.get("/api/v1/consultations/{consultation_id}/transcript")
def get_transcript(consultation_id: str) -> dict[str, Any]:
"""Return transcript for a consultation when available.
Args:
consultation_id (str): Consultation identifier.
Returns:
dict[str, Any]: Transcript payload.
"""
try:
consultation = orchestrator.get_consultation(consultation_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
if consultation.transcript is None:
raise HTTPException(status_code=404, detail="Transcript not generated yet")
return consultation.transcript.model_dump(mode="json")
@app.get("/api/v1/consultations/{consultation_id}/document")
def get_document(consultation_id: str) -> dict[str, Any]:
"""Return generated document placeholder for a consultation.
Args:
consultation_id (str): Consultation identifier.
Returns:
dict[str, Any]: Consultation document object or null.
"""
try:
consultation = orchestrator.get_consultation(consultation_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return {
"consultation_id": consultation_id,
"document": consultation.document.model_dump(mode="json") if consultation.document else None,
}
@app.get("/api/v1/consultations/{consultation_id}/progress")
def get_progress(consultation_id: str) -> dict[str, Any]:
"""Return latest pipeline progress object for a consultation.
Args:
consultation_id (str): Consultation identifier.
Returns:
dict[str, Any]: Pipeline progress payload.
"""
try:
progress = orchestrator.get_progress(consultation_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return progress.model_dump(mode="json")
@app.post("/api/v1/consultations/{consultation_id}/document/sign-off")
def sign_off_document(consultation_id: str, payload: dict[str, Any] | None = Body(default=None)) -> dict[str, Any]:
"""Sign off a generated document and set consultation status to signed off.
Args:
consultation_id (str): Consultation identifier.
Returns:
dict[str, Any]: Signed-off document payload with updated status.
"""
sections = (payload or {}).get("sections", [])
try:
if sections:
orchestrator.update_document_sections(consultation_id, sections)
document = orchestrator.sign_off_document(consultation_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ModelExecutionError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {
"consultation_id": consultation_id,
"status": ConsultationStatus.SIGNED_OFF.value,
"document": document.model_dump(mode="json"),
}
@app.post("/api/v1/consultations/{consultation_id}/document/regenerate-section")
def regenerate_document_section(
consultation_id: str,
payload: dict[str, str] = Body(...),
) -> dict[str, Any]:
"""Regenerate one document section while keeping other sections unchanged.
Args:
consultation_id (str): Consultation identifier.
payload (dict[str, str]): Request body containing `section_heading`.
Returns:
dict[str, Any]: Updated document payload.
"""
section_heading = payload.get("section_heading", "").strip()
if not section_heading:
raise HTTPException(status_code=400, detail="section_heading is required")
try:
document = orchestrator.regenerate_document_section(consultation_id, section_heading)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ModelExecutionError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return {
"consultation_id": consultation_id,
"status": ConsultationStatus.REVIEW.value,
"document": document.model_dump(mode="json"),
}