bshepp commited on
Commit
28e85d4
Β·
1 Parent(s): 58e2bd0

fix: add retry logic, error handling, and startup diagnostics for deployed pipeline

Browse files

- Add startup config logging in main.py (masked secrets, warnings if empty)
- Add retry with exponential backoff in MedGemma API (3 retries, handles 503 cold-start)
- Broaden patient_parser exception handler (catch all, not just ValueError)
- Stop pipeline on critical step failure (skip remaining instead of cascading errors)
- Add /api/health/config diagnostic endpoint

src/backend/app/agent/orchestrator.py CHANGED
@@ -122,6 +122,9 @@ class Orchestrator:
122
  This is the main entry point. Each step is executed sequentially,
123
  with state flowing from one step to the next. Steps that don't
124
  depend on each other (drug check + guidelines) run in parallel.
 
 
 
125
  """
126
  case_id = str(uuid.uuid4())[:8]
127
  steps = self._create_steps(case)
@@ -134,10 +137,23 @@ class Orchestrator:
134
 
135
  try:
136
  # ── Step 1: Parse patient data ──
137
- yield await self._run_step("parse", self._step_parse, case.patient_text)
 
 
 
 
 
 
 
138
 
139
  # ── Step 2: Clinical reasoning ──
140
- yield await self._run_step("reason", self._step_reason)
 
 
 
 
 
 
141
 
142
  # ── Step 3 & 4: Drug check + Guidelines (parallel) ──
143
  parallel_tasks = []
@@ -175,6 +191,20 @@ class Orchestrator:
175
  step.error = f"Pipeline aborted: {str(e)}"
176
  raise
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  async def _run_step(self, step_id: str, fn, *args) -> AgentStep:
179
  """Execute a single step, tracking status and timing."""
180
  step = self._get_step(step_id)
 
122
  This is the main entry point. Each step is executed sequentially,
123
  with state flowing from one step to the next. Steps that don't
124
  depend on each other (drug check + guidelines) run in parallel.
125
+
126
+ If a critical step (parse, reason) fails, subsequent dependent
127
+ steps are marked as SKIPPED to avoid cascading errors.
128
  """
129
  case_id = str(uuid.uuid4())[:8]
130
  steps = self._create_steps(case)
 
137
 
138
  try:
139
  # ── Step 1: Parse patient data ──
140
+ step = await self._run_step("parse", self._step_parse, case.patient_text)
141
+ yield step
142
+
143
+ if step.status == AgentStepStatus.FAILED:
144
+ # Can't continue without patient profile β€” skip remaining steps
145
+ yield from self._skip_remaining_steps("parse")
146
+ self._state.completed_at = datetime.utcnow()
147
+ return
148
 
149
  # ── Step 2: Clinical reasoning ──
150
+ step = await self._run_step("reason", self._step_reason)
151
+ yield step
152
+
153
+ if step.status == AgentStepStatus.FAILED:
154
+ yield from self._skip_remaining_steps("reason")
155
+ self._state.completed_at = datetime.utcnow()
156
+ return
157
 
158
  # ── Step 3 & 4: Drug check + Guidelines (parallel) ──
159
  parallel_tasks = []
 
191
  step.error = f"Pipeline aborted: {str(e)}"
192
  raise
193
 
194
+ def _skip_remaining_steps(self, after_step_id: str) -> list[AgentStep]:
195
+ """Mark all steps after after_step_id as skipped. Returns them for yielding."""
196
+ skipped = []
197
+ found = False
198
+ for step in self._state.steps:
199
+ if step.step_id == after_step_id:
200
+ found = True
201
+ continue
202
+ if found and step.status == AgentStepStatus.PENDING:
203
+ step.status = AgentStepStatus.SKIPPED
204
+ step.error = f"Skipped: prerequisite step '{after_step_id}' failed"
205
+ skipped.append(step)
206
+ return skipped
207
+
208
  async def _run_step(self, step_id: str, fn, *args) -> AgentStep:
209
  """Execute a single step, tracking status and timing."""
210
  step = self._get_step(step_id)
src/backend/app/api/health.py CHANGED
@@ -1,9 +1,26 @@
1
  """Health check endpoint."""
 
 
2
  from fastapi import APIRouter
3
 
 
 
 
4
  router = APIRouter()
5
 
6
 
7
  @router.get("/health")
8
  async def health_check():
