ehragent / server.py
suyc21's picture
fix: bracket MemoryBank() with temp env to honor BYOK
5a648bb verified
"""FastAPI bridge over ehragent.Agent — strict BYOK, with optional augmentations.
Endpoints
---------
GET /api/health → liveness
GET /api/config → list of databases + per-DB augmentation availability
POST /api/ask → run the agent. See ``AskRequest`` for fields.
BYOK
----
Each request carries its own OpenAI API key in the ``X-OpenAI-Api-Key``
header. The server never reads ``OPENAI_API_KEY`` from its environment.
For requests using a custom (local) model where the key is irrelevant,
the header may be empty — the agent will then receive ``api_key="EMPTY"``,
which OpenAI-compatible servers (vLLM, SGLang, …) accept.
Memory bank retrieval, however, ALWAYS goes to OpenAI for the embedding
call (the prebuilt index uses ``text-embedding-3-small``). When
``use_memory_bank`` is true the request must therefore include a real
OpenAI key.
Multi-DB
--------
``EHRAGENT_DATABASES`` is a JSON object mapping ``db_id`` to a config:
{
"mimic_iii": {
"path": "/data/mimic_iii.sqlite",
"label": "MIMIC-III",
"handbook": "/data/handbook/mimic_iii_handbook.md",
"memory_bank": "/data/memory_bank/mimic_iii"
}
}
``handbook`` and ``memory_bank`` are optional. Per request the caller
opts in to either via ``use_handbook`` / ``use_memory_bank``.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from openai import OpenAI
from pydantic import BaseModel, Field
from ehragent import Agent
from ehragent.memory_bank import MemoryBank
# Local-dev convenience: load .env so DB paths and CORS origins can sit there.
# Note: OPENAI_API_KEY is intentionally NOT consumed below.
load_dotenv()
try:
_project_env = Path(__file__).resolve().parents[2] / ".env"
if _project_env.exists():
load_dotenv(dotenv_path=_project_env, override=False)
except IndexError:
pass
DEFAULT_MODEL = os.getenv("EHRAGENT_MODEL", "gpt-5-mini")
DEFAULT_BASE_URL = os.getenv("EHRAGENT_BASE_URL") or None
MAX_TURNS = int(os.getenv("EHRAGENT_MAX_TURNS", "20"))
LOG_DIR = os.getenv("EHRAGENT_LOG_DIR") or None
REASONING_EFFORT = os.getenv("EHRAGENT_REASONING_EFFORT") or None
CORS_ORIGINS = [
o.strip()
for o in os.getenv("EHRAGENT_CORS_ORIGINS", "*").split(",")
if o.strip()
]
def _parse_databases() -> dict[str, dict]:
raw = os.getenv("EHRAGENT_DATABASES")
if not raw:
single = os.getenv("EHRAGENT_DB_PATH")
if single:
return {"default": {"path": single, "label": Path(single).stem}}
return {}
parsed = json.loads(raw)
if not isinstance(parsed, dict):
raise RuntimeError("EHRAGENT_DATABASES must be a JSON object.")
return parsed
DATABASES_CONFIG = _parse_databases()
# Validate paths at startup; missing files don't crash the server, but the
# augmentation is reported as unavailable.
DB_STATUS: dict[str, str | None] = {}
for db_id, cfg in DATABASES_CONFIG.items():
p = Path(cfg.get("path", ""))
DB_STATUS[db_id] = None if p.exists() else f"file not found: {p}"
def _has_handbook(cfg: dict) -> bool:
hb = cfg.get("handbook")
return bool(hb) and Path(hb).exists()
def _has_memory_bank(cfg: dict) -> bool:
mb = cfg.get("memory_bank")
return bool(mb) and Path(mb).exists()
def _build_memory_bank(path: str, *, openai_api_key: str) -> MemoryBank:
"""Load the prebuilt bank; force embeddings to use the caller's key.
MemoryBank.__init__ instantiates ``OpenAI()`` with no args, which raises
on the Space because OPENAI_API_KEY is not in the env. We bracket the
constructor with a temporary env mutation, then replace the client.
The mutation is safe because no coroutine yields between the set and
the constructor (single-threaded asyncio).
"""
prev = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = openai_api_key
try:
bank = MemoryBank(store_dir=path)
finally:
if prev is None:
os.environ.pop("OPENAI_API_KEY", None)
else:
os.environ["OPENAI_API_KEY"] = prev
bank._client = OpenAI(api_key=openai_api_key) # noqa: SLF001
return bank
def _build_agent(
cfg: dict,
*,
api_key: str,
base_url_override: str | None,
model_override: str | None,
use_handbook: bool,
use_memory_bank: bool,
openai_key_for_embeddings: str | None,
) -> Agent:
handbook_path = cfg.get("handbook") if use_handbook and _has_handbook(cfg) else None
memory_bank: MemoryBank | None = None
if use_memory_bank and _has_memory_bank(cfg):
if not openai_key_for_embeddings:
raise HTTPException(
status_code=400,
detail=(
"Memory bank retrieval needs an OpenAI key for embeddings. "
"Either provide one in X-OpenAI-Api-Key or disable use_memory_bank."
),
)
memory_bank = _build_memory_bank(
cfg["memory_bank"], openai_api_key=openai_key_for_embeddings
)
return Agent(
db_path=str(Path(cfg["path"])),
model=model_override or cfg.get("model", DEFAULT_MODEL),
base_url=base_url_override or cfg.get("base_url", DEFAULT_BASE_URL),
api_key=api_key,
handbook=handbook_path,
memory_bank=memory_bank,
max_turns=cfg.get("max_turns", MAX_TURNS),
reasoning_effort=cfg.get("reasoning_effort", REASONING_EFFORT),
log_dir=LOG_DIR,
)
app = FastAPI(title="EHRAgent WebUI Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
class AskRequest(BaseModel):
question: str = Field(..., min_length=1)
db: str | None = None
question_id: str | None = None
base_url: str | None = None
model: str | None = None
use_handbook: bool = False
use_memory_bank: bool = False
@app.get("/api/health")
async def health() -> dict:
return {"status": "ok"}
@app.get("/api/config")
async def config() -> dict:
return {
"model": DEFAULT_MODEL,
"base_url": DEFAULT_BASE_URL,
"max_turns": MAX_TURNS,
"byok_required": True,
"databases": [
{
"id": db_id,
"label": cfg.get("label", db_id),
"ready": DB_STATUS[db_id] is None,
"error": DB_STATUS[db_id],
"has_handbook": _has_handbook(cfg),
"has_memory_bank": _has_memory_bank(cfg),
}
for db_id, cfg in DATABASES_CONFIG.items()
],
}
@app.post("/api/ask")
async def ask(
req: AskRequest,
x_openai_api_key: str | None = Header(default=None, alias="X-OpenAI-Api-Key"),
) -> dict:
if not DATABASES_CONFIG:
raise HTTPException(
status_code=500,
detail="No databases configured. Set EHRAGENT_DATABASES on the server.",
)
db_id = req.db or next(iter(DATABASES_CONFIG))
cfg = DATABASES_CONFIG.get(db_id)
if cfg is None:
raise HTTPException(
status_code=400,
detail=f"Unknown db: {db_id!r}. Available: {list(DATABASES_CONFIG)}",
)
if DB_STATUS[db_id] is not None:
raise HTTPException(
status_code=503,
detail=f"db {db_id!r}: {DB_STATUS[db_id]}",
)
using_local = bool((req.base_url or "").strip())
raw_key = (x_openai_api_key or "").strip()
if not raw_key and not using_local:
raise HTTPException(
status_code=401,
detail=(
"OpenAI API key required. Send it via the X-OpenAI-Api-Key "
"header. The server does not store keys."
),
)
# Inference key: dummy when only a local model is in play.
inference_key = raw_key or "EMPTY"
# Embedding key (memory bank): only OpenAI works, must be a real key.
embedding_key = raw_key or None
try:
agent = _build_agent(
cfg,
api_key=inference_key,
base_url_override=(req.base_url or "").strip() or None,
model_override=(req.model or "").strip() or None,
use_handbook=req.use_handbook,
use_memory_bank=req.use_memory_bank,
openai_key_for_embeddings=embedding_key,
)
except HTTPException:
raise
except Exception as e: # noqa: BLE001
raise HTTPException(status_code=500, detail=str(e))
return await agent.arun(req.question, question_id=req.question_id)