raagustin commited on
Commit
b6b8c87
·
1 Parent(s): 5c410ee

Add vector db creation and querying methods

Browse files
Files changed (2) hide show
  1. llm_service.py +70 -1
  2. requirements.txt +3 -1
llm_service.py CHANGED
@@ -1,4 +1,8 @@
1
  from contextlib import contextmanager
 
 
 
 
2
 
3
  class LLMService(object):
4
  def __init__(self):
@@ -31,12 +35,77 @@ class LLMService(object):
31
  class DefaultLLMService(LLMService):
32
  def __init__(self, api_key: str):
33
  self._api_key = api_key
 
 
 
34
 
35
  def close(self):
36
  raise Exception("Not implemented")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def get_summary(self, patient: str) -> str:
39
- # TODO: Presumably querying the vector database for patient-related stuff
 
 
 
 
 
40
  # TODO: Add the found data to the context and asking OpenAI to summarize the docs provided?
41
  raise Exception("Not implemented")
42
 
 
1
  from contextlib import contextmanager
2
+ import chromadb
3
+ from chromadb.config import Settings
4
+ from sentence_transformers import SentenceTransformer
5
+ import pandas as pd
6
 
7
  class LLMService(object):
8
  def __init__(self):
 
35
  class DefaultLLMService(LLMService):
36
  def __init__(self, api_key: str):
37
  self._api_key = api_key
38
+ #TODO decide on embedding model, using one provided in notebook
39
+ self._embed_model = SentenceTransformer("all-MiniLM-L6-v2")
40
+ self.build_chromadb()
41
 
42
  def close(self):
43
  raise Exception("Not implemented")
44
 
45
+ def get_chromadb(self,clear=0):
46
+ client=chromadb.Client(Settings(
47
+ persist_directory="./chroma_db"
48
+ ))
49
+ if clear:
50
+ client.delete_collection(collection_name)
51
+ return client.get_or_create_collection(name=collection_name)
52
+
53
+ def build_chromadb(self):
54
+ #TODO replace with cleaned data url or move to service inputs
55
+ self._df = pd.read_csv("https://huggingface.co/datasets/patjs/patient1/raw/main/patient_encounters1_notes.csv")
56
+
57
+ collection = self.get_chromadb(clear=1)
58
+ texts = self._df["CLINICAL_NOTES"].astype(str).tolist()
59
+ embeddings = self._embed_model.encode(texts).tolist()
60
+
61
+ collection.add(
62
+ documents=texts,
63
+ embeddings=embeddings,
64
+ metadatas=[df.iloc[i,3:11].to_dict() for i in range(len(texts))], #store everything except clinical notes and ids as metadata
65
+ ids=[str(i) for i in range(len(texts))]
66
+ )
67
+
68
+ def query_chromadb(self, patient: str, query: str, result_template:str, top_n=3) -> str:
69
+ if (patient=="") or (patient is None):
70
+ return ""
71
+ if (query=="") or (query is None):
72
+ return ""
73
+
74
+ collection = self.get_chromadb()
75
+ query_embedding = self._embed_model.encode([query])[0].tolist()
76
+
77
+ results = collection.query(
78
+ query_embeddings=[query_embedding],
79
+ n_results=top_n,
80
+ # include=["documents","metadatas","distances"], #these are default query outputs so no need to specifiy
81
+ where={"PATIENT_ID":patient} # specify patient
82
+ )
83
+
84
+ #TODO refine template for what info to include
85
+ if result_template=="":
86
+ result_template="""
87
+ #{rank}: {desc} {st_date}\n
88
+ {note}\n
89
+ """
90
+
91
+ context=""
92
+ for i in range(len(results["ids"][0])):
93
+ result_txt = result_template.format(
94
+ rank=(i+1), #range is 0 indexed, increment for rank
95
+ desc=results["metadatas"][0][i]["DESCRIPTION"],
96
+ st_date=results["metadatas"][0][i]["START"],
97
+ note=results["documents"][0][i])
98
+ context = context+result_txt
99
+
100
+ return context
101
+
102
  def get_summary(self, patient: str) -> str:
103
+ #TODO get all visit notes or querying the vector database with specific prompt for a patient
104
+ # all_visits = self._df.loc[self._df["PATIENT_ID"] == patient]
105
+ vector_query=""
106
+ vector_result_template="" #format for each result from vector search
107
+ summary_context=self.query_chromadb(patient,vector_query,vector_result_template)
108
+ summary_query=""
109
  # TODO: Add the found data to the context and asking OpenAI to summarize the docs provided?
110
  raise Exception("Not implemented")
111
 
requirements.txt CHANGED
@@ -2,4 +2,6 @@ gradio==5.32.1
2
  langchain
3
  openai
4
  python-dotenv
5
- langchain_openai
 
 
 
2
  langchain
3
  openai
4
  python-dotenv
5
+ langchain_openai
6
+ chromadb
7
+ sentence-transformers