Spaces:
Sleeping
Sleeping
Add vector db creation and querying methods
Browse files- llm_service.py +70 -1
- requirements.txt +3 -1
llm_service.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class LLMService(object):
|
| 4 |
def __init__(self):
|
|
@@ -31,12 +35,77 @@ class LLMService(object):
|
|
| 31 |
class DefaultLLMService(LLMService):
|
| 32 |
def __init__(self, api_key: str):
|
| 33 |
self._api_key = api_key
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def close(self):
|
| 36 |
raise Exception("Not implemented")
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def get_summary(self, patient: str) -> str:
|
| 39 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# TODO: Add the found data to the context and asking OpenAI to summarize the docs provided?
|
| 41 |
raise Exception("Not implemented")
|
| 42 |
|
|
|
|
| 1 |
from contextlib import contextmanager
|
| 2 |
+
import chromadb
|
| 3 |
+
from chromadb.config import Settings
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
import pandas as pd
|
| 6 |
|
| 7 |
class LLMService(object):
|
| 8 |
def __init__(self):
|
|
|
|
| 35 |
class DefaultLLMService(LLMService):
|
| 36 |
def __init__(self, api_key: str):
|
| 37 |
self._api_key = api_key
|
| 38 |
+
#TODO decide on embedding model, using one provided in notebook
|
| 39 |
+
self._embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 40 |
+
self.build_chromadb()
|
| 41 |
|
| 42 |
def close(self):
|
| 43 |
raise Exception("Not implemented")
|
| 44 |
|
| 45 |
+
def get_chromadb(self,clear=0):
|
| 46 |
+
client=chromadb.Client(Settings(
|
| 47 |
+
persist_directory="./chroma_db"
|
| 48 |
+
))
|
| 49 |
+
if clear:
|
| 50 |
+
client.delete_collection(collection_name)
|
| 51 |
+
return client.get_or_create_collection(name=collection_name)
|
| 52 |
+
|
| 53 |
+
def build_chromadb(self):
|
| 54 |
+
#TODO replace with cleaned data url or move to service inputs
|
| 55 |
+
self._df = pd.read_csv("https://huggingface.co/datasets/patjs/patient1/raw/main/patient_encounters1_notes.csv")
|
| 56 |
+
|
| 57 |
+
collection = self.get_chromadb(clear=1)
|
| 58 |
+
texts = self._df["CLINICAL_NOTES"].astype(str).tolist()
|
| 59 |
+
embeddings = self._embed_model.encode(texts).tolist()
|
| 60 |
+
|
| 61 |
+
collection.add(
|
| 62 |
+
documents=texts,
|
| 63 |
+
embeddings=embeddings,
|
| 64 |
+
metadatas=[df.iloc[i,3:11].to_dict() for i in range(len(texts))], #store everything except clinical notes and ids as metadata
|
| 65 |
+
ids=[str(i) for i in range(len(texts))]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def query_chromadb(self, patient: str, query: str, result_template:str, top_n=3) -> str:
|
| 69 |
+
if (patient=="") or (patient is None):
|
| 70 |
+
return ""
|
| 71 |
+
if (query=="") or (query is None):
|
| 72 |
+
return ""
|
| 73 |
+
|
| 74 |
+
collection = self.get_chromadb()
|
| 75 |
+
query_embedding = self._embed_model.encode([query])[0].tolist()
|
| 76 |
+
|
| 77 |
+
results = collection.query(
|
| 78 |
+
query_embeddings=[query_embedding],
|
| 79 |
+
n_results=top_n,
|
| 80 |
+
# include=["documents","metadatas","distances"], #these are default query outputs so no need to specifiy
|
| 81 |
+
where={"PATIENT_ID":patient} # specify patient
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
#TODO refine template for what info to include
|
| 85 |
+
if result_template=="":
|
| 86 |
+
result_template="""
|
| 87 |
+
#{rank}: {desc} {st_date}\n
|
| 88 |
+
{note}\n
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
context=""
|
| 92 |
+
for i in range(len(results["ids"][0])):
|
| 93 |
+
result_txt = result_template.format(
|
| 94 |
+
rank=(i+1), #range is 0 indexed, increment for rank
|
| 95 |
+
desc=results["metadatas"][0][i]["DESCRIPTION"],
|
| 96 |
+
st_date=results["metadatas"][0][i]["START"],
|
| 97 |
+
note=results["documents"][0][i])
|
| 98 |
+
context = context+result_txt
|
| 99 |
+
|
| 100 |
+
return context
|
| 101 |
+
|
| 102 |
def get_summary(self, patient: str) -> str:
|
| 103 |
+
#TODO get all visit notes or querying the vector database with specific prompt for a patient
|
| 104 |
+
# all_visits = self._df.loc[self._df["PATIENT_ID"] == patient]
|
| 105 |
+
vector_query=""
|
| 106 |
+
vector_result_template="" #format for each result from vector search
|
| 107 |
+
summary_context=self.query_chromadb(patient,vector_query,vector_result_template)
|
| 108 |
+
summary_query=""
|
| 109 |
# TODO: Add the found data to the context and asking OpenAI to summarize the docs provided?
|
| 110 |
raise Exception("Not implemented")
|
| 111 |
|
requirements.txt
CHANGED
|
@@ -2,4 +2,6 @@ gradio==5.32.1
|
|
| 2 |
langchain
|
| 3 |
openai
|
| 4 |
python-dotenv
|
| 5 |
-
langchain_openai
|
|
|
|
|
|
|
|
|
| 2 |
langchain
|
| 3 |
openai
|
| 4 |
python-dotenv
|
| 5 |
+
langchain_openai
|
| 6 |
+
chromadb
|
| 7 |
+
sentence-transformers
|