Spaces:
Sleeping
Sleeping
| from contextlib import contextmanager | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| import pandas as pd | |
| from openai import OpenAI | |
| from data_service import DataService | |
| class LLMService(object): | |
| def __init__(self): | |
| self._openAIKey = None | |
| self._data_service = None | |
| def build(self): | |
| llm_service = None | |
| if self._openAIKey is None: | |
| raise ValueError("OPEN AI key was not provided and there is no default value.") | |
| if self._data_service is None: | |
| raise ValueError("To get the patient documents, a data service must be provided before building the LLMService.") | |
| try: | |
| llm_service = DefaultLLMService(self._openAIKey, self._data_service) | |
| yield llm_service | |
| finally: | |
| if llm_service is not None: | |
| llm_service.close() | |
| def close(self): | |
| raise Exception("Should not use the base class") | |
| def get_summary(self, patient: str) -> str: | |
| raise Exception("Should not use the base class") | |
| def answer_query(self, patient: str, query: str) -> str: | |
| raise Exception("Should not use the base class") | |
| def with_key(self, api_key: str): | |
| self._openAIKey = api_key | |
| return self | |
| def with_data_service(self, data_service: DataService): | |
| self._data_service = data_service | |
| return self | |
| class DefaultLLMService(LLMService): | |
| def __init__(self, api_key: str, data_service: DataService): | |
| self._api_key = api_key | |
| self._data_service = data_service | |
| self._chatclient=OpenAI(api_key=self._api_key) | |
| #TODO decide on embedding model, using one provided in notebook | |
| self._embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| self.build_chromadb() | |
| def close(self): | |
| #raise Exception("Not implemented") | |
| pass | |
| def get_chromadb(self): | |
| client=chromadb.Client(Settings( | |
| persist_directory="./chroma_db" | |
| )) | |
| collection_name = "patient_data" | |
| return client.get_or_create_collection(name=collection_name) | |
| # TODO: It probably makes no difference, but the reason I chose to use a decorator @initialize in data_service is that I learned somewhere that it was better to put as little logic in a constructor as possible because (at least in other languages), errors in constructors made everything complicated, and potentially slow initialization logic made things difficult to ... I don't remember. Debug or parallelize or something. Maybe all of them. The trade-off is, if you forgot to decorate the method with @initialize, bad things. | |
| def build_chromadb(self): | |
| collection = self.get_chromadb() | |
| # texts = self._data_service.get_documents() | |
| # metadatas = self._data_service.get_document_metadatas() | |
| # TODO: I'm looking at the docs and I'm wondering 1. if this is the default embedding function anyway (I think it's good to bring it out and see we can adjust it; I'm just pointing this out), and 2. whether it would be equivalent to provide self._embed_model.encode as the embedding function of the collection at configuration time. | |
| # embeddings = self._embed_model.encode(texts).tolist() | |
| df=self._data_service.get_data() | |
| all_ids = df["ENCOUNTER_ID"].astype(str).tolist() | |
| existing_ids = collection.get(ids=all_ids)["ids"] | |
| new_df = df[~df["ENCOUNTER_ID"].isin(existing_ids)] #get rows not in collection | |
| new_ids=new_df["ENCOUNTER_ID"].astype(str).tolist() | |
| # get data from new rows | |
| #TODO what other info should be vectorized | |
| # 1. add as a string the occupied columns? | |
| vector_data=new_df["DESCRIPTION"]+': '+new_df["CLINICAL_NOTES"] | |
| new_texts=vector_data.astype(str).tolist() | |
| new_embeddings = self._embed_model.encode(new_texts).tolist() | |
| new_metadatas=[new_df.iloc[i,3:].to_dict() for i in range(len(new_ids))] | |
| collection.add( | |
| documents=new_texts, | |
| embeddings=new_embeddings, | |
| metadatas=new_metadatas, #store everything except clinical notes and ids as metadata | |
| ids=new_ids | |
| ) | |
| def query_chromadb(self, patient: str, query: str, result_template:str = "", top_n=3) -> str: | |
| if (patient=="") or (patient is None): | |
| return "" | |
| if (query=="") or (query is None): | |
| return "" | |
| collection = self.get_chromadb() | |
| query_embedding = self._embed_model.encode([query])[0].tolist() | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_n, | |
| # include=["documents","metadatas","distances"], #these are default query outputs so no need to specifiy | |
| where={"PATIENT_ID":patient} # specify patient | |
| ) | |
| #TODO refine template for what info to include | |
| if result_template=="": | |
| result_template=""" | |
| #{rank}: {desc} {st_date}\n | |
| {note}\n | |
| """ | |
| context="" | |
| for i in range(len(results["ids"][0])): | |
| result_txt = result_template.format( | |
| rank=(i+1), #range is 0 indexed, increment for rank | |
| desc=results["metadatas"][0][i]["DESCRIPTION"], | |
| st_date=results["metadatas"][0][i]["START"], | |
| note=results["documents"][0][i]) | |
| context = context+result_txt | |
| return context | |
| def get_summary(self, patient: str) -> str: | |
| vector_query="Get visit notes for this patient related to previous symptoms and diagnoses." | |
| summary_context=self.query_chromadb(patient,vector_query,"",10) | |
| summary_query_template="""Given several visit notes for a single patient, write a summary for the patient.\n | |
| You will not be provided infromation from all patient visits. The visits listed are most related to previous symptoms and diagnoses the patient has experienced.\n | |
| Visit Notes:\n | |
| {context} | |
| Summary:""" | |
| summary_query=summary_query_template.format(context=summary_context) | |
| # single input/output to chat model | |
| response = self._chatclient.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": summary_query}], | |
| temperature=0 | |
| ) | |
| return response.choices[0].message.content | |
| def answer_query(self, patient: str, query: str) -> str: | |
| # TODO: Example queries: tell me about incident .... | |
| # Has this patient ...? | |
| # Does this patient have a history of ...? | |
| # TODO: Find in vector database the most related docs to both 1. patient & 2. query | |
| rag=self.query_chromadb(patient,query) | |
| # TODO: Figure out how to utilize other columns. | |
| prompt_template=""" | |
| You are an AI Assistant answering questions about a patient based on the relevant patient information provided.\n | |
| Patient Information:\n | |
| {RAG} | |
| Questions:\n | |
| {Query} | |
| Answer:""" | |
| filled_prompt=prompt_template.format(RAG=rag,Query=query) | |
| # Get model output | |
| response = self._chatclient.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": filled_prompt}], | |
| temperature=0 | |
| ) | |
| # TODO: Error handling for 0 choices | |
| print(response) | |
| return response.choices[0].message.content |