9
  return {"status": "ok", "service": "CDS Agent"}
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Health check endpoint."""
2
+ import logging
3
+
4
  from fastapi import APIRouter
5
 
6
+ from app.config import settings
7
+
8
+ logger = logging.getLogger(__name__)
9
  router = APIRouter()
10
 
11
 
12
  @router.get("/health")
13
  async def health_check():
14
  return {"status": "ok", "service": "CDS Agent"}
15
+
16
+
17
+ @router.get("/api/health/config")
18
+ async def config_check():
19
+ """Diagnostic endpoint: shows whether critical env vars are configured (no secrets)."""
20
+ return {
21
+ "medgemma_base_url_set": bool(settings.medgemma_base_url),
22
+ "medgemma_api_key_set": bool(settings.medgemma_api_key),
23
+ "medgemma_model_id": settings.medgemma_model_id,
24
+ "hf_token_set": bool(settings.hf_token),
25
+ "medgemma_max_tokens": settings.medgemma_max_tokens,
26
+ }
src/backend/app/main.py CHANGED
@@ -1,12 +1,17 @@
1
  """
2
  Clinical Decision Support Agent β€” FastAPI Backend
3
  """
 
 
4
  from fastapi import FastAPI
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
7
  from app.api import cases, health, ws
8
  from app.config import settings
9
 
 
 
 
10
  app = FastAPI(
11
  title="Clinical Decision Support Agent",
12
  description="Agentic clinical decision support powered by MedGemma (HAI-DEF)",
@@ -31,6 +36,24 @@ app.include_router(ws.router, prefix="/ws", tags=["websocket"])
31
  @app.on_event("startup")
32
  async def startup():
33
  """Initialize services on startup."""
34
- # TODO: Initialize MedGemma model / connection
35
- # TODO: Initialize RAG vector store
36
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Clinical Decision Support Agent β€” FastAPI Backend
3
  """
4
+ import logging
5
+
6
  from fastapi import FastAPI
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
  from app.api import cases, health, ws
10
  from app.config import settings
11
 
