dumbot / app.py
Wilame Lima
First commit
40b3f8e
from functions import *
# set the title
st.sidebar.title(DASHBOARD_TITLE)
info_section = st.empty()
# add an explanation of what is NER and why it is important for medical tasks
st.sidebar.markdown(
f"""
Facebook blenderbot is a family of conversational models that are trained on a large dataset of conversations and can generate limited, yet sometimes coherent responses.
For this project, we are using the 400M distill version of the model. This model is smaller and faster than the original model, but it may not be as accurate. I have used Streamlit to create a simple chatbot interface that allows you to chat with the model and demonstrate how easy it is to use these models for conversational AI tasks.
Have fun, but don't expect too much from the model! It is a little dumb sometimes.
Model used: [{MODEL_PATH}]({MODEL_LINK})
"""
)
first_assistant_message = "Hello! I am a dumb bot. What is your dumb question?"
# clear conversation
if st.sidebar.button("Clear conversation"):
chat_history = [{'user':'assistant', 'content':first_assistant_message}]
st.session_state['chat_history'] = chat_history
st.rerun()
# Get the chat history
if "chat_history" not in st.session_state:
chat_history = [{'user':'assistant', 'content':first_assistant_message}]
st.session_state['chat_history'] = chat_history
else:
chat_history = st.session_state['chat_history']
# print the conversation
for message in chat_history:
with st.chat_message(message['user']):
st.write(message['content'])
# convert the chat history to a string to be passed to the model
# keep only last 4 messages
chat_history_str = "\n".join([message['content'] for message in chat_history[-4:] if 'content' in message])
# get the input from user
user_input = st.chat_input("Write something...")
if user_input:
with st.chat_message("user"):
st.write(user_input)
# load the tokenizer
info_section.info("Loading the tokenizer. This may take a while...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
inputs = tokenizer.encode_plus(chat_history_str,
user_input,
return_tensors="pt")
# get the model's response
info_section.info("Loading the model. This also may take a while...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
info_section.empty()
with st.spinner("Generating the response..."):
# generate the response
outputs = model.generate(**inputs)
# decode the outputs
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# append to the history
chat_history.append({'content':user_input, 'user':'user'})
chat_history.append({'content':response, 'user':'assistant'})
st.session_state['chat_history'] = chat_history
st.rerun()