iitm-chatbot / app.py
Photon08's picture
Update app.py
9cb093c
from langchain.vectorstores import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
import streamlit as st
persist_direc ="chroma_db"
embeddings = SentenceTransformerEmbeddings(model_name="multi-qa-MiniLM-L6-cos-v1")
vectordb = Chroma(persist_directory=persist_direc, embedding_function=embeddings)
print("Db Fetched!")
st.header("Chatbot for IIT Madras BS Degree: Ask questions to the Handbook")
st.caption("Developed by Indranil Bhattacharyya (21F1005840)")
st.text("Please note that this is an experimental app, if you find something inappropiate report it to the developer")
with st.sidebar:
st.write("Chat history")
l = len(st.session_state)
if l>0:
for i in st.session_state:
with st.chat_message("human"):
st.write(i)
with st.chat_message('assitant'):
st.write(st.session_state[i])
st.divider()
model_name = st.radio("Choose the LLM for the chat:", ["google/flan-t5-large",
"google/flan-t5-xxl",
"google/flan-ul2",
"facebook/m2m100_1.2B"])
repo_id = model_name
llm = HuggingFaceHub(
repo_id=repo_id, model_kwargs={"temperature": 1},
huggingfacehub_api_token="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"
)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectordb.as_retriever()
)
c = 0
prompt = st.chat_input("Say something")
if prompt:
c += 1
with st.spinner('Wait for it...'):
result = qa_chain({"query":prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message('assistant'):
st.write(result["result"])
key_name = 'question' + str(c)
st.session_state[prompt] = result["result"]