Spaces:
Sleeping
Sleeping
File size: 8,260 Bytes
9007595 25081d1 48b4b07 9007595 96bf754 25081d1 802d399 48b4b07 34c6b6d 32c02d7 c006289 25081d1 c006289 45a92c5 c006289 45a92c5 34c6b6d 32c02d7 25081d1 34c6b6d af26cad 34c6b6d af26cad 34c6b6d 2cef258 aeebc49 2cef258 682c1f1 2cef258 34c6b6d 2cef258 34c6b6d af26cad 34c6b6d f76e4e9 45a92c5 f76e4e9 802d399 34c6b6d c4234e2 34c6b6d c4234e2 34c6b6d 01e42bc c4234e2 34c6b6d af26cad 34c6b6d aeebc49 34c6b6d 808a650 34c6b6d 02bd2ec aeebc49 682c1f1 02bd2ec 34c6b6d 8b6f1af af26cad 32c02d7 af26cad 8189892 af26cad 8189892 af26cad 8189892 48b4b07 8189892 af26cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import sys
import os
import spacy
from groq import Groq
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
from embedding import discharge_collection, trials_collection, get_embedding
from serpapi import GoogleSearch
from pyspur.backend.pyspur.nodes.decorator import tool_function
# Load API key from Hugging Face secret
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("Missing GROQ_API_KEY in environment variables.")
serp_api_key = os.environ.get("SERP_API_KEY")
if not serp_api_key:
raise ValueError("Missing SERP_API_KEY in environment variables.")
# Initialize LLM client and spaCy
client = Groq(api_key=groq_api_key)
nlp = spacy.load("en_core_web_sm")
SYMPTOM_HINTS = [
"chest pain", "shortness of breath", "fatigue", "dizziness",
"nausea", "vomiting", "palpitations", "sweating", "jaw pain",
"arm pain", "back pain", "tightness", "pressure in chest",
"arrhythmia", "tachycardia", "bradycardia", "angina",
"edema", "dyspnea", "syncope", "lightheadedness",
"ejection fraction", "myocardial infarction", "heart failure",
"cardiomyopathy", "cardiac arrest"
]
@tool_function(name="chat_memory_tool")
def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
doc = nlp(memory)
found_symptoms = set(
keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
)
symptom_context = (
f"Previously mentioned symptoms include: {', '.join(found_symptoms)}."
if found_symptoms else "No clear symptoms found in memory."
)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a medical assistant summarizing prior symptoms from memory."},
{"role": "assistant", "content": memory},
{"role": "user", "content": (
f"The patient previously reported: {memory}\n\n"
f"Symptoms extracted: {symptom_context}\n"
"Please provide a clear, concise, and helpful summary of these symptoms and suggest next steps."
)}
]
)
return response.choices[0].message.content
@tool_function(name="treatment_tool")
def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
try:
query_embedding = get_embedding(query)
if use_rag:
results = discharge_collection.query(
query_embeddings=[query_embedding],
n_results=5,
include=["documents"]
)
top_docs = results['documents'][0] if results and results['documents'] else []
top_docs = [doc[:1500] for doc in top_docs]
combined_context = "\n\n".join(top_docs)
prompt = (
"You are a helpful medical assistant. Based on the following discharge notes, "
"recommend essential treatment.\n\n"
f"### Notes:\n{combined_context}\n\n### Condition:\n{query}"
)
else:
prompt = f"Patient condition: {query}. What treatment is recommended?"
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a medically accurate and safety-focused clinical assistant."},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {str(e)}"
@tool_function(name="symptom_search_tool")
def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
def perform_search(query):
params = {
"engine": "google",
"q": f"{query} possible causes site:mayoclinic.org OR site:webmd.com OR site:nih.gov",
"api_key": serp_api_key
}
return GoogleSearch(params).get_dict().get("organic_results", [])
try:
results = perform_search(symptom_description)
if not results:
return "No reliable medical source found."
sources = []
snippets_with_citations = []
for res in results[:3]:
if 'snippet' in res and 'link' in res:
source_url = res['link']
domain = source_url.split("//")[-1].split("/")[0].replace("www.", "")
snippets_with_citations.append(f"{res['snippet']} (Source: {domain})")
sources.append(source_url)
search_context = "\n\n".join(snippets_with_citations)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a medical assistant using trusted web sources to explain symptom causes."},
{"role": "assistant", "content": search_context},
{"role": "user", "content": f"What could be the cause of: {symptom_description}?"}
]
)
bulleted_sources = "\n".join(f"- {url}" for url in sources)
return response.choices[0].message.content + "\n\n**Sources:**\n" + bulleted_sources
except Exception as e:
return f"Search error: {str(e)}"
@tool_function(name="trial_matcher_tool")
def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
try:
query_embedding = get_embedding(discharge_note)
results = trials_collection.query(
query_embeddings=[query_embedding],
n_results=3,
include=["documents", "metadatas"]
)
if not results.get('documents') or not results['documents'][0]:
return "No matching clinical trials were found for the provided note."
summaries = []
for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
nct_id = meta.get("NCT ID") or "Unknown ID"
truncated_doc = doc.strip()[:1500]
if use_rag:
summary_prompt = (
f"You are a clinical assistant reviewing a matched clinical trial.\n"
f"Summarize the trial using **bullet points only** for the following fields:\n"
f"- NCT ID\n- Study Title\n- Conditions\n- Inclusion Criteria\n- Exclusion Criteria\n\n"
f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a medically precise clinical research assistant."},
{"role": "user", "content": summary_prompt}
]
)
summaries.append(f"### Trial {i+1}:\n{response.choices[0].message.content}")
else:
summaries.append(f"### Trial {i+1}:\nNCT ID: {nct_id}\n\n{truncated_doc}")
return "\n\n---\n\n".join(summaries)
except Exception as e:
return f"Error during trial matching: {str(e)}"
# Tool routing via keyword logic
TOOL_ROUTER = {
"symptom": ("symptom_search_tool", False),
"treatment": ("treatment_tool", True),
"trial": ("trial_matcher_tool", True)
}
TOOL_FUNCTIONS = {
"chat_memory_tool": chat_memory_tool,
"treatment_tool": treatment_tool,
"symptom_search_tool": symptom_search_tool,
"trial_matcher_tool": trial_matcher_tool
}
def run_tool(query: str, model: str, use_rag: bool) -> str:
for keyword, (tool_name, supports_rag) in TOOL_ROUTER.items():
if keyword in query.lower():
print(f"Tool selected by PySpur: {tool_name}")
tool_func = TOOL_FUNCTIONS[tool_name]
if supports_rag:
return tool_func(query, model=model, use_rag=use_rag)
else:
return tool_func(query, model=model)
print("Tool selected by PySpur: chat_memory_tool")
return chat_memory_tool(query, model=model) |