shibbir24 commited on
Commit
65562f6
·
verified ·
1 Parent(s): f4126dc

Upload 6 files

Browse files
Files changed (6) hide show
  1. Dockerfile +28 -0
  2. app.py +213 -0
  3. embedding.py +47 -0
  4. metrics_tracker.py +41 -0
  5. requirements.txt +28 -0
  6. tool_handler.py +200 -0
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10.13-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # System dependencies required for blis / thinc / spacy
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ gcc \
9
+ g++ \
10
+ git \
11
+ libopenblas-dev \
12
+ libomp-dev \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Upgrade pip first
16
+ RUN pip install --upgrade pip setuptools wheel
17
+
18
+ # Install Python deps
19
+ COPY requirements.txt .
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy app
23
+ COPY . .
24
+
25
+ # Streamlit default port
26
+ EXPOSE 7860
27
+
28
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import spacy
4
+ from groq import Groq
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv() # This loads variables from .env into environment
8
+
9
+ dotenv_path = os.path.join(os.path.dirname(__file__), 'API_key.env') # adjust if needed
10
+ load_dotenv(dotenv_path)
11
+
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
13
+
14
+ from embedding import discharge_collection, trials_collection, get_embedding
15
+
16
+ from serpapi import GoogleSearch
17
+ from pyspur.backend.pyspur.nodes.decorator import tool_function
18
+
19
+ # Load API key from Hugging Face secret
20
+
21
+ groq_api_key = os.getenv("GROQ_API_KEY")
22
+ if not groq_api_key:
23
+ raise ValueError("Missing GROQ_API_KEY in environment variables.")
24
+
25
+ serp_api_key = os.environ.get("SERP_API_KEY")
26
+ if not serp_api_key:
27
+ raise ValueError("Missing SERP_API_KEY in environment variables.")
28
+
29
+ # Initialize LLM client and spaCy
30
+
31
+ client = Groq(api_key=groq_api_key)
32
+
33
+ try:
34
+ nlp = spacy.load("en_core_web_sm")
35
+ except OSError:
36
+ from spacy.cli import download
37
+ download("en_core_web_sm")
38
+ nlp = spacy.load("en_core_web_sm")
39
+
40
+ SYMPTOM_HINTS = [
41
+ "chest pain", "shortness of breath", "fatigue", "dizziness",
42
+ "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
43
+ "arm pain", "back pain", "tightness", "pressure in chest",
44
+ "arrhythmia", "tachycardia", "bradycardia", "angina",
45
+ "edema", "dyspnea", "syncope", "lightheadedness",
46
+ "ejection fraction", "myocardial infarction", "heart failure",
47
+ "cardiomyopathy", "cardiac arrest"
48
+ ]
49
+
50
+ @tool_function(name="chat_memory_tool")
51
+ def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
52
+ doc = nlp(memory)
53
+ found_symptoms = set(
54
+ keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
55
+ )
56
+ symptom_context = (
57
+ f"Previously mentioned symptoms include: {', '.join(found_symptoms)}."
58
+ if found_symptoms else "No clear symptoms found in memory."
59
+ )
60
+ response = client.chat.completions.create(
61
+ model=model,
62
+ messages=[
63
+ {"role": "system", "content": "You are a medical assistant summarizing prior symptoms from memory."},
64
+ {"role": "assistant", "content": memory},
65
+ {"role": "user", "content": (
66
+ f"The patient previously reported: {memory}\n\n"
67
+ f"Symptoms extracted: {symptom_context}\n"
68
+ "Please provide a clear, concise, and helpful summary of these symptoms and suggest next steps."
69
+ )}
70
+ ]
71
+ )
72
+ return response.choices[0].message.content
73
+
74
+ @tool_function(name="treatment_tool")
75
+ def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
76
+ try:
77
+ query_embedding = get_embedding(query)
78
+ if use_rag:
79
+ results = discharge_collection.query(
80
+ query_embeddings=[query_embedding],
81
+ n_results=5,
82
+ include=["documents"]
83
+ )
84
+ top_docs = results['documents'][0] if results and results['documents'] else []
85
+ top_docs = [doc[:1500] for doc in top_docs]
86
+ combined_context = "\n\n".join(top_docs)
87
+ prompt = (
88
+ "You are a helpful medical assistant. Based on the following discharge notes, "
89
+ "recommend essential treatment.\n\n"
90
+ f"### Notes:\n{combined_context}\n\n### Condition:\n{query}"
91
+ )
92
+ else:
93
+ prompt = f"Patient condition: {query}. What treatment is recommended?"
94
+
95
+ response = client.chat.completions.create(
96
+ model=model,
97
+ messages=[
98
+ {"role": "system", "content": "You are a medically accurate and safety-focused clinical assistant."},
99
+ {"role": "user", "content": prompt}
100
+ ]
101
+ )
102
+ return response.choices[0].message.content
103
+
104
+ except Exception as e:
105
+ return f"Error: {str(e)}"
106
+
107
+ @tool_function(name="symptom_search_tool")
108
+ def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
109
+ def perform_search(query):
110
+ params = {
111
+ "engine": "google",
112
+ "q": f"{query} possible causes site:mayoclinic.org OR site:webmd.com OR site:nih.gov",
113
+ "api_key": serp_api_key
114
+ }
115
+ return GoogleSearch(params).get_dict().get("organic_results", [])
116
+
117
+ try:
118
+ results = perform_search(symptom_description)
119
+ if not results:
120
+ return "No reliable medical source found."
121
+
122
+ sources = []
123
+ snippets_with_citations = []
124
+ for res in results[:3]:
125
+ if 'snippet' in res and 'link' in res:
126
+ source_url = res['link']
127
+ domain = source_url.split("//")[-1].split("/")[0].replace("www.", "")
128
+ snippets_with_citations.append(f"{res['snippet']} (Source: {domain})")
129
+ sources.append(source_url)
130
+
131
+ search_context = "\n\n".join(snippets_with_citations)
132
+ response = client.chat.completions.create(
133
+ model=model,
134
+ messages=[
135
+ {"role": "system", "content": "You are a medical assistant using trusted web sources to explain symptom causes."},
136
+ {"role": "assistant", "content": search_context},
137
+ {"role": "user", "content": f"What could be the cause of: {symptom_description}?"}
138
+ ]
139
+ )
140
+
141
+ bulleted_sources = "\n".join(f"- {url}" for url in sources)
142
+ return response.choices[0].message.content + "\n\n**Sources:**\n" + bulleted_sources
143
+
144
+ except Exception as e:
145
+ return f"Search error: {str(e)}"
146
+
147
+ @tool_function(name="trial_matcher_tool")
148
+ def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
149
+ try:
150
+ query_embedding = get_embedding(discharge_note)
151
+ results = trials_collection.query(
152
+ query_embeddings=[query_embedding],
153
+ n_results=3,
154
+ include=["documents", "metadatas"]
155
+ )
156
+ if not results.get('documents') or not results['documents'][0]:
157
+ return "No matching clinical trials were found for the provided note."
158
+
159
+ summaries = []
160
+ for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
161
+ nct_id = meta.get("NCT ID") or "Unknown ID"
162
+ truncated_doc = doc.strip()[:1500]
163
+ if use_rag:
164
+ summary_prompt = (
165
+ f"You are a clinical assistant reviewing a matched clinical trial.\n"
166
+ f"Summarize the trial using **bullet points only** for the following fields:\n"
167
+ f"- NCT ID\n- Study Title\n- Conditions\n- Inclusion Criteria\n- Exclusion Criteria\n\n"
168
+ f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
169
+ f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
170
+ )
171
+ response = client.chat.completions.create(
172
+ model=model,
173
+ messages=[
174
+ {"role": "system", "content": "You are a medically precise clinical research assistant."},
175
+ {"role": "user", "content": summary_prompt}
176
+ ]
177
+ )
178
+ summaries.append(f"### Trial {i+1}:\n{response.choices[0].message.content}")
179
+ else:
180
+ summaries.append(f"### Trial {i+1}:\nNCT ID: {nct_id}\n\n{truncated_doc}")
181
+
182
+ return "\n\n---\n\n".join(summaries)
183
+
184
+ except Exception as e:
185
+ return f"Error during trial matching: {str(e)}"
186
+
187
+ # Tool routing via keyword logic
188
+
189
+ TOOL_ROUTER = {
190
+ "symptom": ("symptom_search_tool", False),
191
+ "treatment": ("treatment_tool", True),
192
+ "trial": ("trial_matcher_tool", True)
193
+ }
194
+
195
+ TOOL_FUNCTIONS = {
196
+ "chat_memory_tool": chat_memory_tool,
197
+ "treatment_tool": treatment_tool,
198
+ "symptom_search_tool": symptom_search_tool,
199
+ "trial_matcher_tool": trial_matcher_tool
200
+ }
201
+
202
+ def run_tool(query: str, model: str, use_rag: bool) -> str:
203
+ for keyword, (tool_name, supports_rag) in TOOL_ROUTER.items():
204
+ if keyword in query.lower():
205
+ print(f"Tool selected by PySpur: {tool_name}")
206
+ tool_func = TOOL_FUNCTIONS[tool_name]
207
+ if supports_rag:
208
+ return tool_func(query, model=model, use_rag=use_rag)
209
+ else:
210
+ return tool_func(query, model=model)
211
+
212
+ print("Tool selected by PySpur: chat_memory_tool")
213
+ return chat_memory_tool(query, model=model)
embedding.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=UserWarning)
3
+ warnings.filterwarnings("ignore", category=FutureWarning)
4
+
5
+ import os
6
+ import zipfile
7
+ import torch
8
+ from transformers import AutoModel, AutoTokenizer
9
+ import chromadb
10
+
11
+ # Constants
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ MODEL_NAME = "dmis-lab/biobert-base-cased-v1.1"
14
+ DB_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chromadb_store")
15
+ ZIP_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chromadb_store.zip")
16
+
17
+ # Step 1: Unzip the vector store if not already present
18
+ if not os.path.exists(os.path.join(DB_DIR, "chroma.sqlite3")):
19
+ print("🔓 Unzipping prebuilt ChromaDB store...")
20
+ with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
21
+ zip_ref.extractall(".")
22
+ print("Vector store unzipped and ready.")
23
+ else:
24
+ print("Vector store already present. Skipping unzip.")
25
+
26
+ # Step 2: Connect to persistent ChromaDB
27
+ client = chromadb.PersistentClient(path=DB_DIR)
28
+ discharge_collection = client.get_or_create_collection("discharge_notes")
29
+ trials_collection = client.get_or_create_collection("clinical_trials")
30
+
31
+ # Step 3: Load BioBERT for embedding
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
+ model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
34
+ model.eval()
35
+
36
+ # Step 4: Embedding function
37
+ def get_embedding(text: str):
38
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
39
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+ return outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy().tolist()
43
+
44
+ # Final check
45
+ print(f"📦 ChromaDB Status:")
46
+ print(f" - Discharge Notes Loaded: {discharge_collection.count()}")
47
+ print(f" - Clinical Trials Loaded: {trials_collection.count()}")
metrics_tracker.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import os
4
+
5
+ class MetricsTracker:
6
+ def __init__(self):
7
+ self.total_queries = 0
8
+ self.successful_routings = 0
9
+ self.failed_routings = 0
10
+ self.response_times = []
11
+
12
+ def record_query(self, routed_correctly: bool, response_time: float):
13
+ self.total_queries += 1
14
+ if routed_correctly:
15
+ self.successful_routings += 1
16
+ else:
17
+ self.failed_routings += 1
18
+ self.response_times.append(response_time)
19
+
20
+ def get_metrics_summary(self):
21
+ if self.total_queries == 0:
22
+ accuracy = 0.0
23
+ avg_response_time = 0.0
24
+ else:
25
+ accuracy = (self.successful_routings / self.total_queries) * 100
26
+ avg_response_time = sum(self.response_times) / self.total_queries
27
+
28
+ return {
29
+ "Total Queries": self.total_queries,
30
+ "Successful Routings": self.successful_routings,
31
+ "Failed Routings": self.failed_routings,
32
+ "Routing Accuracy (%)": round(accuracy, 2),
33
+ "Average Response Time (sec)": round(avg_response_time, 2)
34
+ }
35
+
36
+ def print_metrics_summary(self):
37
+ summary = self.get_metrics_summary()
38
+ print("\n=== Metrics Summary ===")
39
+ for k, v in summary.items():
40
+ print(f"{k}: {v}")
41
+ print("==========================\n")
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip>=23.2.1
2
+
3
+ # NLP
4
+ spacy==3.7.2
5
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.2/en_core_web_sm-3.7.2-py3-none-any.whl
6
+
7
+ # ML / DL
8
+ torch==2.2.0
9
+ transformers==4.36.2
10
+ sentence-transformers==2.2.2
11
+ scikit-learn==1.3.2
12
+ numpy
13
+ pandas
14
+
15
+ # PDF / File handling
16
+ PyPDF2==3.0.1
17
+ pdfplumber==0.10.3
18
+
19
+ # Database / Vector DB
20
+ chromadb==0.6.2
21
+ pysqlite3-binary
22
+ sqlalchemy>=1.4.0
23
+ groq==0.15.0
24
+
25
+ # Web / API
26
+ streamlit==1.32.0
27
+ google-search-results==2.4.2
28
+ httpx==0.27.0
tool_handler.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import spacy
4
+ from groq import Groq
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'pyspur/backend/')))
6
+
7
+ from embedding import discharge_collection, trials_collection, get_embedding
8
+
9
+ from serpapi import GoogleSearch
10
+ from pyspur.backend.pyspur.nodes.decorator import tool_function
11
+
12
+ # Load API key from Hugging Face secret
13
+
14
+ groq_api_key = os.getenv("GROQ_API_KEY")
15
+ if not groq_api_key:
16
+ raise ValueError("Missing GROQ_API_KEY in environment variables.")
17
+
18
+ serp_api_key = os.environ.get("SERP_API_KEY")
19
+ if not serp_api_key:
20
+ raise ValueError("Missing SERP_API_KEY in environment variables.")
21
+
22
+ # Initialize LLM client and spaCy
23
+
24
+ client = Groq(api_key=groq_api_key)
25
+ nlp = spacy.load("en_core_web_sm")
26
+
27
+ SYMPTOM_HINTS = [
28
+ "chest pain", "shortness of breath", "fatigue", "dizziness",
29
+ "nausea", "vomiting", "palpitations", "sweating", "jaw pain",
30
+ "arm pain", "back pain", "tightness", "pressure in chest",
31
+ "arrhythmia", "tachycardia", "bradycardia", "angina",
32
+ "edema", "dyspnea", "syncope", "lightheadedness",
33
+ "ejection fraction", "myocardial infarction", "heart failure",
34
+ "cardiomyopathy", "cardiac arrest"
35
+ ]
36
+
37
+ @tool_function(name="chat_memory_tool")
38
+ def chat_memory_tool(memory: str, model: str = "llama-3.3-70b-versatile") -> str:
39
+ doc = nlp(memory)
40
+ found_symptoms = set(
41
+ keyword for chunk in doc.noun_chunks for keyword in SYMPTOM_HINTS if keyword in chunk.text.lower()
42
+ )
43
+ symptom_context = (
44
+ f"Previously mentioned symptoms include: {', '.join(found_symptoms)}."
45
+ if found_symptoms else "No clear symptoms found in memory."
46
+ )
47
+ response = client.chat.completions.create(
48
+ model=model,
49
+ messages=[
50
+ {"role": "system", "content": "You are a medical assistant summarizing prior symptoms from memory."},
51
+ {"role": "assistant", "content": memory},
52
+ {"role": "user", "content": (
53
+ f"The patient previously reported: {memory}\n\n"
54
+ f"Symptoms extracted: {symptom_context}\n"
55
+ "Please provide a clear, concise, and helpful summary of these symptoms and suggest next steps."
56
+ )}
57
+ ]
58
+ )
59
+ return response.choices[0].message.content
60
+
61
+ @tool_function(name="treatment_tool")
62
+ def treatment_tool(query: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
63
+ try:
64
+ query_embedding = get_embedding(query)
65
+ if use_rag:
66
+ results = discharge_collection.query(
67
+ query_embeddings=[query_embedding],
68
+ n_results=5,
69
+ include=["documents"]
70
+ )
71
+ top_docs = results['documents'][0] if results and results['documents'] else []
72
+ top_docs = [doc[:1500] for doc in top_docs]
73
+ combined_context = "\n\n".join(top_docs)
74
+ prompt = (
75
+ "You are a helpful medical assistant. Based on the following discharge notes, "
76
+ "recommend essential treatment.\n\n"
77
+ f"### Notes:\n{combined_context}\n\n### Condition:\n{query}"
78
+ )
79
+ else:
80
+ prompt = f"Patient condition: {query}. What treatment is recommended?"
81
+
82
+ response = client.chat.completions.create(
83
+ model=model,
84
+ messages=[
85
+ {"role": "system", "content": "You are a medically accurate and safety-focused clinical assistant."},
86
+ {"role": "user", "content": prompt}
87
+ ]
88
+ )
89
+ return response.choices[0].message.content
90
+
91
+ except Exception as e:
92
+ return f"Error: {str(e)}"
93
+
94
+ @tool_function(name="symptom_search_tool")
95
+ def symptom_search_tool(symptom_description: str, model: str = "llama-3.3-70b-versatile") -> str:
96
+ def perform_search(query):
97
+ params = {
98
+ "engine": "google",
99
+ "q": f"{query} possible causes site:mayoclinic.org OR site:webmd.com OR site:nih.gov",
100
+ "api_key": serp_api_key
101
+ }
102
+ return GoogleSearch(params).get_dict().get("organic_results", [])
103
+
104
+ try:
105
+ results = perform_search(symptom_description)
106
+ if not results:
107
+ return "No reliable medical source found."
108
+
109
+ sources = []
110
+ snippets_with_citations = []
111
+ for res in results[:3]:
112
+ if 'snippet' in res and 'link' in res:
113
+ source_url = res['link']
114
+ domain = source_url.split("//")[-1].split("/")[0].replace("www.", "")
115
+ snippets_with_citations.append(f"{res['snippet']} (Source: {domain})")
116
+ sources.append(source_url)
117
+
118
+ search_context = "\n\n".join(snippets_with_citations)
119
+ response = client.chat.completions.create(
120
+ model=model,
121
+ messages=[
122
+ {"role": "system", "content": "You are a medical assistant using trusted web sources to explain symptom causes."},
123
+ {"role": "assistant", "content": search_context},
124
+ {"role": "user", "content": f"What could be the cause of: {symptom_description}?"}
125
+ ]
126
+ )
127
+
128
+ bulleted_sources = "\n".join(f"- {url}" for url in sources)
129
+ return response.choices[0].message.content + "\n\n**Sources:**\n" + bulleted_sources
130
+
131
+ except Exception as e:
132
+ return f"Search error: {str(e)}"
133
+
134
+ @tool_function(name="trial_matcher_tool")
135
+ def trial_matcher_tool(discharge_note: str, model: str = "llama-3.3-70b-versatile", use_rag: bool = True) -> str:
136
+ try:
137
+ query_embedding = get_embedding(discharge_note)
138
+ results = trials_collection.query(
139
+ query_embeddings=[query_embedding],
140
+ n_results=3,
141
+ include=["documents", "metadatas"]
142
+ )
143
+ if not results.get('documents') or not results['documents'][0]:
144
+ return "No matching clinical trials were found for the provided note."
145
+
146
+ summaries = []
147
+ for i, (doc, meta) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
148
+ nct_id = meta.get("NCT ID") or "Unknown ID"
149
+ truncated_doc = doc.strip()[:1500]
150
+ if use_rag:
151
+ summary_prompt = (
152
+ f"You are a clinical assistant reviewing a matched clinical trial.\n"
153
+ f"Summarize the trial using **bullet points only** for the following fields:\n"
154
+ f"- NCT ID\n- Study Title\n- Conditions\n- Inclusion Criteria\n- Exclusion Criteria\n\n"
155
+ f"Use bullets under each field. Maintain a clean format. Respond only with the summary.\n\n"
156
+ f"Trial Description:\nNCT ID: {nct_id}\n{truncated_doc}"
157
+ )
158
+ response = client.chat.completions.create(
159
+ model=model,
160
+ messages=[
161
+ {"role": "system", "content": "You are a medically precise clinical research assistant."},
162
+ {"role": "user", "content": summary_prompt}
163
+ ]
164
+ )
165
+ summaries.append(f"### Trial {i+1}:\n{response.choices[0].message.content}")
166
+ else:
167
+ summaries.append(f"### Trial {i+1}:\nNCT ID: {nct_id}\n\n{truncated_doc}")
168
+
169
+ return "\n\n---\n\n".join(summaries)
170
+
171
+ except Exception as e:
172
+ return f"Error during trial matching: {str(e)}"
173
+
174
+ # Tool routing via keyword logic
175
+
176
+ TOOL_ROUTER = {
177
+ "symptom": ("symptom_search_tool", False),
178
+ "treatment": ("treatment_tool", True),
179
+ "trial": ("trial_matcher_tool", True)
180
+ }
181
+
182
+ TOOL_FUNCTIONS = {
183
+ "chat_memory_tool": chat_memory_tool,
184
+ "treatment_tool": treatment_tool,
185
+ "symptom_search_tool": symptom_search_tool,
186
+ "trial_matcher_tool": trial_matcher_tool
187
+ }
188
+
189
+ def run_tool(query: str, model: str, use_rag: bool) -> str:
190
+ for keyword, (tool_name, supports_rag) in TOOL_ROUTER.items():
191
+ if keyword in query.lower():
192
+ print(f"Tool selected by PySpur: {tool_name}")
193
+ tool_func = TOOL_FUNCTIONS[tool_name]
194
+ if supports_rag:
195
+ return tool_func(query, model=model, use_rag=use_rag)
196
+ else:
197
+ return tool_func(query, model=model)
198
+
199
+ print("Tool selected by PySpur: chat_memory_tool")
200
+ return chat_memory_tool(query, model=model)