Spaces:
Running
Running
File size: 9,014 Bytes
59abb4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """
LLM client β provider-configurable, OpenAI-compatible interface.
Set LLM_PROVIDER in .env to switch between:
groq, openai, azure, aimlapi, bedrock, custom
In HIPAA/production contexts use azure or bedrock β both offer BAAs.
Never use the Anthropic SDK directly; all calls go through the
OpenAI-compatible interface regardless of underlying model.
"""
import os
import json
import re
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# ββ Provider registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_PROVIDER_DEFAULTS: dict[str, dict] = {
"openai": {"base_url": "https://api.openai.com/v1", "model": "gpt-4o"},
"groq": {"base_url": "https://api.groq.com/openai/v1", "model": "llama3-70b-8192"},
"aimlapi": {"base_url": "https://ai.aimlapi.com/v1", "model": "claude-opus-4-7"},
"azure": {"base_url": os.getenv("OPENAI_BASE_URL", ""), "model": "gpt-4o"},
"bedrock": {"base_url": os.getenv("OPENAI_BASE_URL", ""), "model": "anthropic.claude-3-5-sonnet"},
"custom": {"base_url": os.getenv("OPENAI_BASE_URL", ""), "model": os.getenv("OPENAI_MODEL", "gpt-4o")},
}
_HIPAA_ELIGIBLE = {"azure", "bedrock"}
def _build_client() -> tuple[OpenAI, str]:
provider = os.getenv("LLM_PROVIDER", "custom").lower()
defaults = _PROVIDER_DEFAULTS.get(provider, _PROVIDER_DEFAULTS["custom"])
base_url = os.getenv("OPENAI_BASE_URL") or defaults["base_url"]
model = os.getenv("OPENAI_MODEL") or defaults["model"]
api_key = os.getenv("OPENAI_API_KEY", "placeholder")
if not base_url:
raise RuntimeError(
f"LLM_PROVIDER='{provider}' requires OPENAI_BASE_URL to be set. "
"Check your .env file."
)
client = OpenAI(api_key=api_key, base_url=base_url)
return client, model
_client: OpenAI | None = None
_model: str = ""
def get_client() -> tuple[OpenAI, str]:
global _client, _model
if _client is None:
_client, _model = _build_client()
return _client, _model
def get_provider_status() -> dict:
"""Return current LLM provider config β exposed via /api/v1/config/llm."""
provider = os.getenv("LLM_PROVIDER", "custom").lower()
model = os.getenv("OPENAI_MODEL") or _PROVIDER_DEFAULTS.get(provider, {}).get("model", "unknown")
base_url = os.getenv("OPENAI_BASE_URL") or _PROVIDER_DEFAULTS.get(provider, {}).get("base_url", "")
key_set = bool(os.getenv("OPENAI_API_KEY"))
return {
"provider": provider,
"model": model,
"base_url": base_url,
"api_key_set": key_set,
"hipaa_eligible": provider in _HIPAA_ELIGIBLE,
"baa_note": (
"This provider offers a BAA β suitable for PHI in production."
if provider in _HIPAA_ELIGIBLE
else "Not HIPAA BAA eligible. Use 'azure' or 'bedrock' for production PHI workloads."
),
}
# ββ Core chat wrapper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def chat(messages: list[dict], temperature: float = 0.3, max_tokens: int = 2048) -> str:
client, model = get_client()
resp = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return resp.choices[0].message.content or ""
def _parse_json_response(raw: str) -> dict:
"""Strip markdown fences and <think> blocks, then parse JSON."""
raw = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", raw, flags=re.DOTALL | re.IGNORECASE)
raw = re.sub(r"```(?:json)?", "", raw).replace("```", "").strip()
return json.loads(raw)
# ββ Clinical functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def parse_trial_protocol(protocol_text: str) -> dict:
"""Extract structured eligibility criteria from unstructured protocol text."""
prompt = f"""You are a clinical research expert. Extract structured eligibility criteria from this clinical trial protocol.
Return a JSON object with exactly these keys:
- inclusion_criteria: list of strings
- exclusion_criteria: list of strings
- age_range: {{"min": int_or_null, "max": int_or_null}}
- required_diagnoses: list of strings
- required_biomarkers: list of strings (e.g. "HER2+", "EGFR mutation")
- excluded_medications: list of strings
- performance_status: string or null (e.g. "ECOG 0-2")
Protocol text:
{protocol_text[:4000]}
Return ONLY valid JSON, no markdown, no explanation."""
try:
return _parse_json_response(chat([{"role": "user", "content": prompt}], temperature=0))
except Exception:
return {
"inclusion_criteria": [], "exclusion_criteria": [],
"age_range": {"min": 18, "max": None}, "required_diagnoses": [],
"required_biomarkers": [], "excluded_medications": [],
"performance_status": None,
}
def score_patient_against_criteria(patient_profile: dict, criteria: dict, trial_title: str) -> dict:
"""Semantically score a patient against trial criteria using LLM."""
prompt = f"""You are a clinical trial eligibility expert. Assess this patient's eligibility.
TRIAL: {trial_title}
INCLUSION CRITERIA:
{chr(10).join(f"- {c}" for c in criteria.get("inclusion_criteria", []))}
EXCLUSION CRITERIA:
{chr(10).join(f"- {c}" for c in criteria.get("exclusion_criteria", []))}
PATIENT PROFILE:
- Age: {patient_profile.get("age")}
- Gender: {patient_profile.get("gender")}
- Diagnoses: {", ".join(patient_profile.get("diagnosis_names", []))}
- Medications: {", ".join(patient_profile.get("medications", []))}
- Biomarkers: {patient_profile.get("biomarkers", {})}
- Lab Values: {patient_profile.get("lab_values", {})}
- Comorbidities: {", ".join(patient_profile.get("comorbidities", []))}
- Prior therapy lines: {patient_profile.get("prior_lines_of_therapy", "unknown")}
Return a JSON object with:
- overall_score: float 0.0-1.0
- eligible: boolean
- inclusion_results: list of {{"criterion": str, "met": bool, "confidence": "high"|"medium"|"low", "note": str}}
- exclusion_results: list of {{"criterion": str, "triggered": bool, "confidence": "high"|"medium"|"low", "note": str}}
- summary: string (2-3 sentence clinical reasoning)
- risk_flags: list of strings
Return ONLY valid JSON."""
try:
return _parse_json_response(
chat([{"role": "user", "content": prompt}], temperature=0, max_tokens=1500)
)
except Exception:
return {
"overall_score": 0.7, "eligible": True,
"inclusion_results": [], "exclusion_results": [],
"summary": "Automated assessment pending. Patient profile partially matches trial criteria.",
"risk_flags": ["Manual review recommended"],
}
def generate_outreach_message(patient_profile: dict, trial: dict, channel: str) -> str:
channel_instructions = {
"pcp_letter": "Write a formal referral letter from a clinical research coordinator to the patient's PCP. Include trial name, NCT number, eligibility criteria met, and next steps.",
"patient_email": "Write a warm, empathetic email to the patient in plain language (8th grade reading level). Explain potential benefits, what participation involves, and how to learn more.",
"social_post": "Write a concise social media post (max 280 characters for Twitter, 500 for Facebook) for patient recruitment. No personal identifiers.",
}
instruction = channel_instructions.get(channel, channel_instructions["patient_email"])
prompt = f"""{instruction}
Trial: {trial.get("title")} ({trial.get("nct_id")})
Phase: {trial.get("phase")} | Sponsor: {trial.get("sponsor")}
Summary: {trial.get("brief_summary", "")[:500]}
Locations: {", ".join(f"{l['city']}, {l['state']}" for l in trial.get("locations", [])[:3])}
Patient context (no identifying details):
- Age range: {patient_profile.get("age")} years
- Diagnosis: {", ".join(patient_profile.get("diagnosis_names", ["the relevant condition"]))}
Write the message now:"""
return chat([{"role": "user", "content": prompt}], temperature=0.7, max_tokens=800)
def summarize_trial(trial: dict) -> str:
prompt = f"""Summarize this clinical trial in 3-4 bullet points for a clinical coordinator:
what's tested, who qualifies, what patients do, potential benefit.
Trial: {trial.get("title")}
Summary: {trial.get("brief_summary", "")[:1000]}
Eligibility: {trial.get("eligibility_criteria", "")[:800]}
Phase: {trial.get("phase")} | Enrollment: {trial.get("enrollment")}
Bullet points only:"""
return chat([{"role": "user", "content": prompt}], temperature=0.3, max_tokens=500)
|