File size: 8,260 Bytes
9007595
 
25081d1
 
48b4b07
9007595
96bf754
25081d1
802d399
48b4b07
34c6b6d
32c02d7
c006289
25081d1
c006289
 
 
45a92c5
c006289
 
45a92c5
34c6b6d
32c02d7
25081d1
34c6b6d
 
 
 
 
 
 
 
 
 
 
 
af26cad
 
34c6b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af26cad
 
34c6b6d
 
2cef258
aeebc49
2cef258
 
 
 
 
 
 
 
682c1f1
 
2cef258
 
 
 
 
34c6b6d
 
 
 
 
 
 
 
2cef258
34c6b6d
 
 
af26cad
 
34c6b6d
f76e4e9
 
 
45a92c5
f76e4e9
802d399
34c6b6d
 
 
 
c4234e2
 
 
 
 
 
 
 
 
 
 
 
34c6b6d
 
 
c4234e2
34c6b6d
 
 
 
01e42bc
 
 
c4234e2
34c6b6d
 
 
af26cad
 
34c6b6d
 
aeebc49
34c6b6d
 
 
 
808a650
34c6b6d
 
02bd2ec
 
aeebc49
682c1f1
02bd2ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34c6b6d
 
8b6f1af
 
af26cad
32c02d7
af26cad
8189892
 
 
 
 
 
 
 
 
 
af26cad
 
 
8189892
af26cad
8189892
 
48b4b07
 
 
 
8189892
 
af26cad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import sys
import os
import spacy
from groq import Groq
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))

from embedding import discharge_collection, trials_collection, get_embedding

from serpapi import GoogleSearch
from pyspur.backend.pyspur.nodes.decorator import tool_function

# Load API key from Hugging Face secret

groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
    raise ValueError("Missing GROQ_API_KEY in environment variables.")

serp_api_key = os.environ.get("SERP_API_KEY")
if not serp_api_key:
    raise ValueError("Missing SERP_API_KEY in environment variables.")

# Initialize LLM client and spaCy

client = Groq(api_key=groq_api_key)
nlp = spacy.load("en_core_web_sm")

SYMPTOM_HINTS = [
    "chest pain", "shortness of breath", "fatigue", "dizziness",
    "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
    "arm pain", "back pain", "tightness", "pressure in chest",
    "arrhythmia", "tachycardia", "bradycardia", "angina",
    "edema", "dyspnea", "syncope", "lightheadedness",
    "ejection fraction", "myocardial infarction", "heart failure",
    "cardiomyopathy", "cardiac arrest"
]

@tool_function(name="chat_memory_tool")
def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
    doc = nlp(memory)
    found_symptoms = set(
        keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
    )
    symptom_context = (
        f"Previously mentioned symptoms include: {', '.join(found_symptoms)}."
        if found_symptoms else "No clear symptoms found in memory."
    )
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a medical assistant summarizing prior symptoms from memory."},
            {"role": "assistant", "content": memory},
            {"role": "user", "content": (
                f"The patient previously reported: {memory}\n\n"
                f"Symptoms extracted: {symptom_context}\n"
                "Please provide a clear, concise, and helpful summary of these symptoms and suggest next steps."
            )}
        ]
    )
    return response.choices[0].message.content

@tool_function(name="treatment_tool")
def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
    try:
        query_embedding = get_embedding(query)
        if use_rag:
            results = discharge_collection.query(
                query_embeddings=[query_embedding],
                n_results=5,
                include=["documents"]
            )
            top_docs = results['documents'][0] if results and results['documents'] else []
            top_docs = [doc[:1500] for doc in top_docs]
            combined_context = "\n\n".join(top_docs)
            prompt = (
                "You are a helpful medical assistant. Based on the following discharge notes, "
                "recommend essential treatment.\n\n"
                f"### Notes:\n{combined_context}\n\n### Condition:\n{query}"
            )
        else:
            prompt = f"Patient condition: {query}. What treatment is recommended?"

        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a medically accurate and safety-focused clinical assistant."},
                {"role": "user", "content": prompt}
            ]
        )
        return response.choices[0].message.content

    except Exception as e:
        return f"Error: {str(e)}"

