MistralAI / src /chain.py
John Graham Reynolds
update prompt
1898c5b
import os
import streamlit as st
from langchain_community.vectorstores import FAISS
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
class GlossaryChain:
def __init__(self):
self.vector_store = load_vector_store()
self.retriever = self.vector_store.as_retriever()
self.llm = ChatMistralAI(
# model="mistral-large-latest", # "error","message":"Service tier capacity exceeded for this model.","type":"service_tier_capacity_exceeded
model="open-mistral-7b", # we must use the open-weight model
mistral_api_key=MISTRAL_API_KEY,
temperature=0.2
)
self.prompt = PromptTemplate.from_template(
"Answer the question based on the following context given by the Vanderbilt University Medical Center Glossary: \
\
{context}\n\nQuestion: \
\
{question}\n\nAnswer:"
)
self.chain = (
{"context": self.retriever | format_docs, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
def stream(self, input: str) -> str:
return self.chain.stream(input=input)
def invoke(self, input: str) -> str:
return self.chain.invoke(input=input)
def format_docs(docs: list[Document]) -> str:
"""Format retrieved documents into a readable string"""
return "\n\n".join(doc.page_content for doc in docs)
@st.cache_resource
def load_vector_store() -> FAISS:
import os
# Get the absolute path to the faiss_index directory
current_dir = os.path.dirname(os.path.abspath(__file__))
faiss_path = os.path.join(os.path.dirname(current_dir), "faiss_index")
return FAISS.load_local(
folder_path=faiss_path,
embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY),
allow_dangerous_deserialization=True
)