shibbir24 commited on
Commit
af26cad
·
verified ·
1 Parent(s): daa7a96

Update tool_handler.py

Browse files
Files changed (1) hide show
  1. tool_handler.py +20 -30
tool_handler.py CHANGED
@@ -3,18 +3,15 @@ import os
3
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
4
 
5
  from embedding import discharge_collection, trials_collection, get_embedding
6
-
7
  from groq import Client
8
  from serpapi import GoogleSearch
9
  import spacy
10
- from pyspur import SpurAgent
11
  from pyspur.nodes.decorator import tool_function
12
 
13
  # Initialize LLM client and spaCy
14
  client = Client(api_key="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")
15
  nlp = spacy.load("en_core_web_sm")
16
 
17
- # Symptom keywords
18
  SYMPTOM_HINTS = [
19
  "chest pain", "shortness of breath", "fatigue", "dizziness",
20
  "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
@@ -25,9 +22,8 @@ SYMPTOM_HINTS = [
25
  "cardiomyopathy", "cardiac arrest"
26
  ]
27
 
28
- # Tool 1: Chat Memory Symptom Reasoner
29
- @tool_function
30
- def chat_memory_tool(memory: str, model="llama-3.3-70b-versatile") -> str:
31
  doc = nlp(memory)
32
  found_symptoms = set(
33
  keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
@@ -50,12 +46,10 @@ def chat_memory_tool(memory: str, model="llama-3.3-70b-versatile") -> str:
50
  )
51
  return response.choices[0].message.content
52
 
53
- # Tool 2: Treatment Recommender
54
- @tool_function
55
- def treatment_tool(query: str, model="llama-3.3-70b-versatile", use_rag=True) -> str:
56
  try:
57
  query_embedding = get_embedding(query)
58
-
59
  if use_rag:
60
  results = discharge_collection.query(
61
  query_embeddings=[query_embedding],
@@ -65,7 +59,6 @@ def treatment_tool(query: str, model="llama-3.3-70b-versatile", use_rag=True) ->
65
  top_docs = results['documents'][0] if results and results['documents'] else []
66
  top_docs = [doc[:1500] for doc in top_docs]
67
  combined_context = "\n\n".join(top_docs)
68
-
69
  prompt = (
70
  "You are a helpful medical assistant. Based on the following discharge notes, "
71
  "recommend essential treatment.\n\n"
@@ -86,9 +79,8 @@ def treatment_tool(query: str, model="llama-3.3-70b-versatile", use_rag=True) ->
86
  except Exception as e:
87
  return f"Error: {str(e)}"
88
 
89
- # Tool 3: Symptom Cause Analyzer
90
- @tool_function
91
- def symptom_search_tool(symptom_description: str, model="llama-3.3-70b-versatile") -> str:
92
  def perform_search(query):
93
  params = {
94
  "engine": "google",
@@ -104,7 +96,6 @@ def symptom_search_tool(symptom_description: str, model="llama-3.3-70b-versatile
104
 
105
  sources = []
106
  snippets_with_citations = []
107
-
108
  for res in results[:3]:
109
  if 'snippet' in res and 'link' in res:
110
  source_url = res['link']
@@ -113,7 +104,6 @@ def symptom_search_tool(symptom_description: str, model="llama-3.3-70b-versatile
113
  sources.append(source_url)
114
 
115
  search_context = "\n\n".join(snippets_with_citations)
116
-
117
  response = client.chat.completions.create(
118
  model=model,
119
  messages=[
@@ -129,18 +119,15 @@ def symptom_search_tool(symptom_description: str, model="llama-3.3-70b-versatile
129
  except Exception as e:
130
  return f"Search error: {str(e)}"
131
 
132
- # Tool 4: Clinical Trial Matcher
133
- @tool_function
134
- def trial_matcher_tool(discharge_note: str, model="llama-3.3-70b-versatile", use_rag=True) -> str:
135
  try:
136
  query_embedding = get_embedding(discharge_note)
137
-
138
  results = trials_collection.query(
139
  query_embeddings=[query_embedding],
140
  n_results=3,
141
  include=["documents", "metadatas"]
142
  )
143
-
144
  if not results.get('documents') or not results['documents'][0]:
145
  return "No matching clinical trials were found for the provided note."
146
 
@@ -148,7 +135,6 @@ def trial_matcher_tool(discharge_note: str, model="llama-3.3-70b-versatile", use
148
  for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
149
  nct_id = meta.get("NCT ID") or "Unknown ID"
150
  truncated_doc = doc.strip()[:1500]
151
-
152
  if use_rag:
153
  summary_prompt = (
154
  f"You are a clinical assistant reviewing a matched clinical trial.\n"
@@ -157,7 +143,6 @@ def trial_matcher_tool(discharge_note: str, model="llama-3.3-70b-versatile", use
157
  f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
158
  f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
159
  )
160
-
161
  response = client.chat.completions.create(
162
  model=model,
163
  messages=[
@@ -174,10 +159,15 @@ def trial_matcher_tool(discharge_note: str, model="llama-3.3-70b-versatile", use
174
  except Exception as e:
175
  return f"Error during trial matching: {str(e)}"
176
 
177
- # Register tools in PySpur Agent
178
- agent = SpurAgent(tools=[
179
- chat_memory_tool,
180
- treatment_tool,
181
- symptom_search_tool,
182
- trial_matcher_tool
183
- ])
 
 
 
 
 
 
3
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
4
 
5
  from embedding import discharge_collection, trials_collection, get_embedding
 
6
  from groq import Client
7
  from serpapi import GoogleSearch
8
  import spacy
 
9
  from pyspur.nodes.decorator import tool_function
10
 
11
  # Initialize LLM client and spaCy
12
  client = Client(api_key="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")
13
  nlp = spacy.load("en_core_web_sm")
14
 
 
15
  SYMPTOM_HINTS = [
16
  "chest pain", "shortness of breath", "fatigue", "dizziness",
17
  "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
 
22
  "cardiomyopathy", "cardiac arrest"
23
  ]
24
 
25
+ @tool_function(name="chat_memory_tool")
26
+ def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
 
27
  doc = nlp(memory)
28
  found_symptoms = set(
29
  keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
 
46
  )
47
  return response.choices[0].message.content
48
 
49
+ @tool_function(name="treatment_tool")
50
+ def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
 
51
  try:
52
  query_embedding = get_embedding(query)
 
53
  if use_rag:
54
  results = discharge_collection.query(
55
  query_embeddings=[query_embedding],
 
59
  top_docs = results['documents'][0] if results and results['documents'] else []
60
  top_docs = [doc[:1500] for doc in top_docs]
61
  combined_context = "\n\n".join(top_docs)
 
62
  prompt = (
63
  "You are a helpful medical assistant. Based on the following discharge notes, "
64
  "recommend essential treatment.\n\n"
 
79
  except Exception as e:
80
  return f"Error: {str(e)}"
81
 
82
+ @tool_function(name="symptom_search_tool")
83
+ def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
 
84
  def perform_search(query):
85
  params = {
86
  "engine": "google",
 
96
 
97
  sources = []
98
  snippets_with_citations = []
 
99
  for res in results[:3]:
100
  if 'snippet' in res and 'link' in res:
101
  source_url = res['link']
 
104
  sources.append(source_url)
105
 
106
  search_context = "\n\n".join(snippets_with_citations)
 
107
  response = client.chat.completions.create(
108
  model=model,
109
  messages=[
 
119
  except Exception as e:
120
  return f"Search error: {str(e)}"
121
 
122
+ @tool_function(name="trial_matcher_tool")
123
+ def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
 
124
  try:
125
  query_embedding = get_embedding(discharge_note)
 
126
  results = trials_collection.query(
127
  query_embeddings=[query_embedding],
128
  n_results=3,
129
  include=["documents", "metadatas"]
130
  )
 
131
  if not results.get('documents') or not results['documents'][0]:
132
  return "No matching clinical trials were found for the provided note."
133
 
 
135
  for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
136
  nct_id = meta.get("NCT ID") or "Unknown ID"
137
  truncated_doc = doc.strip()[:1500]
 
138
  if use_rag:
139
  summary_prompt = (
140
  f"You are a clinical assistant reviewing a matched clinical trial.\n"
 
143
  f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
144
  f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
145
  )
 
146
  response = client.chat.completions.create(
147
  model=model,
148
  messages=[
 
159
  except Exception as e:
160
  return f"Error during trial matching: {str(e)}"
161
 
162
+ # Tool routing via keyword logic
163
+ TOOL_ROUTER = {
164
+ "symptom": symptom_search_tool,
165
+ "treatment": treatment_tool,
166
+ "trial": trial_matcher_tool
167
+ }
168
+
169
+ def run_tool(query: str, model: str, use_rag: bool) -> str:
170
+ for keyword, tool_func in TOOL_ROUTER.items():
171
+ if keyword in query.lower():
172
+ return tool_func(query, model=model, use_rag=use_rag)
173
+ return chat_memory_tool(query, model=model)