File size: 2,268 Bytes
2eb582d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102a4d9
 
2eb582d
 
 
 
1898c5b
2eb582d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102a4d9
2eb582d
 
 
 
 
102a4d9
2eb582d
102a4d9
 
 
 
2eb582d
102a4d9
2eb582d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import streamlit as st
from langchain_community.vectorstores import FAISS
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document

MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")

class GlossaryChain:
    def __init__(self):
        self.vector_store = load_vector_store()
        self.retriever = self.vector_store.as_retriever()
        self.llm = ChatMistralAI(
            # model="mistral-large-latest", # "error","message":"Service tier capacity exceeded for this model.","type":"service_tier_capacity_exceeded
            model="open-mistral-7b", # we must use the open-weight model
            mistral_api_key=MISTRAL_API_KEY,
            temperature=0.2
        )
        self.prompt = PromptTemplate.from_template(
            "Answer the question based on the following context given by the Vanderbilt University Medical Center Glossary: \
            \
            {context}\n\nQuestion: \
            \
            {question}\n\nAnswer:"
        )
        self.chain = (
            {"context": self.retriever | format_docs, "question": RunnablePassthrough()}
            | self.prompt
            | self.llm
            | StrOutputParser()
        )

    def stream(self, input: str) -> str:
        return self.chain.stream(input=input)

    def invoke(self, input: str) -> str:
        return self.chain.invoke(input=input)


def format_docs(docs: list[Document]) -> str:
    """Format retrieved documents into a readable string"""
    return "\n\n".join(doc.page_content for doc in docs)

@st.cache_resource
def load_vector_store() -> FAISS:
    import os
    # Get the absolute path to the faiss_index directory
    current_dir = os.path.dirname(os.path.abspath(__file__))
    faiss_path = os.path.join(os.path.dirname(current_dir), "faiss_index")
    return FAISS.load_local(
        folder_path=faiss_path, 
        embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY),
        allow_dangerous_deserialization=True
    )