import os import streamlit as st from langchain_community.vectorstores import FAISS from langchain_mistralai.embeddings import MistralAIEmbeddings from langchain_mistralai.chat_models import ChatMistralAI from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_core.documents import Document MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY") class GlossaryChain: def __init__(self): self.vector_store = load_vector_store() self.retriever = self.vector_store.as_retriever() self.llm = ChatMistralAI( # model="mistral-large-latest", # "error","message":"Service tier capacity exceeded for this model.","type":"service_tier_capacity_exceeded model="open-mistral-7b", # we must use the open-weight model mistral_api_key=MISTRAL_API_KEY, temperature=0.2 ) self.prompt = PromptTemplate.from_template( "Answer the question based on the following context given by the Vanderbilt University Medical Center Glossary: \ \ {context}\n\nQuestion: \ \ {question}\n\nAnswer:" ) self.chain = ( {"context": self.retriever | format_docs, "question": RunnablePassthrough()} | self.prompt | self.llm | StrOutputParser() ) def stream(self, input: str) -> str: return self.chain.stream(input=input) def invoke(self, input: str) -> str: return self.chain.invoke(input=input) def format_docs(docs: list[Document]) -> str: """Format retrieved documents into a readable string""" return "\n\n".join(doc.page_content for doc in docs) @st.cache_resource def load_vector_store() -> FAISS: import os # Get the absolute path to the faiss_index directory current_dir = os.path.dirname(os.path.abspath(__file__)) faiss_path = os.path.join(os.path.dirname(current_dir), "faiss_index") return FAISS.load_local( folder_path=faiss_path, embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY), allow_dangerous_deserialization=True )