shibbir24 commited on
Commit
dca131e
·
verified ·
1 Parent(s): 260420a

Update tool_handler.py

Browse files
Files changed (1) hide show
  1. tool_handler.py +210 -200
tool_handler.py CHANGED
@@ -1,200 +1,210 @@
1
- import sys
2
- import os
3
- import spacy
4
- from groq import Groq
5
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
6
-
7
- from embedding import discharge_collection, trials_collection, get_embedding
8
-
9
- from serpapi import GoogleSearch
10
- from pyspur.backend.pyspur.nodes.decorator import tool_function
11
-
12
- # Load API key from Hugging Face secret
13
-
14
- groq_api_key = os.getenv("GROQ_API_KEY")
15
- if not groq_api_key:
16
- raise ValueError("Missing GROQ_API_KEY in environment variables.")
17
-
18
- serp_api_key = os.environ.get("SERP_API_KEY")
19
- if not serp_api_key:
20
- raise ValueError("Missing SERP_API_KEY in environment variables.")
21
-
22
- # Initialize LLM client and spaCy
23
-
24
- client = Groq(api_key=groq_api_key)
25
- nlp = spacy.load("en_core_web_sm")
26
-
27
- SYMPTOM_HINTS = [
28
- "chest pain", "shortness of breath", "fatigue", "dizziness",
29
- "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
30
- "arm pain", "back pain", "tightness", "pressure in chest",
31
- "arrhythmia", "tachycardia", "bradycardia", "angina",
32
- "edema", "dyspnea", "syncope", "lightheadedness",
33
- "ejection fraction", "myocardial infarction", "heart failure",
34
- "cardiomyopathy", "cardiac arrest"
35
- ]
36
-
37
- @tool_function(name="chat_memory_tool")
38
- def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
39
- doc = nlp(memory)
40
- found_symptoms = set(
41
- keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
42
- )
43
- symptom_context = (
44
- f"Previously mentioned symptoms include: {', '.join(found_symptoms)}."
45
- if found_symptoms else "No clear symptoms found in memory."
46
- )
47
- response = client.chat.completions.create(
48
- model=model,
49
- messages=[
50
- {"role": "system", "content": "You are a medical assistant summarizing prior symptoms from memory."},
51
- {"role": "assistant", "content": memory},
52
- {"role": "user", "content": (
53
- f"The patient previously reported: {memory}\n\n"
54
- f"Symptoms extracted: {symptom_context}\n"
55
- "Please provide a clear, concise, and helpful summary of these symptoms and suggest next steps."
56
- )}
57
- ]
58
- )
59
- return response.choices[0].message.content
60
-
61
- @tool_function(name="treatment_tool")
62
- def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
63
- try:
64
- query_embedding = get_embedding(query)
65
- if use_rag:
66
- results = discharge_collection.query(
67
- query_embeddings=[query_embedding],
68
- n_results=5,
69
- include=["documents"]
70
- )
71
- top_docs = results['documents'][0] if results and results['documents'] else []
72
- top_docs = [doc[:1500] for doc in top_docs]
73
- combined_context = "\n\n".join(top_docs)
74
- prompt = (
75
- "You are a helpful medical assistant. Based on the following discharge notes, "
76
- "recommend essential treatment.\n\n"
77
- f"### Notes:\n{combined_context}\n\n### Condition:\n{query}"
78
- )
79
- else:
80
- prompt = f"Patient condition: {query}. What treatment is recommended?"
81
-
82
- response = client.chat.completions.create(
83
- model=model,
84
- messages=[
85
- {"role": "system", "content": "You are a medically accurate and safety-focused clinical assistant."},
86
- {"role": "user", "content": prompt}
87
- ]
88
- )
89
- return response.choices[0].message.content
90
-
91
- except Exception as e:
92
- return f"Error: {str(e)}"
93
-
94
- @tool_function(name="symptom_search_tool")
95
- def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
96
- def perform_search(query):
97
- params = {
98
- "engine": "google",
99
- "q": f"{query} possible causes site:mayoclinic.org OR site:webmd.com OR site:nih.gov",
100
- "api_key": serp_api_key
101
- }
102
- return GoogleSearch(params).get_dict().get("organic_results", [])
103
-
104
- try:
105
- results = perform_search(symptom_description)
106
- if not results:
107
- return "No reliable medical source found."
108
-
109
- sources = []
110
- snippets_with_citations = []
111
- for res in results[:3]:
112
- if 'snippet' in res and 'link' in res:
113
- source_url = res['link']
114
- domain = source_url.split("//")[-1].split("/")[0].replace("www.", "")
115
- snippets_with_citations.append(f"{res['snippet']} (Source: {domain})")
116
- sources.append(source_url)
117
-
118
- search_context = "\n\n".join(snippets_with_citations)
119
- response = client.chat.completions.create(
120
- model=model,
121
- messages=[
122
- {"role": "system", "content": "You are a medical assistant using trusted web sources to explain symptom causes."},
123
- {"role": "assistant", "content": search_context},
124
- {"role": "user", "content": f"What could be the cause of: {symptom_description}?"}
125
- ]
126
- )
127
-
128
- bulleted_sources = "\n".join(f"- {url}" for url in sources)
129
- return response.choices[0].message.content + "\n\n**Sources:**\n" + bulleted_sources
130
-
131
- except Exception as e:
132
- return f"Search error: {str(e)}"
133
-
134
- @tool_function(name="trial_matcher_tool")
135
- def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
136
- try:
137
- query_embedding = get_embedding(discharge_note)
138
- results = trials_collection.query(
139
- query_embeddings=[query_embedding],
140
- n_results=3,
141
- include=["documents", "metadatas"]
142
- )
143
- if not results.get('documents') or not results['documents'][0]:
144
- return "No matching clinical trials were found for the provided note."
145
-
146
- summaries = []
147
- for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
148
- nct_id = meta.get("NCT ID") or "Unknown ID"
149
- truncated_doc = doc.strip()[:1500]
150
- if use_rag:
151
- summary_prompt = (
152
- f"You are a clinical assistant reviewing a matched clinical trial.\n"
153
- f"Summarize the trial using **bullet points only** for the following fields:\n"
154
- f"- NCT ID\n- Study Title\n- Conditions\n- Inclusion Criteria\n- Exclusion Criteria\n\n"
155
- f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
156
- f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
157
- )
158
- response = client.chat.completions.create(
159
- model=model,
160
- messages=[
161
- {"role": "system", "content": "You are a medically precise clinical research assistant."},
162
- {"role": "user", "content": summary_prompt}
163
- ]
164
- )
165
- summaries.append(f"### Trial {i+1}:\n{response.choices[0].message.content}")
166
- else:
167
- summaries.append(f"### Trial {i+1}:\nNCT ID: {nct_id}\n\n{truncated_doc}")
168
-
169
- return "\n\n---\n\n".join(summaries)
170
-
171
- except Exception as e:
172
- return f"Error during trial matching: {str(e)}"
173
-
174
- # Tool routing via keyword logic
175
-
176
- TOOL_ROUTER = {
177
- "symptom": ("symptom_search_tool", False),
178
- "treatment": ("treatment_tool", True),
179
- "trial": ("trial_matcher_tool", True)
180
- }
181
-
182
- TOOL_FUNCTIONS = {
183
- "chat_memory_tool": chat_memory_tool,
184
- "treatment_tool": treatment_tool,
185
- "symptom_search_tool": symptom_search_tool,
186
- "trial_matcher_tool": trial_matcher_tool
187
- }
188
-
189
- def run_tool(query: str, model: str, use_rag: bool) -> str:
190
- for keyword, (tool_name, supports_rag) in TOOL_ROUTER.items():
191
- if keyword in query.lower():
192
- print(f"Tool selected by PySpur: {tool_name}")
193
- tool_func = TOOL_FUNCTIONS[tool_name]
194
- if supports_rag:
195
- return tool_func(query, model=model, use_rag=use_rag)
196
- else:
197
- return tool_func(query, model=model)
198
-
199
- print("Tool selected by PySpur: chat_memory_tool")
200
- return chat_memory_tool(query, model=model)
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import spacy
4
+ from groq import Groq
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
6
+
7
+ from embedding import discharge_collection, trials_collection, get_embedding
8
+
9
+ from serpapi import GoogleSearch
10
+ from pyspur.backend.pyspur.nodes.decorator import tool_function
11
+
12
+ # Load API key from Hugging Face secret
13
+
14
+ groq_api_key = os.getenv("GROQ_API_KEY")
15
+ if not groq_api_key:
16
+ raise ValueError("Missing GROQ_API_KEY in environment variables.")
17
+
18
+ serp_api_key = os.environ.get("SERP_API_KEY")
19
+ if not serp_api_key:
20
+ raise ValueError("Missing SERP_API_KEY in environment variables.")
21
+
22
+ # Initialize LLM client and spaCy
23
+
24
+ client = Groq(api_key=groq_api_key)
25
+ nlp = spacy.load("en_core_web_sm")
26
+
27
+ SYMPTOM_HINTS = [
28
+ "chest pain", "shortness of breath", "fatigue", "dizziness",
29
+ "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
30
+ "arm pain", "back pain", "tightness", "pressure in chest",
31
+ "arrhythmia", "tachycardia", "bradycardia", "angina",
32
+ "edema", "dyspnea", "syncope", "lightheadedness",
33
+ "ejection fraction", "myocardial infarction", "heart failure",
34
+ "cardiomyopathy", "cardiac arrest"
35
+ ]
36
+
37
+ @tool_function(name="chat_memory_tool")
38
+ def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
39
+ doc = nlp(memory)
40
+ found_symptoms = set(
41
+ keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
42
+ )
43
+ symptom_context = (
44
+ f"Previously mentioned symptoms include: {', '.join(found_symptoms)}."
45
+ if found_symptoms else "No clear symptoms found in memory."
46
+ )
47
+ response = client.chat.completions.create(
48
+ model=model,
49
+ messages=[
50
+ {"role": "system", "content": "You are a medical assistant summarizing prior symptoms from memory."},
51
+ {"role": "assistant", "content": memory},
52
+ {"role": "user", "content": (
53
+ f"The patient previously reported: {memory}\n\n"
54
+ f"Symptoms extracted: {symptom_context}\n"
55
+ "Please provide a clear, concise, and helpful summary of these symptoms and suggest next steps."
56
+ )}
57
+ ]
58
+ )
59
+ return response.choices[0].message.content
60
+
61
+ @tool_function(name="treatment_tool")
62
+ def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
63
+ try:
64
+ query_embedding = get_embedding(query)
65
+ if use_rag:
66
+ results = discharge_collection.query(
67
+ query_embeddings=[query_embedding],
68
+ n_results=5,
69
+ include=["documents"]
70
+ )
71
+ top_docs = results['documents'][0] if results and results['documents'] else []
72
+ top_docs = [doc[:1500] for doc in top_docs]
73
+ combined_context = "\n\n".join(top_docs)
74
+ prompt = (
75
+ "You are a helpful medical assistant. Based on the following discharge notes, "
76
+ "recommend essential treatment.\n\n"
77
+ f"### Notes:\n{combined_context}\n\n### Condition:\n{query}"
78
+ )
79
+ else:
80
+ prompt = f"Patient condition: {query}. What treatment is recommended?"
81
+
82
+ response = client.chat.completions.create(
83
+ model=model,
84
+ messages=[
85
+ {"role": "system", "content": "You are a medically accurate and safety-focused clinical assistant."},
86
+ {"role": "user", "content": prompt}
87
+ ]
88
+ )
89
+ return response.choices[0].message.content
90
+
91
+ except Exception as e:
92
+ return f"Error: {str(e)}"
93
+
94
+ @tool_function(name="symptom_search_tool")
95
+ def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
96
+ def perform_search(query):
97
+ params = {
98
+ "engine": "google",
99
+ "q": f"{query} possible causes site:mayoclinic.org OR site:webmd.com OR site:nih.gov",
100
+ "api_key": serp_api_key
101
+ }
102
+ return GoogleSearch(params).get_dict().get("organic_results", [])
103
+
104
+ try:
105
+ results = perform_search(symptom_description)
106
+ if not results:
107
+ return "No reliable medical source found."
108
+
109
+ sources = []
110
+ snippets_with_citations = []
111
+ for res in results[:3]:
112
+ if 'snippet' in res and 'link' in res:
113
+ source_url = res['link']
114
+ domain = source_url.split("//")[-1].split("/")[0].replace("www.", "")
115
+ snippets_with_citations.append(f"{res['snippet']} (Source: {domain})")
116
+ sources.append(source_url)
117
+
118
+ search_context = "\n\n".join(snippets_with_citations)
119
+ response = client.chat.completions.create(
120
+ model=model,
121
+ messages=[
122
+ {"role": "system", "content": "You are a medical assistant using trusted web sources to explain symptom causes."},
123
+ {"role": "assistant", "content": search_context},
124
+ {"role": "user", "content": f"What could be the cause of: {symptom_description}?"}
125
+ ]
126
+ )
127
+
128
+ bulleted_sources = "\n".join(f"- {url}" for url in sources)
129
+ return response.choices[0].message.content + "\n\n**Sources:**\n" + bulleted_sources
130
+
131
+ except Exception as e:
132
+ return f"Search error: {str(e)}"
133
+
134
+ @tool_function(name="trial_matcher_tool")
135
+ def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
136
+ try:
137
+ query_embedding = get_embedding(discharge_note)
138
+ results = trials_collection.query(
139
+ query_embeddings=[query_embedding],
140
+ n_results=3,
141
+ include=["documents", "metadatas"]
142
+ )
143
+ if not results.get('documents') or not results['documents'][0]:
144
+ return "No matching clinical trials were found for the provided note."
145
+
146
+ summaries = []
147
+ for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
148
+ nct_id = meta.get("NCT ID") or "Unknown ID"
149
+ truncated_doc = doc.strip()[:1500]
150
+ if use_rag:
151
+ summary_prompt = (
152
+ f"You are a clinical assistant reviewing a matched clinical trial.\n"
153
+ f"Summarize the trial using **bullet points only** for the following fields:\n"
154
+ f"- NCT ID\n- Study Title\n- Conditions\n- Inclusion Criteria\n- Exclusion Criteria\n\n"
155
+ f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
156
+ f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
157
+ )
158
+ response = client.chat.completions.create(
159
+ model=model,
160
+ messages=[
161
+ {"role": "system", "content": "You are a medically precise clinical research assistant."},
162
+ {"role": "user", "content": summary_prompt}
163
+ ]
164
+ )
165
+ summaries.append(f"### Trial {i+1}:\n{response.choices[0].message.content}")
166
+ else:
167
+ summaries.append(f"### Trial {i+1}:\nNCT ID: {nct_id}\n\n{truncated_doc}")
168
+
169
+ return "\n\n---\n\n".join(summaries)
170
+
171
+ except Exception as e:
172
+ return f"Error during trial matching: {str(e)}"
173
+
174
+ # Tool routing via keyword logic
175
+
176
+ TOOL_ROUTER = {
177
+ "symptom": ("symptom_search_tool", False),
178
+ "treatment": ("treatment_tool", True),
179
+ "trial": ("trial_matcher_tool", True)
180
+ }
181
+
182
+ TOOL_FUNCTIONS = {
183
+ "chat_memory_tool": chat_memory_tool,
184
+ "treatment_tool": treatment_tool,
185
+ "symptom_search_tool": symptom_search_tool,
186
+ "trial_matcher_tool": trial_matcher_tool
187
+ }
188
+
189
+ def run_tool(query: str, model: str, use_rag: bool):
190
+ for keyword, (tool_name, supports_rag) in TOOL_ROUTER.items():
191
+ if keyword in query.lower():
192
+ print(f"[ROUTER] Tool selected: {tool_name}")
193
+ print(f"[ROUTER] Model: {model} | RAG: {use_rag}")
194
+
195
+ tool_func = TOOL_FUNCTIONS[tool_name]
196
+
197
+ if supports_rag:
198
+ response = tool_func(query, model=model, use_rag=use_rag)
199
+ else:
200
+ response = tool_func(query, model=model)
201
+
202
+ return response, tool_name
203
+
204
+ # Default fallback
205
+ tool_name = "chat_memory_tool"
206
+ print(f"[ROUTER] Tool selected: {tool_name}")
207
+ print(f"[ROUTER] Model: {model} | RAG: {use_rag}")
208
+
209
+ response = chat_memory_tool(query, model=model)
210
+ return response, tool_name