John Graham Reynolds commited on
Commit
2eb582d
·
1 Parent(s): 27c1b21

create GlossaryChain for orchestrating chain functionality

Browse files
Files changed (1) hide show
  1. src/chain.py +54 -0
src/chain.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_mistralai.embeddings import MistralAIEmbeddings
5
+ from langchain_mistralai.chat_models import ChatMistralAI
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.documents import Document
10
+
11
+ MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
12
+
13
+ class GlossaryChain:
14
+ def __init__(self):
15
+ self.vector_store = load_vector_store()
16
+ self.retriever = self.vector_store.as_retriever()
17
+ self.llm = ChatMistralAI(
18
+ model="mistral-large-latest",
19
+ mistral_api_key=MISTRAL_API_KEY,
20
+ temperature=0.2
21
+ )
22
+ self.prompt = PromptTemplate.from_template(
23
+ "Answer the question based on the following context about Vanderbilt University Medical Center: \
24
+ \
25
+ {context}\n\nQuestion: \
26
+ \
27
+ {question}\n\nAnswer:"
28
+ )
29
+ self.chain = (
30
+ {"context": self.retriever | format_docs, "question": RunnablePassthrough()}
31
+ | self.prompt
32
+ | self.llm
33
+ | StrOutputParser()
34
+ )
35
+
36
+ def stream(self, input: str) -> str:
37
+ return self.chain.stream(input=input)
38
+
39
+ def invoke(self, input: str) -> str:
40
+ return self.chain.invoke(input=input)
41
+
42
+
43
+ def format_docs(docs: list[Document]) -> str:
44
+ """Format retrieved documents into a readable string"""
45
+ return "\n\n".join(doc.page_content for doc in docs)
46
+
47
+ @st.cache_data
48
+ def load_vector_store() -> FAISS:
49
+ return FAISS.load_local(
50
+ folder_path="faiss_index",
51
+ embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY),
52
+ allow_dangerous_deserialization=True
53
+ )
54
+