12
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
13
+ logger = logging.getLogger(__name__)
14
+
15
  app = FastAPI(
16
  title="Clinical Decision Support Agent",
17
  description="Agentic clinical decision support powered by MedGemma (HAI-DEF)",
 
36
  @app.on_event("startup")
37
  async def startup():
38
  """Initialize services on startup."""
39
+ # Log configuration (mask secrets)
40
+ def _mask(val: str) -> str:
41
+ if not val:
42
+ return "(empty)"
43
+ if len(val) <= 8:
44
+ return "***"
45
+ return val[:4] + "..." + val[-4:]
46
+
47
+ logger.info("=== CDS Agent Backend Starting ===")
48
+ logger.info(f" medgemma_base_url : {settings.medgemma_base_url or '(empty)'}")
49
+ logger.info(f" medgemma_model_id : {settings.medgemma_model_id}")
50
+ logger.info(f" medgemma_api_key : {_mask(settings.medgemma_api_key)}")
51
+ logger.info(f" hf_token : {_mask(settings.hf_token)}")
52
+ logger.info(f" medgemma_max_tokens: {settings.medgemma_max_tokens}")
53
+ logger.info(f" cors_origins : {settings.cors_origins}")
54
+ logger.info(f" chroma_persist_dir: {settings.chroma_persist_dir}")
55
+
56
+ if not settings.medgemma_base_url:
57
+ logger.warning("MEDGEMMA_BASE_URL is empty -- MedGemma API calls will fail!")
58
+ if not settings.medgemma_api_key:
59
+ logger.warning("MEDGEMMA_API_KEY is empty -- MedGemma API calls will fail!")
src/backend/app/services/medgemma.py CHANGED
@@ -10,6 +10,7 @@ All tools that need MedGemma go through this service.
10
  """
11
  from __future__ import annotations
12
 
 
13
  import json
14
  import logging
15
  from typing import Any, Optional, Type, TypeVar
@@ -22,6 +23,10 @@ logger = logging.getLogger(__name__)
22
 
23
  T = TypeVar("T", bound=BaseModel)
24
 
 
 
 
 
25
 
26
  class MedGemmaService:
27
  """
@@ -146,6 +151,9 @@ class MedGemmaService:
146
  happens to be plain Gemma on Google AI Studio (which rejects the system
147
  role), we automatically fall back to folding the system prompt into the
148
  user message.
 
 
 
149
  """
150
  client = await self._get_client()
151
 
@@ -154,29 +162,59 @@ class MedGemmaService:
154
  messages.append({"role": "system", "content": system_prompt})
155
  messages.append({"role": "user", "content": prompt})
156
 
157
- try:
158
- response = await client.chat.completions.create(
159
- model=settings.medgemma_model_id,
160
- messages=messages,
161
- max_tokens=max_tokens,
162
- temperature=temperature,
163
- )
164
- return response.choices[0].message.content
165
- except Exception as e:
166
- # Fallback: fold system prompt into user message (Google AI Studio compat)
167
- if system_prompt and "system" in str(e).lower():
168
- logger.warning("Backend rejected system role β€” folding into user message.")
169
- fallback_messages = [
170
- {"role": "user", "content": f"{system_prompt}\n\n{prompt}"}
171
- ]
172
  response = await client.chat.completions.create(
173
  model=settings.medgemma_model_id,
174
- messages=fallback_messages,
175
  max_tokens=max_tokens,
176
  temperature=temperature,
177
  )
178
  return response.choices[0].message.content
179
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  async def _generate_local(
182
  self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
 
10
  """
11
  from __future__ import annotations
12
 
13
+ import asyncio
14
  import json
15
  import logging
16
  from typing import Any, Optional, Type, TypeVar
 
23
 
24
  T = TypeVar("T", bound=BaseModel)
25
 
26
+ # Retry configuration for transient API errors (cold-start / 503)
27
+ MAX_API_RETRIES = 3
28
+ RETRY_BASE_DELAY = 5.0 # seconds, doubles on each retry
29
+
30
 
31
  class MedGemmaService:
32
  """
 
151
  happens to be plain Gemma on Google AI Studio (which rejects the system
152
  role), we automatically fall back to folding the system prompt into the
153
  user message.
154
+
155
+ Includes retry with exponential backoff for transient errors (503 cold
156
+ start, connection errors, timeouts).
157
  """
158
  client = await self._get_client()
159
 
 
162
  messages.append({"role": "system", "content": system_prompt})
163
  messages.append({"role": "user", "content": prompt})
164
 
165
+ last_error: Optional[Exception] = None
166
+
167
+ for attempt in range(MAX_API_RETRIES):
168
+ try:
 
 
 
 
 
 
 
 
 
 
 
169
  response = await client.chat.completions.create(
170
  model=settings.medgemma_model_id,
171
+ messages=messages,
172
  max_tokens=max_tokens,
173
  temperature=temperature,
174
  )
175
  return response.choices[0].message.content
176
+ except Exception as e:
177
+ error_str = str(e).lower()
178
+ last_error = e
179
+
180
+ # Detect system-role rejection (Google AI Studio) β€” immediate fallback, no retry
181
+ if system_prompt and "system" in error_str:
182
+ logger.warning("Backend rejected system role -- folding into user message.")
183
+ fallback_messages = [
184
+ {"role": "user", "content": f"{system_prompt}\n\n{prompt}"}
185
+ ]
186
+ try:
187
+ response = await client.chat.completions.create(
188
+ model=settings.medgemma_model_id,
189
+ messages=fallback_messages,
190
+ max_tokens=max_tokens,
191
+ temperature=temperature,
192
+ )
193
+ return response.choices[0].message.content
194
+ except Exception as e2:
195
+ last_error = e2
196
+ error_str = str(e2).lower()
197
+
198
+ # Retry on transient errors (503, 502, 429, connection, timeout)
199
+ is_transient = any(
200
+ keyword in error_str
201
+ for keyword in ["503", "502", "429", "service unavailable", "overloaded",
202
+ "connection", "timeout", "timed out", "temporarily"]
203
+ )
204
+ if is_transient and attempt < MAX_API_RETRIES - 1:
205
+ delay = RETRY_BASE_DELAY * (2 ** attempt)
206
+ logger.warning(
207
+ f"MedGemma API transient error (attempt {attempt + 1}/{MAX_API_RETRIES}): "
208
+ f"{e}. Retrying in {delay:.0f}s..."
209
+ )
210
+ await asyncio.sleep(delay)
211
+ continue
212
+
213
+ # Non-transient or final attempt β€” log and raise
214
+ logger.error(f"MedGemma API error (attempt {attempt + 1}/{MAX_API_RETRIES}): {e}")
215
+ break
216
+
217
+ raise last_error
218
 
219
  async def _generate_local(
220
  self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
src/backend/app/tools/patient_parser.py CHANGED
@@ -67,11 +67,11 @@ class PatientParserTool:
67
  logger.info(f"Parsed patient profile: {profile.chief_complaint}")
68
  return profile
69
 
70
- except ValueError:
71
- # Fallback: If structured parsing fails, do basic extraction
72
- logger.warning("Structured parsing failed, attempting basic extraction")
73
  return PatientProfile(
74
  chief_complaint=patient_text[:200],
75
  history_of_present_illness=patient_text,
76
- additional_notes="Auto-extracted from raw text (structured parsing failed)",
77
  )
 
67
  logger.info(f"Parsed patient profile: {profile.chief_complaint}")
68
  return profile
69
 
70
+ except Exception as e:
71
+ # Fallback: If any error occurs (API, parsing, etc.), do basic extraction
72
+ logger.warning(f"Patient parsing failed ({type(e).__name__}: {e}), using basic extraction")
73
  return PatientProfile(
74
  chief_complaint=patient_text[:200],
75
  history_of_present_illness=patient_text,
76
+ additional_notes=f"Auto-extracted from raw text (structured parsing failed: {type(e).__name__})",
77
  )