mediAI / app.py
A1ee's picture
Update app.py
a90e530 verified
Raw
History Blame Contribute Delete
2.93 kB
import os
import streamlit as st
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
DB_FAISS_PATH = "vectorstore/db_faiss"
@st.cache_resource
def get_vectorstore():
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
return db
def set_custom_prompt(custom_prompt_template):
return PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
def load_llm(huggingface_repo_id, HF_TOKEN):
llm = HuggingFaceEndpoint(
repo_id=huggingface_repo_id,
temperature=0.5,
huggingfacehub_api_token=HF_TOKEN,
model_kwargs={"max_length": 512}
)
return llm
def main():
st.title("Ask Medi AI!")
if 'messages' not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message(message['role']).markdown(message['content'])
prompt = st.chat_input("Pass your prompt here")
if prompt:
st.chat_message('user').markdown(prompt)
st.session_state.messages.append({'role':'user', 'content': prompt})
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer user's question.
If you dont know the answer, just say that you dont know, dont try to make up an answer.
Dont provide anything out of the given context.
Context: {context}
Question: {question}
Begin your answer directly, talk in a professional way. Act like a doctor, and talk in a friendly way.
"""
# ✅ Use a working HuggingFace model
HUGGINGFACE_REPO_ID = "HuggingFaceH4/zephyr-7b-beta"
HF_TOKEN = os.environ.get("HF_TOKEN")
try:
vectorstore = get_vectorstore()
if vectorstore is None:
st.error("Failed to load the vector store")
qa_chain = RetrievalQA.from_chain_type(
llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID, HF_TOKEN=HF_TOKEN),
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
return_source_documents=True,
chain_type_kwargs={'prompt': set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
)
response = qa_chain.invoke({'query': prompt})
result = response["result"]
st.chat_message('assistant').markdown(result)
st.session_state.messages.append({'role': 'assistant', 'content': result})
except Exception as e:
st.error(f"Error: {str(e)}")
if __name__ == "__main__":
main()