finalproject / llm_service.py
raagustin's picture
Fix typo
9ce4875
from contextlib import contextmanager
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import pandas as pd
from openai import OpenAI
from data_service import DataService
class LLMService(object):
def __init__(self):
self._openAIKey = None
self._data_service = None
@contextmanager
def build(self):
llm_service = None
if self._openAIKey is None:
raise ValueError("OPEN AI key was not provided and there is no default value.")
if self._data_service is None:
raise ValueError("To get the patient documents, a data service must be provided before building the LLMService.")
try:
llm_service = DefaultLLMService(self._openAIKey, self._data_service)
yield llm_service
finally:
if llm_service is not None:
llm_service.close()
def close(self):
raise Exception("Should not use the base class")
def get_summary(self, patient: str) -> str:
raise Exception("Should not use the base class")
def answer_query(self, patient: str, query: str) -> str:
raise Exception("Should not use the base class")
def with_key(self, api_key: str):
self._openAIKey = api_key
return self
def with_data_service(self, data_service: DataService):
self._data_service = data_service
return self
class DefaultLLMService(LLMService):
def __init__(self, api_key: str, data_service: DataService):
self._api_key = api_key
self._data_service = data_service
self._chatclient=OpenAI(api_key=self._api_key)
#TODO decide on embedding model, using one provided in notebook
self._embed_model = SentenceTransformer("all-MiniLM-L6-v2")
self.build_chromadb()
def close(self):
#raise Exception("Not implemented")
pass
def get_chromadb(self):
client=chromadb.Client(Settings(
persist_directory="./chroma_db"
))
collection_name = "patient_data"
return client.get_or_create_collection(name=collection_name)
# TODO: It probably makes no difference, but the reason I chose to use a decorator @initialize in data_service is that I learned somewhere that it was better to put as little logic in a constructor as possible because (at least in other languages), errors in constructors made everything complicated, and potentially slow initialization logic made things difficult to ... I don't remember. Debug or parallelize or something. Maybe all of them. The trade-off is, if you forgot to decorate the method with @initialize, bad things.
def build_chromadb(self):
collection = self.get_chromadb()
# texts = self._data_service.get_documents()
# metadatas = self._data_service.get_document_metadatas()
# TODO: I'm looking at the docs and I'm wondering 1. if this is the default embedding function anyway (I think it's good to bring it out and see we can adjust it; I'm just pointing this out), and 2. whether it would be equivalent to provide self._embed_model.encode as the embedding function of the collection at configuration time.
# embeddings = self._embed_model.encode(texts).tolist()
df=self._data_service.get_data()
all_ids = df["ENCOUNTER_ID"].astype(str).tolist()
existing_ids = collection.get(ids=all_ids)["ids"]
new_df = df[~df["ENCOUNTER_ID"].isin(existing_ids)] #get rows not in collection
new_ids=new_df["ENCOUNTER_ID"].astype(str).tolist()
# get data from new rows
#TODO what other info should be vectorized
# 1. add as a string the occupied columns?
vector_data=new_df["DESCRIPTION"]+': '+new_df["CLINICAL_NOTES"]
new_texts=vector_data.astype(str).tolist()
new_embeddings = self._embed_model.encode(new_texts).tolist()
new_metadatas=[new_df.iloc[i,3:].to_dict() for i in range(len(new_ids))]
collection.add(
documents=new_texts,
embeddings=new_embeddings,
metadatas=new_metadatas, #store everything except clinical notes and ids as metadata
ids=new_ids
)
def query_chromadb(self, patient: str, query: str, result_template:str = "", top_n=3) -> str:
if (patient=="") or (patient is None):
return ""
if (query=="") or (query is None):
return ""
collection = self.get_chromadb()
query_embedding = self._embed_model.encode([query])[0].tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_n,
# include=["documents","metadatas","distances"], #these are default query outputs so no need to specifiy
where={"PATIENT_ID":patient} # specify patient
)
#TODO refine template for what info to include
if result_template=="":
result_template="""
#{rank}: {desc} {st_date}\n
{note}\n
"""
context=""
for i in range(len(results["ids"][0])):
result_txt = result_template.format(
rank=(i+1), #range is 0 indexed, increment for rank
desc=results["metadatas"][0][i]["DESCRIPTION"],
st_date=results["metadatas"][0][i]["START"],
note=results["documents"][0][i])
context = context+result_txt
return context
def get_summary(self, patient: str) -> str:
vector_query="Get visit notes for this patient related to previous symptoms and diagnoses."
summary_context=self.query_chromadb(patient,vector_query,"",10)
summary_query_template="""Given several visit notes for a single patient, write a summary for the patient.\n
You will not be provided infromation from all patient visits. The visits listed are most related to previous symptoms and diagnoses the patient has experienced.\n
Visit Notes:\n
{context}
Summary:"""
summary_query=summary_query_template.format(context=summary_context)
# single input/output to chat model
response = self._chatclient.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": summary_query}],
temperature=0
)
return response.choices[0].message.content
def answer_query(self, patient: str, query: str) -> str:
# TODO: Example queries: tell me about incident ....
# Has this patient ...?
# Does this patient have a history of ...?
# TODO: Find in vector database the most related docs to both 1. patient & 2. query
rag=self.query_chromadb(patient,query)
# TODO: Figure out how to utilize other columns.
prompt_template="""
You are an AI Assistant answering questions about a patient based on the relevant patient information provided.\n
Patient Information:\n
{RAG}
Questions:\n
{Query}
Answer:"""
filled_prompt=prompt_template.format(RAG=rag,Query=query)
# Get model output
response = self._chatclient.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": filled_prompt}],
temperature=0
)
# TODO: Error handling for 0 choices
print(response)
return response.choices[0].message.content