File size: 9,014 Bytes
59abb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
LLM client β€” provider-configurable, OpenAI-compatible interface.

Set LLM_PROVIDER in .env to switch between:
  groq, openai, azure, aimlapi, bedrock, custom

In HIPAA/production contexts use azure or bedrock β€” both offer BAAs.
Never use the Anthropic SDK directly; all calls go through the
OpenAI-compatible interface regardless of underlying model.
"""
import os
import json
import re
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

# ── Provider registry ─────────────────────────────────────────────────────────

_PROVIDER_DEFAULTS: dict[str, dict] = {
    "openai":   {"base_url": "https://api.openai.com/v1",           "model": "gpt-4o"},
    "groq":     {"base_url": "https://api.groq.com/openai/v1",      "model": "llama3-70b-8192"},
    "aimlapi":  {"base_url": "https://ai.aimlapi.com/v1",           "model": "claude-opus-4-7"},
    "azure":    {"base_url": os.getenv("OPENAI_BASE_URL", ""),      "model": "gpt-4o"},
    "bedrock":  {"base_url": os.getenv("OPENAI_BASE_URL", ""),      "model": "anthropic.claude-3-5-sonnet"},
    "custom":   {"base_url": os.getenv("OPENAI_BASE_URL", ""),      "model": os.getenv("OPENAI_MODEL", "gpt-4o")},
}

_HIPAA_ELIGIBLE = {"azure", "bedrock"}

def _build_client() -> tuple[OpenAI, str]:
    provider = os.getenv("LLM_PROVIDER", "custom").lower()
    defaults = _PROVIDER_DEFAULTS.get(provider, _PROVIDER_DEFAULTS["custom"])

    base_url = os.getenv("OPENAI_BASE_URL") or defaults["base_url"]
    model    = os.getenv("OPENAI_MODEL")    or defaults["model"]
    api_key  = os.getenv("OPENAI_API_KEY",  "placeholder")

    if not base_url:
        raise RuntimeError(
            f"LLM_PROVIDER='{provider}' requires OPENAI_BASE_URL to be set. "
            "Check your .env file."
        )

    client = OpenAI(api_key=api_key, base_url=base_url)
    return client, model


_client: OpenAI | None = None
_model: str = ""


def get_client() -> tuple[OpenAI, str]:
    global _client, _model
    if _client is None:
        _client, _model = _build_client()
    return _client, _model


def get_provider_status() -> dict:
    """Return current LLM provider config β€” exposed via /api/v1/config/llm."""
    provider = os.getenv("LLM_PROVIDER", "custom").lower()
    model    = os.getenv("OPENAI_MODEL") or _PROVIDER_DEFAULTS.get(provider, {}).get("model", "unknown")
    base_url = os.getenv("OPENAI_BASE_URL") or _PROVIDER_DEFAULTS.get(provider, {}).get("base_url", "")
    key_set  = bool(os.getenv("OPENAI_API_KEY"))
    return {
        "provider":      provider,
        "model":         model,
        "base_url":      base_url,
        "api_key_set":   key_set,
        "hipaa_eligible": provider in _HIPAA_ELIGIBLE,
        "baa_note": (
            "This provider offers a BAA β€” suitable for PHI in production."
            if provider in _HIPAA_ELIGIBLE
            else "Not HIPAA BAA eligible. Use 'azure' or 'bedrock' for production PHI workloads."
        ),
    }


# ── Core chat wrapper ─────────────────────────────────────────────────────────

def chat(messages: list[dict], temperature: float = 0.3, max_tokens: int = 2048) -> str:
    client, model = get_client()
    resp = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return resp.choices[0].message.content or ""


def _parse_json_response(raw: str) -> dict:
    """Strip markdown fences and <think> blocks, then parse JSON."""
    raw = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", raw, flags=re.DOTALL | re.IGNORECASE)
    raw = re.sub(r"```(?:json)?", "", raw).replace("```", "").strip()
    return json.loads(raw)


# ── Clinical functions ────────────────────────────────────────────────────────

def parse_trial_protocol(protocol_text: str) -> dict:
    """Extract structured eligibility criteria from unstructured protocol text."""
    prompt = f"""You are a clinical research expert. Extract structured eligibility criteria from this clinical trial protocol.

Return a JSON object with exactly these keys:
- inclusion_criteria: list of strings
- exclusion_criteria: list of strings
- age_range: {{"min": int_or_null, "max": int_or_null}}
- required_diagnoses: list of strings
- required_biomarkers: list of strings (e.g. "HER2+", "EGFR mutation")
- excluded_medications: list of strings
- performance_status: string or null (e.g. "ECOG 0-2")

