| | import os |
| | import re |
| | import json |
| | import logging |
| | import traceback |
| | from functools import lru_cache |
| | from typing import List, Dict, Any, Optional, TypedDict |
| |
|
| | import requests |
| | from langchain_groq import ChatGroq |
| | from langchain_community.tools.tavily_search import TavilySearchResults |
| | from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage |
| | from langchain_core.pydantic_v1 import BaseModel, Field |
| | from langchain_core.tools import tool |
| | from langgraph.prebuilt import ToolExecutor |
| | from langgraph.graph import StateGraph, END |
| |
|
| | |
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | |
| | UMLS_API_KEY = os.getenv("UMLS_API_KEY") |
| | GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
| | TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
| | if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]): |
| | logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY") |
| | raise RuntimeError("Missing required API keys") |
| |
|
| | |
| | AGENT_MODEL_NAME = "llama3-70b-8192" |
| | AGENT_TEMPERATURE = 0.1 |
| | MAX_SEARCH_RESULTS = 3 |
| |
|
| | class ClinicalPrompts: |
| | SYSTEM_PROMPT = """ |
| | You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation... |
| | [SYSTEM PROMPT CONTENT HERE] |
| | """ |
| |
|
| | |
| | def wrap_message(msg: Any) -> AIMessage: |
| | """ |
| | Ensures the given message is an AIMessage. |
| | If it is a dict, extracts the 'content' field (or serializes the dict). |
| | Otherwise, converts the message to a string. |
| | """ |
| | if isinstance(msg, AIMessage): |
| | return msg |
| | elif isinstance(msg, dict): |
| | return AIMessage(content=msg.get("content", json.dumps(msg))) |
| | else: |
| | return AIMessage(content=str(msg)) |
| |
|
| | def normalize_messages(state: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Normalizes all messages in the state to be AIMessage objects. |
| | """ |
| | state["messages"] = [wrap_message(m) for m in state.get("messages", [])] |
| | return state |
| |
|
| | |
| | UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key" |
| | RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST" |
| | OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json" |
| |
|
| | @lru_cache(maxsize=256) |
| | def get_rxcui(drug_name: str) -> Optional[str]: |
| | """Lookup RxNorm CUI for a given drug name.""" |
| | drug_name = (drug_name or "").strip() |
| | if not drug_name: |
| | return None |
| | logger.info(f"Looking up RxCUI for '{drug_name}'") |
| | try: |
| | params = {"name": drug_name, "search": 1} |
| | r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10) |
| | r.raise_for_status() |
| | ids = r.json().get("idGroup", {}).get("rxnormId") |
| | if ids: |
| | logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'") |
| | return ids[0] |
| | r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10) |
| | r.raise_for_status() |
| | for grp in r.json().get("drugGroup", {}).get("conceptGroup", []): |
| | props = grp.get("conceptProperties") |
| | if props: |
| | logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'") |
| | return props[0]["rxcui"] |
| | except Exception: |
| | logger.exception(f"Error fetching RxCUI for '{drug_name}'") |
| | return None |
| |
|
| | @lru_cache(maxsize=128) |
| | def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]: |
| | """Fetch the OpenFDA label for a drug by RxCUI or name.""" |
| | if not (rxcui or drug_name): |
| | return None |
| | terms = [] |
| | if rxcui: |
| | terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"') |
| | if drug_name: |
| | dn = drug_name.lower() |
| | terms.append(f'(openfda.brand_name:"{dn}" OR openfda.generic_name:"{dn}")') |
| | query = " OR ".join(terms) |
| | logger.info(f"Looking up OpenFDA label with query: {query}") |
| | try: |
| | r = requests.get(OPENFDA_API_BASE, params={"search": query, "limit": 1}, timeout=15) |
| | r.raise_for_status() |
| | results = r.json().get("results", []) |
| | if results: |
| | return results[0] |
| | except Exception: |
| | logger.exception("Error fetching OpenFDA label") |
| | return None |
| |
|
| | def search_text_list(texts: List[str], terms: List[str]) -> List[str]: |
| | """Return highlighted snippets from a list of texts containing any of the search terms.""" |
| | snippets = [] |
| | lowers = [t.lower() for t in terms if t] |
| | for text in texts or []: |
| | tl = text.lower() |
| | for term in lowers: |
| | if term in tl: |
| | i = tl.find(term) |
| | start = max(0, i - 50) |
| | end = min(len(text), i + len(term) + 100) |
| | snippet = text[start:end] |
| | snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE) |
| | snippets.append(f"...{snippet}...") |
| | break |
| | return snippets |
| |
|
| | def parse_bp(bp: str) -> Optional[tuple[int, int]]: |
| | """Parse 'SYS/DIA' blood pressure string into a (sys, dia) tuple.""" |
| | if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()): |
| | return int(m.group(1)), int(m.group(2)) |
| | return None |
| |
|
| | def check_red_flags(patient_data: Dict[str, Any]) -> List[str]: |
| | """Identify immediate red flags from patient_data.""" |
| | flags: List[str] = [] |
| | hpi = patient_data.get("hpi", {}) |
| | vitals = patient_data.get("vitals", {}) |
| | syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)] |
| | mapping = { |
| | "chest pain": "Chest pain reported", |
| | "shortness of breath": "Shortness of breath reported", |
| | "severe headache": "Severe headache reported", |
| | "syncope": "Syncope reported", |
| | "hemoptysis": "Hemoptysis reported" |
| | } |
| | for term, desc in mapping.items(): |
| | if term in syms: |
| | flags.append(f"Red Flag: {desc}.") |
| | temp = vitals.get("temp_c") |
| | hr = vitals.get("hr_bpm") |
| | rr = vitals.get("rr_rpm") |
| | spo2 = vitals.get("spo2_percent") |
| | bp = parse_bp(vitals.get("bp_mmhg", "")) |
| | if temp is not None and temp >= 38.5: |
| | flags.append(f"Red Flag: Fever ({temp}Β°C).") |
| | if hr is not None: |
| | if hr >= 120: |
| | flags.append(f"Red Flag: Tachycardia ({hr} bpm).") |
| | if hr <= 50: |
| | flags.append(f"Red Flag: Bradycardia ({hr} bpm).") |
| | if rr is not None and rr >= 24: |
| | flags.append(f"Red Flag: Tachypnea ({rr} rpm).") |
| | if spo2 is not None and spo2 <= 92: |
| | flags.append(f"Red Flag: Hypoxia ({spo2}%).") |
| | if bp: |
| | sys, dia = bp |
| | if sys >= 180 or dia >= 110: |
| | flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).") |
| | if sys <= 90 or dia <= 60: |
| | flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).") |
| | return list(dict.fromkeys(flags)) |
| |
|
| | def format_patient_data_for_prompt(data: Dict[str, Any]) -> str: |
| | """Format patient_data dict into a markdown-like prompt section.""" |
| | if not data: |
| | return "No patient data provided." |
| | lines: List[str] = [] |
| | for section, value in data.items(): |
| | title = section.replace("_", " ").title() |
| | if isinstance(value, dict) and any(value.values()): |
| | lines.append(f"**{title}:**") |
| | for k, v in value.items(): |
| | if v: |
| | lines.append(f"- {k.replace('_',' ').title()}: {v}") |
| | elif isinstance(value, list) and value: |
| | lines.append(f"**{title}:** {', '.join(map(str, value))}") |
| | elif value: |
| | lines.append(f"**{title}:** {value}") |
| | return "\n".join(lines) |
| |
|
| | |
| | class LabOrderInput(BaseModel): |
| | test_name: str = Field(...) |
| | reason: str = Field(...) |
| | priority: str = Field("Routine") |
| |
|
| | class PrescriptionInput(BaseModel): |
| | medication_name: str = Field(...) |
| | dosage: str = Field(...) |
| | route: str = Field(...) |
| | frequency: str = Field(...) |
| | duration: str = Field("As directed") |
| | reason: str = Field(...) |
| |
|
| | class InteractionCheckInput(BaseModel): |
| | potential_prescription: str |
| | current_medications: Optional[List[str]] = Field(None) |
| | allergies: Optional[List[str]] = Field(None) |
| |
|
| | class FlagRiskInput(BaseModel): |
| | risk_description: str = Field(...) |
| | urgency: str = Field("High") |
| |
|
| | |
| | @tool("order_lab_test", args_schema=LabOrderInput) |
| | def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: |
| | """ |
| | Place an order for a laboratory test. |
| | """ |
| | logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}") |
| | return json.dumps({ |
| | "status": "success", |
| | "message": f"Lab Ordered: {test_name} ({priority})", |
| | "details": f"Reason: {reason}" |
| | }) |
| |
|
| | @tool("prescribe_medication", args_schema=PrescriptionInput) |
| | def prescribe_medication( |
| | medication_name: str, |
| | dosage: str, |
| | route: str, |
| | frequency: str, |
| | duration: str, |
| | reason: str |
| | ) -> str: |
| | """ |
| | Prepare a medication prescription. |
| | """ |
| | logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}") |
| | return json.dumps({ |
| | "status": "success", |
| | "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", |
| | "details": f"Duration: {duration}. Reason: {reason}" |
| | }) |
| |
|
| | @tool("check_drug_interactions", args_schema=InteractionCheckInput) |
| | def check_drug_interactions( |
| | potential_prescription: str, |
| | current_medications: Optional[List[str]] = None, |
| | allergies: Optional[List[str]] = None |
| | ) -> str: |
| | """ |
| | Check for drugβdrug interactions and allergy risks. |
| | """ |
| | logger.info(f"Checking interactions for: {potential_prescription}") |
| | warnings: List[str] = [] |
| | pm = [m.lower().strip() for m in (current_medications or []) if m] |
| | al = [a.lower().strip() for a in (allergies or []) if a] |
| | if potential_prescription.lower().strip() in al: |
| | warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.") |
| | rxcui = get_rxcui(potential_prescription) |
| | label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription) |
| | if not (rxcui or label): |
| | warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.") |
| | for section in ("contraindications", "warnings_and_cautions", "warnings"): |
| | items = label.get(section) if label else None |
| | if isinstance(items, list): |
| | snippets = search_text_list(items, al) |
| | if snippets: |
| | warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}") |
| | for med in pm: |
| | mrxcui = get_rxcui(med) |
| | mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med) |
| | for sec in ("drug_interactions",): |
| | for src_label, src_name in ((label, potential_prescription), (mlabel, med)): |
| | items = src_label.get(sec) if src_label else None |
| | if isinstance(items, list): |
| | snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription]) |
| | if snippets: |
| | warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}") |
| | status = "warning" if warnings else "clear" |
| | message = ( |
| | f"{len(warnings)} issue(s) found for '{potential_prescription}'." |
| | if warnings else |
| | f"No major interactions or allergy issues identified for '{potential_prescription}'." |
| | ) |
| | return json.dumps({"status": status, "message": message, "warnings": warnings}) |
| |
|
| | @tool("flag_risk", args_schema=FlagRiskInput) |
| | def flag_risk(risk_description: str, urgency: str = "High") -> str: |
| | """ |
| | Flag a clinical risk with given urgency. |
| | """ |
| | logger.info(f"Flagging risk: {risk_description} (urgency={urgency})") |
| | return json.dumps({ |
| | "status": "flagged", |
| | "message": f"Risk '{risk_description}' flagged with {urgency} urgency." |
| | }) |
| |
|
| | |
| | search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results") |
| | all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool] |
| |
|
| | |
| | llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME) |
| | model_with_tools = llm.bind_tools(all_tools) |
| | tool_executor = ToolExecutor(all_tools) |
| |
|
| | |
| | class AgentState(TypedDict): |
| | messages: List[Any] |
| | patient_data: Optional[Dict[str, Any]] |
| | summary: Optional[str] |
| | interaction_warnings: Optional[List[str]] |
| | done: Optional[bool] |
| | iterations: Optional[int] |
| |
|
| | |
| | def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]: |
| | for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]: |
| | if key in old and key not in new: |
| | new[key] = old[key] |
| | return new |
| |
|
| | |
| | def agent_node(state: AgentState) -> Dict[str, Any]: |
| | state = normalize_messages(state) |
| | if state.get("done", False): |
| | return state |
| | msgs = state.get("messages", []) |
| | if not msgs or not isinstance(msgs[0], SystemMessage): |
| | msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs |
| | logger.info(f"Invoking LLM with {len(msgs)} messages") |
| | try: |
| | response = model_with_tools.invoke(msgs) |
| | response = wrap_message(response) |
| | new_state = {"messages": [response]} |
| | return propagate_state(new_state, state) |
| | except Exception as e: |
| | logger.exception("Error in agent_node") |
| | new_state = {"messages": [wrap_message(AIMessage(content=f"Error: {e}"))]} |
| | return propagate_state(new_state, state) |
| |
|
| | def tool_node(state: AgentState) -> Dict[str, Any]: |
| | state = normalize_messages(state) |
| | if state.get("done", False): |
| | return state |
| | messages_list = state.get("messages", []) |
| | if not messages_list: |
| | logger.warning("tool_node invoked with no messages") |
| | new_state = {"messages": []} |
| | return propagate_state(new_state, state) |
| | last = wrap_message(messages_list[-1]) |
| | tool_calls = last.__dict__.get("tool_calls") |
| | if not (isinstance(last, AIMessage) and tool_calls): |
| | logger.warning("tool_node invoked without pending tool_calls") |
| | new_state = {"messages": []} |
| | return propagate_state(new_state, state) |
| | calls = tool_calls |
| | blocked_ids = set() |
| | for call in calls: |
| | if call["name"] == "prescribe_medication": |
| | med = call["args"].get("medication_name", "").lower() |
| | if not any( |
| | c["name"] == "check_drug_interactions" and |
| | c["args"].get("potential_prescription", "").lower() == med |
| | for c in calls |
| | ): |
| | logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check") |
| | blocked_ids.add(call["id"]) |
| | to_execute = [c for c in calls if c["id"] not in blocked_ids] |
| | pd = state.get("patient_data", {}) |
| | for call in to_execute: |
| | if call["name"] == "check_drug_interactions": |
| | call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", [])) |
| | call["args"].setdefault("allergies", pd.get("allergies", [])) |
| | messages: List[ToolMessage] = [] |
| | warnings: List[str] = [] |
| | try: |
| | responses = tool_executor.batch(to_execute, return_exceptions=True) |
| | for call, resp in zip(to_execute, responses): |
| | if isinstance(resp, Exception): |
| | logger.exception(f"Error executing tool {call['name']}") |
| | content = json.dumps({"status": "error", "message": str(resp)}) |
| | else: |
| | content = str(resp) |
| | if call["name"] == "check_drug_interactions": |
| | data = json.loads(content) |
| | if data.get("status") == "warning": |
| | warnings.extend(data.get("warnings", [])) |
| | messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"])) |
| | except Exception as e: |
| | logger.exception("Critical error in tool_node") |
| | for call in to_execute: |
| | messages.append(ToolMessage( |
| | content=json.dumps({"status": "error", "message": str(e)}), |
| | tool_call_id=call["id"], |
| | name=call["name"] |
| | )) |
| | new_state = {"messages": messages, "interaction_warnings": warnings or None} |
| | return propagate_state(new_state, state) |
| |
|
| | def reflection_node(state: AgentState) -> Dict[str, Any]: |
| | state = normalize_messages(state) |
| | if state.get("done", False): |
| | return state |
| | warns = state.get("interaction_warnings") |
| | if not warns: |
| | logger.warning("reflection_node called without warnings") |
| | new_state = {"messages": []} |
| | return propagate_state(new_state, state) |
| | triggering = None |
| | for msg in reversed(state.get("messages", [])): |
| | wrapped = wrap_message(msg) |
| | if isinstance(wrapped, AIMessage) and wrapped.__dict__.get("tool_calls"): |
| | triggering = wrapped |
| | break |
| | if not triggering: |
| | new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]} |
| | return propagate_state(new_state, state) |
| | prompt = ( |
| | "You are SynapseAI, performing a focused safety review of the following plan:\n\n" |
| | f"{triggering.content}\n\n" |
| | "Highlight any issues based on these warnings:\n" + |
| | "\n".join(f"- {w}" for w in warns) |
| | ) |
| | try: |
| | resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)]) |
| | new_state = {"messages": [wrap_message(resp)]} |
| | return propagate_state(new_state, state) |
| | except Exception as e: |
| | logger.exception("Error during reflection") |
| | new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]} |
| | return propagate_state(new_state, state) |
| |
|
| | |
| | def should_continue(state: AgentState) -> str: |
| | state = normalize_messages(state) |
| | state.setdefault("iterations", 0) |
| | state["iterations"] += 1 |
| | logger.info(f"Iteration count: {state['iterations']}") |
| | if state["iterations"] >= 4: |
| | state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete.")) |
| | state["done"] = True |
| | return "end_conversation_turn" |
| | if not state.get("messages"): |
| | state["done"] = True |
| | return "end_conversation_turn" |
| | last = wrap_message(state["messages"][-1]) |
| | if not isinstance(last, AIMessage): |
| | state["done"] = True |
| | return "end_conversation_turn" |
| | if last.__dict__.get("tool_calls"): |
| | return "continue_tools" |
| | if "consultation complete" in last.content.lower(): |
| | state["done"] = True |
| | return "end_conversation_turn" |
| | |
| | state["done"] = True |
| | return "end_conversation_turn" |
| |
|
| | def after_tools_router(state: AgentState) -> str: |
| | if state.get("interaction_warnings"): |
| | return "reflection" |
| | return "end_conversation_turn" |
| |
|
| | |
| | class ClinicalAgent: |
| | def __init__(self): |
| | logger.info("Building ClinicalAgent workflow") |
| | wf = StateGraph(AgentState) |
| | wf.add_node("start", agent_node) |
| | wf.add_node("tools", tool_node) |
| | wf.add_node("reflection", reflection_node) |
| | wf.set_entry_point("start") |
| | wf.add_conditional_edges("start", should_continue, { |
| | "continue_tools": "tools", |
| | "end_conversation_turn": END |
| | }) |
| | wf.add_conditional_edges("tools", after_tools_router, { |
| | "reflection": "reflection", |
| | "end_conversation_turn": END |
| | }) |
| | self.graph_app = wf.compile() |
| | logger.info("ClinicalAgent ready") |
| |
|
| | def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]: |
| | try: |
| | result = self.graph_app.invoke(state, {"recursion_limit": 100}) |
| | result.setdefault("summary", state.get("summary")) |
| | result.setdefault("interaction_warnings", None) |
| | return result |
| | except Exception as e: |
| | logger.exception("Error during graph invocation") |
| | return { |
| | "messages": state.get("messages", []) + [AIMessage(content=f"Error: {e}")], |
| | "patient_data": state.get("patient_data"), |
| | "summary": state.get("summary"), |
| | "interaction_warnings": None |
| | } |
| |
|