bshepp
commited on
Commit
Β·
28e85d4
1
Parent(s):
58e2bd0
fix: add retry logic, error handling, and startup diagnostics for deployed pipeline
Browse files- Add startup config logging in main.py (masked secrets, warnings if empty)
- Add retry with exponential backoff in MedGemma API (3 retries, handles 503 cold-start)
- Broaden patient_parser exception handler (catch all, not just ValueError)
- Stop pipeline on critical step failure (skip remaining instead of cascading errors)
- Add /api/health/config diagnostic endpoint
src/backend/app/agent/orchestrator.py
CHANGED
|
@@ -122,6 +122,9 @@ class Orchestrator:
|
|
| 122 |
This is the main entry point. Each step is executed sequentially,
|
| 123 |
with state flowing from one step to the next. Steps that don't
|
| 124 |
depend on each other (drug check + guidelines) run in parallel.
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
case_id = str(uuid.uuid4())[:8]
|
| 127 |
steps = self._create_steps(case)
|
|
@@ -134,10 +137,23 @@ class Orchestrator:
|
|
| 134 |
|
| 135 |
try:
|
| 136 |
# ββ Step 1: Parse patient data ββ
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# ββ Step 2: Clinical reasoning ββ
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# ββ Step 3 & 4: Drug check + Guidelines (parallel) ββ
|
| 143 |
parallel_tasks = []
|
|
@@ -175,6 +191,20 @@ class Orchestrator:
|
|
| 175 |
step.error = f"Pipeline aborted: {str(e)}"
|
| 176 |
raise
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
async def _run_step(self, step_id: str, fn, *args) -> AgentStep:
|
| 179 |
"""Execute a single step, tracking status and timing."""
|
| 180 |
step = self._get_step(step_id)
|
|
|
|
| 122 |
This is the main entry point. Each step is executed sequentially,
|
| 123 |
with state flowing from one step to the next. Steps that don't
|
| 124 |
depend on each other (drug check + guidelines) run in parallel.
|
| 125 |
+
|
| 126 |
+
If a critical step (parse, reason) fails, subsequent dependent
|
| 127 |
+
steps are marked as SKIPPED to avoid cascading errors.
|
| 128 |
"""
|
| 129 |
case_id = str(uuid.uuid4())[:8]
|
| 130 |
steps = self._create_steps(case)
|
|
|
|
| 137 |
|
| 138 |
try:
|
| 139 |
# ββ Step 1: Parse patient data ββ
|
| 140 |
+
step = await self._run_step("parse", self._step_parse, case.patient_text)
|
| 141 |
+
yield step
|
| 142 |
+
|
| 143 |
+
if step.status == AgentStepStatus.FAILED:
|
| 144 |
+
# Can't continue without patient profile β skip remaining steps
|
| 145 |
+
yield from self._skip_remaining_steps("parse")
|
| 146 |
+
self._state.completed_at = datetime.utcnow()
|
| 147 |
+
return
|
| 148 |
|
| 149 |
# ββ Step 2: Clinical reasoning ββ
|
| 150 |
+
step = await self._run_step("reason", self._step_reason)
|
| 151 |
+
yield step
|
| 152 |
+
|
| 153 |
+
if step.status == AgentStepStatus.FAILED:
|
| 154 |
+
yield from self._skip_remaining_steps("reason")
|
| 155 |
+
self._state.completed_at = datetime.utcnow()
|
| 156 |
+
return
|
| 157 |
|
| 158 |
# ββ Step 3 & 4: Drug check + Guidelines (parallel) ββ
|
| 159 |
parallel_tasks = []
|
|
|
|
| 191 |
step.error = f"Pipeline aborted: {str(e)}"
|
| 192 |
raise
|
| 193 |
|
| 194 |
+
def _skip_remaining_steps(self, after_step_id: str) -> list[AgentStep]:
|
| 195 |
+
"""Mark all steps after after_step_id as skipped. Returns them for yielding."""
|
| 196 |
+
skipped = []
|
| 197 |
+
found = False
|
| 198 |
+
for step in self._state.steps:
|
| 199 |
+
if step.step_id == after_step_id:
|
| 200 |
+
found = True
|
| 201 |
+
continue
|
| 202 |
+
if found and step.status == AgentStepStatus.PENDING:
|
| 203 |
+
step.status = AgentStepStatus.SKIPPED
|
| 204 |
+
step.error = f"Skipped: prerequisite step '{after_step_id}' failed"
|
| 205 |
+
skipped.append(step)
|
| 206 |
+
return skipped
|
| 207 |
+
|
| 208 |
async def _run_step(self, step_id: str, fn, *args) -> AgentStep:
|
| 209 |
"""Execute a single step, tracking status and timing."""
|
| 210 |
step = self._get_step(step_id)
|
src/backend/app/api/health.py
CHANGED
|
@@ -1,9 +1,26 @@
|
|
| 1 |
"""Health check endpoint."""
|
|
|
|
|
|
|
| 2 |
from fastapi import APIRouter
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
router = APIRouter()
|
| 5 |
|
| 6 |
|
| 7 |
@router.get("/health")
|
| 8 |
async def health_check():
|
| 9 |
return {"status": "ok", "service": "CDS Agent"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Health check endpoint."""
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
from fastapi import APIRouter
|
| 5 |
|
| 6 |
+
from app.config import settings
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
router = APIRouter()
|
| 10 |
|
| 11 |
|
| 12 |
@router.get("/health")
|
| 13 |
async def health_check():
|
| 14 |
return {"status": "ok", "service": "CDS Agent"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("/api/health/config")
|
| 18 |
+
async def config_check():
|
| 19 |
+
"""Diagnostic endpoint: shows whether critical env vars are configured (no secrets)."""
|
| 20 |
+
return {
|
| 21 |
+
"medgemma_base_url_set": bool(settings.medgemma_base_url),
|
| 22 |
+
"medgemma_api_key_set": bool(settings.medgemma_api_key),
|
| 23 |
+
"medgemma_model_id": settings.medgemma_model_id,
|
| 24 |
+
"hf_token_set": bool(settings.hf_token),
|
| 25 |
+
"medgemma_max_tokens": settings.medgemma_max_tokens,
|
| 26 |
+
}
|
src/backend/app/main.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
"""
|
| 2 |
Clinical Decision Support Agent β FastAPI Backend
|
| 3 |
"""
|
|
|
|
|
|
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
|
| 7 |
from app.api import cases, health, ws
|
| 8 |
from app.config import settings
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
app = FastAPI(
|
| 11 |
title="Clinical Decision Support Agent",
|
| 12 |
description="Agentic clinical decision support powered by MedGemma (HAI-DEF)",
|
|
@@ -31,6 +36,24 @@ app.include_router(ws.router, prefix="/ws", tags=["websocket"])
|
|
| 31 |
@app.on_event("startup")
|
| 32 |
async def startup():
|
| 33 |
"""Initialize services on startup."""
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Clinical Decision Support Agent β FastAPI Backend
|
| 3 |
"""
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
from fastapi import FastAPI
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
from app.api import cases, health, ws
|
| 10 |
from app.config import settings
|
| 11 |
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
app = FastAPI(
|
| 16 |
title="Clinical Decision Support Agent",
|
| 17 |
description="Agentic clinical decision support powered by MedGemma (HAI-DEF)",
|
|
|
|
| 36 |
@app.on_event("startup")
|
| 37 |
async def startup():
|
| 38 |
"""Initialize services on startup."""
|
| 39 |
+
# Log configuration (mask secrets)
|
| 40 |
+
def _mask(val: str) -> str:
|
| 41 |
+
if not val:
|
| 42 |
+
return "(empty)"
|
| 43 |
+
if len(val) <= 8:
|
| 44 |
+
return "***"
|
| 45 |
+
return val[:4] + "..." + val[-4:]
|
| 46 |
+
|
| 47 |
+
logger.info("=== CDS Agent Backend Starting ===")
|
| 48 |
+
logger.info(f" medgemma_base_url : {settings.medgemma_base_url or '(empty)'}")
|
| 49 |
+
logger.info(f" medgemma_model_id : {settings.medgemma_model_id}")
|
| 50 |
+
logger.info(f" medgemma_api_key : {_mask(settings.medgemma_api_key)}")
|
| 51 |
+
logger.info(f" hf_token : {_mask(settings.hf_token)}")
|
| 52 |
+
logger.info(f" medgemma_max_tokens: {settings.medgemma_max_tokens}")
|
| 53 |
+
logger.info(f" cors_origins : {settings.cors_origins}")
|
| 54 |
+
logger.info(f" chroma_persist_dir: {settings.chroma_persist_dir}")
|
| 55 |
+
|
| 56 |
+
if not settings.medgemma_base_url:
|
| 57 |
+
logger.warning("MEDGEMMA_BASE_URL is empty -- MedGemma API calls will fail!")
|
| 58 |
+
if not settings.medgemma_api_key:
|
| 59 |
+
logger.warning("MEDGEMMA_API_KEY is empty -- MedGemma API calls will fail!")
|
src/backend/app/services/medgemma.py
CHANGED
|
@@ -10,6 +10,7 @@ All tools that need MedGemma go through this service.
|
|
| 10 |
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 13 |
import json
|
| 14 |
import logging
|
| 15 |
from typing import Any, Optional, Type, TypeVar
|
|
@@ -22,6 +23,10 @@ logger = logging.getLogger(__name__)
|
|
| 22 |
|
| 23 |
T = TypeVar("T", bound=BaseModel)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class MedGemmaService:
|
| 27 |
"""
|
|
@@ -146,6 +151,9 @@ class MedGemmaService:
|
|
| 146 |
happens to be plain Gemma on Google AI Studio (which rejects the system
|
| 147 |
role), we automatically fall back to folding the system prompt into the
|
| 148 |
user message.
|
|
|
|
|
|
|
|
|
|
| 149 |
"""
|
| 150 |
client = await self._get_client()
|
| 151 |
|
|
@@ -154,29 +162,59 @@ class MedGemmaService:
|
|
| 154 |
messages.append({"role": "system", "content": system_prompt})
|
| 155 |
messages.append({"role": "user", "content": prompt})
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
max_tokens=max_tokens,
|
| 162 |
-
temperature=temperature,
|
| 163 |
-
)
|
| 164 |
-
return response.choices[0].message.content
|
| 165 |
-
except Exception as e:
|
| 166 |
-
# Fallback: fold system prompt into user message (Google AI Studio compat)
|
| 167 |
-
if system_prompt and "system" in str(e).lower():
|
| 168 |
-
logger.warning("Backend rejected system role β folding into user message.")
|
| 169 |
-
fallback_messages = [
|
| 170 |
-
{"role": "user", "content": f"{system_prompt}\n\n{prompt}"}
|
| 171 |
-
]
|
| 172 |
response = await client.chat.completions.create(
|
| 173 |
model=settings.medgemma_model_id,
|
| 174 |
-
messages=
|
| 175 |
max_tokens=max_tokens,
|
| 176 |
temperature=temperature,
|
| 177 |
)
|
| 178 |
return response.choices[0].message.content
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
async def _generate_local(
|
| 182 |
self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
|
|
|
|
| 10 |
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
+
import asyncio
|
| 14 |
import json
|
| 15 |
import logging
|
| 16 |
from typing import Any, Optional, Type, TypeVar
|
|
|
|
| 23 |
|
| 24 |
T = TypeVar("T", bound=BaseModel)
|
| 25 |
|
| 26 |
+
# Retry configuration for transient API errors (cold-start / 503)
|
| 27 |
+
MAX_API_RETRIES = 3
|
| 28 |
+
RETRY_BASE_DELAY = 5.0 # seconds, doubles on each retry
|
| 29 |
+
|
| 30 |
|
| 31 |
class MedGemmaService:
|
| 32 |
"""
|
|
|
|
| 151 |
happens to be plain Gemma on Google AI Studio (which rejects the system
|
| 152 |
role), we automatically fall back to folding the system prompt into the
|
| 153 |
user message.
|
| 154 |
+
|
| 155 |
+
Includes retry with exponential backoff for transient errors (503 cold
|
| 156 |
+
start, connection errors, timeouts).
|
| 157 |
"""
|
| 158 |
client = await self._get_client()
|
| 159 |
|
|
|
|
| 162 |
messages.append({"role": "system", "content": system_prompt})
|
| 163 |
messages.append({"role": "user", "content": prompt})
|
| 164 |
|
| 165 |
+
last_error: Optional[Exception] = None
|
| 166 |
+
|
| 167 |
+
for attempt in range(MAX_API_RETRIES):
|
| 168 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
response = await client.chat.completions.create(
|
| 170 |
model=settings.medgemma_model_id,
|
| 171 |
+
messages=messages,
|
| 172 |
max_tokens=max_tokens,
|
| 173 |
temperature=temperature,
|
| 174 |
)
|
| 175 |
return response.choices[0].message.content
|
| 176 |
+
except Exception as e:
|
| 177 |
+
error_str = str(e).lower()
|
| 178 |
+
last_error = e
|
| 179 |
+
|
| 180 |
+
# Detect system-role rejection (Google AI Studio) β immediate fallback, no retry
|
| 181 |
+
if system_prompt and "system" in error_str:
|
| 182 |
+
logger.warning("Backend rejected system role -- folding into user message.")
|
| 183 |
+
fallback_messages = [
|
| 184 |
+
{"role": "user", "content": f"{system_prompt}\n\n{prompt}"}
|
| 185 |
+
]
|
| 186 |
+
try:
|
| 187 |
+
response = await client.chat.completions.create(
|
| 188 |
+
model=settings.medgemma_model_id,
|
| 189 |
+
messages=fallback_messages,
|
| 190 |
+
max_tokens=max_tokens,
|
| 191 |
+
temperature=temperature,
|
| 192 |
+
)
|
| 193 |
+
return response.choices[0].message.content
|
| 194 |
+
except Exception as e2:
|
| 195 |
+
last_error = e2
|
| 196 |
+
error_str = str(e2).lower()
|
| 197 |
+
|
| 198 |
+
# Retry on transient errors (503, 502, 429, connection, timeout)
|
| 199 |
+
is_transient = any(
|
| 200 |
+
keyword in error_str
|
| 201 |
+
for keyword in ["503", "502", "429", "service unavailable", "overloaded",
|
| 202 |
+
"connection", "timeout", "timed out", "temporarily"]
|
| 203 |
+
)
|
| 204 |
+
if is_transient and attempt < MAX_API_RETRIES - 1:
|
| 205 |
+
delay = RETRY_BASE_DELAY * (2 ** attempt)
|
| 206 |
+
logger.warning(
|
| 207 |
+
f"MedGemma API transient error (attempt {attempt + 1}/{MAX_API_RETRIES}): "
|
| 208 |
+
f"{e}. Retrying in {delay:.0f}s..."
|
| 209 |
+
)
|
| 210 |
+
await asyncio.sleep(delay)
|
| 211 |
+
continue
|
| 212 |
+
|
| 213 |
+
# Non-transient or final attempt β log and raise
|
| 214 |
+
logger.error(f"MedGemma API error (attempt {attempt + 1}/{MAX_API_RETRIES}): {e}")
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
raise last_error
|
| 218 |
|
| 219 |
async def _generate_local(
|
| 220 |
self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
|
src/backend/app/tools/patient_parser.py
CHANGED
|
@@ -67,11 +67,11 @@ class PatientParserTool:
|
|
| 67 |
logger.info(f"Parsed patient profile: {profile.chief_complaint}")
|
| 68 |
return profile
|
| 69 |
|
| 70 |
-
except
|
| 71 |
-
# Fallback: If
|
| 72 |
-
logger.warning("
|
| 73 |
return PatientProfile(
|
| 74 |
chief_complaint=patient_text[:200],
|
| 75 |
history_of_present_illness=patient_text,
|
| 76 |
-
additional_notes="Auto-extracted from raw text (structured parsing failed)",
|
| 77 |
)
|
|
|
|
| 67 |
logger.info(f"Parsed patient profile: {profile.chief_complaint}")
|
| 68 |
return profile
|
| 69 |
|
| 70 |
+
except Exception as e:
|
| 71 |
+
# Fallback: If any error occurs (API, parsing, etc.), do basic extraction
|
| 72 |
+
logger.warning(f"Patient parsing failed ({type(e).__name__}: {e}), using basic extraction")
|
| 73 |
return PatientProfile(
|
| 74 |
chief_complaint=patient_text[:200],
|
| 75 |
history_of_present_illness=patient_text,
|
| 76 |
+
additional_notes=f"Auto-extracted from raw text (structured parsing failed: {type(e).__name__})",
|
| 77 |
)
|