File size: 7,138 Bytes
edd2137
 
 
 
 
0bd536c
7836133
edd2137
 
 
 
7836133
edd2137
 
 
 
 
 
7836133
 
edd2137
7836133
edd2137
 
0bd536c
 
edd2137
 
 
 
 
 
 
 
 
 
 
 
 
7836133
 
 
 
edd2137
 
7836133
edd2137
7836133
edd2137
 
 
 
 
 
0bd536c
 
edd2137
102159c
edd2137
 
 
102159c
edd2137
 
7836133
edd2137
102159c
 
 
7836133
102159c
 
 
 
 
 
 
9ce4875
102159c
 
 
 
 
 
 
 
edd2137
 
102159c
 
 
 
edd2137
 
7836133
edd2137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102159c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edd2137
 
 
 
 
 
7836133
edd2137
 
 
 
 
 
 
 
7836133
edd2137
 
7836133
edd2137
 
 
 
 
 
0bd536c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from contextlib import contextmanager
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import pandas as pd
from openai import OpenAI
from data_service import DataService

class LLMService(object):
  def __init__(self):
    self._openAIKey = None
    self._data_service = None
  
  @contextmanager
  def build(self):
    llm_service = None
    if self._openAIKey is None:
      raise ValueError("OPEN AI key was not provided and there is no default value.")
    if self._data_service is None:
      raise ValueError("To get the patient documents, a data service must be provided before building the LLMService.")
    try:
      llm_service = DefaultLLMService(self._openAIKey, self._data_service)
      yield llm_service
    finally:
      if llm_service is not None:
        llm_service.close()
  
  def close(self):
    raise Exception("Should not use the base class")

  def get_summary(self, patient: str) -> str:
    raise Exception("Should not use the base class")
    
  def answer_query(self, patient: str, query: str) -> str:
    raise Exception("Should not use the base class")

  def with_key(self, api_key: str):
    self._openAIKey = api_key
    return self
  
  def with_data_service(self, data_service: DataService):
    self._data_service = data_service
    return self

class DefaultLLMService(LLMService):
  def __init__(self, api_key: str, data_service: DataService):
    self._api_key = api_key
    self._data_service = data_service
    self._chatclient=OpenAI(api_key=self._api_key)
    #TODO decide on embedding model, using one provided in notebook
    self._embed_model = SentenceTransformer("all-MiniLM-L6-v2")
    self.build_chromadb()

  def close(self):
    #raise Exception("Not implemented")
    pass

  def get_chromadb(self):
    client=chromadb.Client(Settings(
      persist_directory="./chroma_db"
    ))
    collection_name = "patient_data"  
    return client.get_or_create_collection(name=collection_name)
    
  # TODO: It probably makes no difference, but the reason I chose to use a decorator @initialize in data_service is that I learned somewhere that it was better to put as little logic in a constructor as possible because (at least in other languages), errors in constructors made everything complicated, and potentially slow initialization logic made things difficult to ... I don't remember. Debug or parallelize or something. Maybe all of them. The trade-off is, if you forgot to decorate the method with @initialize, bad things.
  def build_chromadb(self):
    collection = self.get_chromadb()
    # texts = self._data_service.get_documents()
    # metadatas = self._data_service.get_document_metadatas()
    # TODO: I'm looking at the docs and I'm wondering 1. if this is the default embedding function anyway (I think it's good to bring it out and see we can adjust it; I'm just pointing this out), and 2. whether it would be equivalent to provide self._embed_model.encode as the embedding function of the collection at configuration time.
    # embeddings = self._embed_model.encode(texts).tolist()

    df=self._data_service.get_data()

    all_ids = df["ENCOUNTER_ID"].astype(str).tolist()
    existing_ids = collection.get(ids=all_ids)["ids"]
    new_df = df[~df["ENCOUNTER_ID"].isin(existing_ids)] #get rows not in collection
    new_ids=new_df["ENCOUNTER_ID"].astype(str).tolist()

    # get data from new rows
    #TODO what other info should be vectorized
    # 1. add as a string the occupied columns?
    vector_data=new_df["DESCRIPTION"]+': '+new_df["CLINICAL_NOTES"]
    new_texts=vector_data.astype(str).tolist()
    new_embeddings = self._embed_model.encode(new_texts).tolist()
    new_metadatas=[new_df.iloc[i,3:].to_dict() for i in range(len(new_ids))]

    collection.add(
      documents=new_texts,
      embeddings=new_embeddings,
      metadatas=new_metadatas, #store everything except clinical notes and ids as metadata
      ids=new_ids
    )
  
  def query_chromadb(self, patient: str, query: str, result_template:str = "", top_n=3) -> str:
    if (patient=="") or (patient is None):
      return ""
    if (query=="") or (query is None):
      return ""

    collection = self.get_chromadb()
    query_embedding = self._embed_model.encode([query])[0].tolist()

    results = collection.query(
      query_embeddings=[query_embedding],
      n_results=top_n,
      # include=["documents","metadatas","distances"], #these are default query outputs so no need to specifiy
      where={"PATIENT_ID":patient} # specify patient
    )

    #TODO refine template for what info to include
    if result_template=="":
      result_template="""

      #{rank}: {desc} {st_date}\n

      {note}\n

      """

    context=""
    for i in range(len(results["ids"][0])):
      result_txt = result_template.format(
          rank=(i+1), #range is 0 indexed, increment for rank
          desc=results["metadatas"][0][i]["DESCRIPTION"],
          st_date=results["metadatas"][0][i]["START"],
          note=results["documents"][0][i])
      context = context+result_txt

    return context

  def get_summary(self, patient: str) -> str:
    vector_query="Get visit notes for this patient related to previous symptoms and diagnoses."
    summary_context=self.query_chromadb(patient,vector_query,"",10)
    summary_query_template="""Given several visit notes for a single patient, write a summary for the patient.\n

    You will not be provided infromation from all patient visits. The visits listed are most related to previous symptoms and diagnoses the patient has experienced.\n

    Visit Notes:\n

    {context}

    Summary:"""
    summary_query=summary_query_template.format(context=summary_context)
    # single input/output to chat model
    response = self._chatclient.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": summary_query}],
        temperature=0
    )
    return response.choices[0].message.content

  def answer_query(self, patient: str, query: str) -> str:
    # TODO: Example queries: tell me about incident ....
    #     Has this patient ...?
    #     Does this patient have a history of ...?
    # TODO: Find in vector database the most related docs to both 1. patient & 2. query
    rag=self.query_chromadb(patient,query)
    # TODO: Figure out how to utilize other columns.
    prompt_template="""

    You are an AI Assistant answering questions about a patient based on the relevant patient information provided.\n

    

    Patient Information:\n

    {RAG}



    Questions:\n

    {Query}



    Answer:"""
    filled_prompt=prompt_template.format(RAG=rag,Query=query)
    # Get model output
    response = self._chatclient.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": filled_prompt}],
        temperature=0
    )
    # TODO: Error handling for 0 choices
    print(response)
    return response.choices[0].message.content