Protocol text:
{protocol_text[:4000]}

Return ONLY valid JSON, no markdown, no explanation."""

    try:
        return _parse_json_response(chat([{"role": "user", "content": prompt}], temperature=0))
    except Exception:
        return {
            "inclusion_criteria": [], "exclusion_criteria": [],
            "age_range": {"min": 18, "max": None}, "required_diagnoses": [],
            "required_biomarkers": [], "excluded_medications": [],
            "performance_status": None,
        }


def score_patient_against_criteria(patient_profile: dict, criteria: dict, trial_title: str) -> dict:
    """Semantically score a patient against trial criteria using LLM."""
    prompt = f"""You are a clinical trial eligibility expert. Assess this patient's eligibility.

TRIAL: {trial_title}

INCLUSION CRITERIA:
{chr(10).join(f"- {c}" for c in criteria.get("inclusion_criteria", []))}

EXCLUSION CRITERIA:
{chr(10).join(f"- {c}" for c in criteria.get("exclusion_criteria", []))}

PATIENT PROFILE:
- Age: {patient_profile.get("age")}
- Gender: {patient_profile.get("gender")}
- Diagnoses: {", ".join(patient_profile.get("diagnosis_names", []))}
- Medications: {", ".join(patient_profile.get("medications", []))}
- Biomarkers: {patient_profile.get("biomarkers", {})}
- Lab Values: {patient_profile.get("lab_values", {})}
- Comorbidities: {", ".join(patient_profile.get("comorbidities", []))}
- Prior therapy lines: {patient_profile.get("prior_lines_of_therapy", "unknown")}

Return a JSON object with:
- overall_score: float 0.0-1.0
- eligible: boolean
- inclusion_results: list of {{"criterion": str, "met": bool, "confidence": "high"|"medium"|"low", "note": str}}
- exclusion_results: list of {{"criterion": str, "triggered": bool, "confidence": "high"|"medium"|"low", "note": str}}
- summary: string (2-3 sentence clinical reasoning)
- risk_flags: list of strings

Return ONLY valid JSON."""

    try:
        return _parse_json_response(
            chat([{"role": "user", "content": prompt}], temperature=0, max_tokens=1500)
        )
    except Exception:
        return {
            "overall_score": 0.7, "eligible": True,
            "inclusion_results": [], "exclusion_results": [],
            "summary": "Automated assessment pending. Patient profile partially matches trial criteria.",
            "risk_flags": ["Manual review recommended"],
        }


def generate_outreach_message(patient_profile: dict, trial: dict, channel: str) -> str:
    channel_instructions = {
        "pcp_letter": "Write a formal referral letter from a clinical research coordinator to the patient's PCP. Include trial name, NCT number, eligibility criteria met, and next steps.",
        "patient_email": "Write a warm, empathetic email to the patient in plain language (8th grade reading level). Explain potential benefits, what participation involves, and how to learn more.",
        "social_post": "Write a concise social media post (max 280 characters for Twitter, 500 for Facebook) for patient recruitment. No personal identifiers.",
    }
    instruction = channel_instructions.get(channel, channel_instructions["patient_email"])
    prompt = f"""{instruction}

Trial: {trial.get("title")} ({trial.get("nct_id")})
Phase: {trial.get("phase")} | Sponsor: {trial.get("sponsor")}
Summary: {trial.get("brief_summary", "")[:500]}
Locations: {", ".join(f"{l['city']}, {l['state']}" for l in trial.get("locations", [])[:3])}

Patient context (no identifying details):
- Age range: {patient_profile.get("age")} years
- Diagnosis: {", ".join(patient_profile.get("diagnosis_names", ["the relevant condition"]))}

Write the message now:"""
    return chat([{"role": "user", "content": prompt}], temperature=0.7, max_tokens=800)


def summarize_trial(trial: dict) -> str:
    prompt = f"""Summarize this clinical trial in 3-4 bullet points for a clinical coordinator:
what's tested, who qualifies, what patients do, potential benefit.

Trial: {trial.get("title")}
Summary: {trial.get("brief_summary", "")[:1000]}
Eligibility: {trial.get("eligibility_criteria", "")[:800]}
Phase: {trial.get("phase")} | Enrollment: {trial.get("enrollment")}

Bullet points only:"""
    return chat([{"role": "user", "content": prompt}], temperature=0.3, max_tokens=500)