CTA / backend /llm_client.py
TheQuantEd's picture
Initial deployment: ClinicalMatch AI v2.0 β€” FHIR R4 Β· MCP (9 tools) Β· A2A workflow Β· SHARP compliance Β· 100k synthetic patients Β· Neo4j graph Β· GraphRAG chatbot
59abb4f
"""
LLM client β€” provider-configurable, OpenAI-compatible interface.
Set LLM_PROVIDER in .env to switch between:
groq, openai, azure, aimlapi, bedrock, custom
In HIPAA/production contexts use azure or bedrock β€” both offer BAAs.
Never use the Anthropic SDK directly; all calls go through the
OpenAI-compatible interface regardless of underlying model.
"""
import os
import json
import re
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# ── Provider registry ─────────────────────────────────────────────────────────
_PROVIDER_DEFAULTS: dict[str, dict] = {
"openai": {"base_url": "https://api.openai.com/v1", "model": "gpt-4o"},
"groq": {"base_url": "https://api.groq.com/openai/v1", "model": "llama3-70b-8192"},
"aimlapi": {"base_url": "https://ai.aimlapi.com/v1", "model": "claude-opus-4-7"},
"azure": {"base_url": os.getenv("OPENAI_BASE_URL", ""), "model": "gpt-4o"},
"bedrock": {"base_url": os.getenv("OPENAI_BASE_URL", ""), "model": "anthropic.claude-3-5-sonnet"},
"custom": {"base_url": os.getenv("OPENAI_BASE_URL", ""), "model": os.getenv("OPENAI_MODEL", "gpt-4o")},
}
_HIPAA_ELIGIBLE = {"azure", "bedrock"}
def _build_client() -> tuple[OpenAI, str]:
provider = os.getenv("LLM_PROVIDER", "custom").lower()
defaults = _PROVIDER_DEFAULTS.get(provider, _PROVIDER_DEFAULTS["custom"])
base_url = os.getenv("OPENAI_BASE_URL") or defaults["base_url"]
model = os.getenv("OPENAI_MODEL") or defaults["model"]
api_key = os.getenv("OPENAI_API_KEY", "placeholder")
if not base_url:
raise RuntimeError(
f"LLM_PROVIDER='{provider}' requires OPENAI_BASE_URL to be set. "
"Check your .env file."
)
client = OpenAI(api_key=api_key, base_url=base_url)
return client, model
_client: OpenAI | None = None
_model: str = ""
def get_client() -> tuple[OpenAI, str]:
global _client, _model
if _client is None:
_client, _model = _build_client()
return _client, _model
def get_provider_status() -> dict:
"""Return current LLM provider config β€” exposed via /api/v1/config/llm."""
provider = os.getenv("LLM_PROVIDER", "custom").lower()
model = os.getenv("OPENAI_MODEL") or _PROVIDER_DEFAULTS.get(provider, {}).get("model", "unknown")
base_url = os.getenv("OPENAI_BASE_URL") or _PROVIDER_DEFAULTS.get(provider, {}).get("base_url", "")
key_set = bool(os.getenv("OPENAI_API_KEY"))
return {
"provider": provider,
"model": model,
"base_url": base_url,
"api_key_set": key_set,
"hipaa_eligible": provider in _HIPAA_ELIGIBLE,
"baa_note": (
"This provider offers a BAA β€” suitable for PHI in production."
if provider in _HIPAA_ELIGIBLE
else "Not HIPAA BAA eligible. Use 'azure' or 'bedrock' for production PHI workloads."
),
}
# ── Core chat wrapper ─────────────────────────────────────────────────────────
def chat(messages: list[dict], temperature: float = 0.3, max_tokens: int = 2048) -> str:
client, model = get_client()
resp = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content or ""
def _parse_json_response(raw: str) -> dict:
"""Strip markdown fences and <think> blocks, then parse JSON."""
raw = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", raw, flags=re.DOTALL | re.IGNORECASE)
raw = re.sub(r"```(?:json)?", "", raw).replace("```", "").strip()
return json.loads(raw)
# ── Clinical functions ────────────────────────────────────────────────────────
def parse_trial_protocol(protocol_text: str) -> dict:
"""Extract structured eligibility criteria from unstructured protocol text."""
prompt = f"""You are a clinical research expert. Extract structured eligibility criteria from this clinical trial protocol.
Return a JSON object with exactly these keys:
- inclusion_criteria: list of strings
- exclusion_criteria: list of strings
- age_range: {{"min": int_or_null, "max": int_or_null}}
- required_diagnoses: list of strings
- required_biomarkers: list of strings (e.g. "HER2+", "EGFR mutation")
- excluded_medications: list of strings
- performance_status: string or null (e.g. "ECOG 0-2")
Protocol text:
{protocol_text[:4000]}
Return ONLY valid JSON, no markdown, no explanation."""
try:
return _parse_json_response(chat([{"role": "user", "content": prompt}], temperature=0))
except Exception:
return {
"inclusion_criteria": [], "exclusion_criteria": [],
"age_range": {"min": 18, "max": None}, "required_diagnoses": [],
"required_biomarkers": [], "excluded_medications": [],
"performance_status": None,
}
def score_patient_against_criteria(patient_profile: dict, criteria: dict, trial_title: str) -> dict:
"""Semantically score a patient against trial criteria using LLM."""
prompt = f"""You are a clinical trial eligibility expert. Assess this patient's eligibility.
TRIAL: {trial_title}
INCLUSION CRITERIA:
{chr(10).join(f"- {c}" for c in criteria.get("inclusion_criteria", []))}
EXCLUSION CRITERIA:
{chr(10).join(f"- {c}" for c in criteria.get("exclusion_criteria", []))}
PATIENT PROFILE:
- Age: {patient_profile.get("age")}
- Gender: {patient_profile.get("gender")}
- Diagnoses: {", ".join(patient_profile.get("diagnosis_names", []))}
- Medications: {", ".join(patient_profile.get("medications", []))}
- Biomarkers: {patient_profile.get("biomarkers", {})}
- Lab Values: {patient_profile.get("lab_values", {})}
- Comorbidities: {", ".join(patient_profile.get("comorbidities", []))}
- Prior therapy lines: {patient_profile.get("prior_lines_of_therapy", "unknown")}
Return a JSON object with:
- overall_score: float 0.0-1.0
- eligible: boolean
- inclusion_results: list of {{"criterion": str, "met": bool, "confidence": "high"|"medium"|"low", "note": str}}
- exclusion_results: list of {{"criterion": str, "triggered": bool, "confidence": "high"|"medium"|"low", "note": str}}
- summary: string (2-3 sentence clinical reasoning)
- risk_flags: list of strings
Return ONLY valid JSON."""
try:
return _parse_json_response(
chat([{"role": "user", "content": prompt}], temperature=0, max_tokens=1500)
)
except Exception:
return {
"overall_score": 0.7, "eligible": True,
"inclusion_results": [], "exclusion_results": [],
"summary": "Automated assessment pending. Patient profile partially matches trial criteria.",
"risk_flags": ["Manual review recommended"],
}
def generate_outreach_message(patient_profile: dict, trial: dict, channel: str) -> str:
channel_instructions = {
"pcp_letter": "Write a formal referral letter from a clinical research coordinator to the patient's PCP. Include trial name, NCT number, eligibility criteria met, and next steps.",
"patient_email": "Write a warm, empathetic email to the patient in plain language (8th grade reading level). Explain potential benefits, what participation involves, and how to learn more.",
"social_post": "Write a concise social media post (max 280 characters for Twitter, 500 for Facebook) for patient recruitment. No personal identifiers.",
}
instruction = channel_instructions.get(channel, channel_instructions["patient_email"])
prompt = f"""{instruction}
Trial: {trial.get("title")} ({trial.get("nct_id")})
Phase: {trial.get("phase")} | Sponsor: {trial.get("sponsor")}
Summary: {trial.get("brief_summary", "")[:500]}
Locations: {", ".join(f"{l['city']}, {l['state']}" for l in trial.get("locations", [])[:3])}
Patient context (no identifying details):
- Age range: {patient_profile.get("age")} years
- Diagnosis: {", ".join(patient_profile.get("diagnosis_names", ["the relevant condition"]))}
Write the message now:"""
return chat([{"role": "user", "content": prompt}], temperature=0.7, max_tokens=800)
def summarize_trial(trial: dict) -> str:
prompt = f"""Summarize this clinical trial in 3-4 bullet points for a clinical coordinator:
what's tested, who qualifies, what patients do, potential benefit.
Trial: {trial.get("title")}
Summary: {trial.get("brief_summary", "")[:1000]}
Eligibility: {trial.get("eligibility_criteria", "")[:800]}
Phase: {trial.get("phase")} | Enrollment: {trial.get("enrollment")}
Bullet points only:"""
return chat([{"role": "user", "content": prompt}], temperature=0.3, max_tokens=500)