Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_mistralai.embeddings import MistralAIEmbeddings | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.documents import Document | |
| MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY") | |
| class GlossaryChain: | |
| def __init__(self): | |
| self.vector_store = load_vector_store() | |
| self.retriever = self.vector_store.as_retriever() | |
| self.llm = ChatMistralAI( | |
| # model="mistral-large-latest", # "error","message":"Service tier capacity exceeded for this model.","type":"service_tier_capacity_exceeded | |
| model="open-mistral-7b", # we must use the open-weight model | |
| mistral_api_key=MISTRAL_API_KEY, | |
| temperature=0.2 | |
| ) | |
| self.prompt = PromptTemplate.from_template( | |
| "Answer the question based on the following context given by the Vanderbilt University Medical Center Glossary: \ | |
| \ | |
| {context}\n\nQuestion: \ | |
| \ | |
| {question}\n\nAnswer:" | |
| ) | |
| self.chain = ( | |
| {"context": self.retriever | format_docs, "question": RunnablePassthrough()} | |
| | self.prompt | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| def stream(self, input: str) -> str: | |
| return self.chain.stream(input=input) | |
| def invoke(self, input: str) -> str: | |
| return self.chain.invoke(input=input) | |
| def format_docs(docs: list[Document]) -> str: | |
| """Format retrieved documents into a readable string""" | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def load_vector_store() -> FAISS: | |
| import os | |
| # Get the absolute path to the faiss_index directory | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| faiss_path = os.path.join(os.path.dirname(current_dir), "faiss_index") | |
| return FAISS.load_local( | |
| folder_path=faiss_path, | |
| embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=MISTRAL_API_KEY), | |
| allow_dangerous_deserialization=True | |
| ) | |