@tool_function(name="symptom_search_tool")
def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
    def perform_search(query):
        params = {
            "engine": "google",
            "q": f"{query} possible causes site:mayoclinic.org OR site:webmd.com OR site:nih.gov",
            "api_key": serp_api_key
        }
        return GoogleSearch(params).get_dict().get("organic_results", [])

    try:
        results = perform_search(symptom_description)
        if not results:
            return "No reliable medical source found."

        sources = []
        snippets_with_citations = []
        for res in results[:3]:
            if 'snippet' in res and 'link' in res:
                source_url = res['link']
                domain = source_url.split("//")[-1].split("/")[0].replace("www.", "")
                snippets_with_citations.append(f"{res['snippet']} (Source: {domain})")
                sources.append(source_url)

        search_context = "\n\n".join(snippets_with_citations)
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a medical assistant using trusted web sources to explain symptom causes."},
                {"role": "assistant", "content": search_context},
                {"role": "user", "content": f"What could be the cause of: {symptom_description}?"}
            ]
        )

        bulleted_sources = "\n".join(f"- {url}" for url in sources)
        return response.choices[0].message.content + "\n\n**Sources:**\n" + bulleted_sources

    except Exception as e:
        return f"Search error: {str(e)}"

@tool_function(name="trial_matcher_tool")
def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
    try:
        query_embedding = get_embedding(discharge_note)
        results = trials_collection.query(
            query_embeddings=[query_embedding],
            n_results=3,
            include=["documents", "metadatas"]
        )
        if not results.get('documents') or not results['documents'][0]:
            return "No matching clinical trials were found for the provided note."

        summaries = []
        for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
            nct_id = meta.get("NCT ID") or "Unknown ID"
            truncated_doc = doc.strip()[:1500]
            if use_rag:
                summary_prompt = (
                    f"You are a clinical assistant reviewing a matched clinical trial.\n"
                    f"Summarize the trial using **bullet points only** for the following fields:\n"
                    f"- NCT ID\n- Study Title\n- Conditions\n- Inclusion Criteria\n- Exclusion Criteria\n\n"
                    f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
                    f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
                )
                response = client.chat.completions.create(
                    model=model,
                    messages=[
                        {"role": "system", "content": "You are a medically precise clinical research assistant."},
                        {"role": "user", "content": summary_prompt}
                    ]
                )
                summaries.append(f"### Trial {i+1}:\n{response.choices[0].message.content}")
            else:
                summaries.append(f"### Trial {i+1}:\nNCT ID: {nct_id}\n\n{truncated_doc}")

        return "\n\n---\n\n".join(summaries)

    except Exception as e:
        return f"Error during trial matching: {str(e)}"

# Tool routing via keyword logic

TOOL_ROUTER = {
    "symptom": ("symptom_search_tool", False),
    "treatment": ("treatment_tool", True),
    "trial": ("trial_matcher_tool", True)
}

TOOL_FUNCTIONS = {
    "chat_memory_tool": chat_memory_tool,
    "treatment_tool": treatment_tool,
    "symptom_search_tool": symptom_search_tool,
    "trial_matcher_tool": trial_matcher_tool
}

def run_tool(query: str, model: str, use_rag: bool) -> str:
    for keyword, (tool_name, supports_rag) in TOOL_ROUTER.items():
        if keyword in query.lower():
            print(f"Tool selected by PySpur: {tool_name}")
            tool_func = TOOL_FUNCTIONS[tool_name]
            if supports_rag:
                return tool_func(query, model=model, use_rag=use_rag)
            else:
                return tool_func(query, model=model)

    print("Tool selected by PySpur: chat_memory_tool")
    return chat_memory_tool(query, model=model)