|
|
import streamlit as st |
|
|
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration |
|
|
|
|
|
model_str = 'facebook/blenderbot-400M-distill' |
|
|
|
|
|
tokenizer = BlenderbotTokenizer.from_pretrained(model_str) |
|
|
model = BlenderbotForConditionalGeneration.from_pretrained(model_str) |
|
|
|
|
|
def chatbot(text): |
|
|
|
|
|
inputs = tokenizer([text], return_tensors='pt') |
|
|
|
|
|
response = model.generate( |
|
|
inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
max_length=100 |
|
|
) |
|
|
|
|
|
output = tokenizer.decode(response[0], skip_special_tokens=True) |
|
|
return output |
|
|
|
|
|
def main(): |
|
|
|
|
|
st.title("convo-bot") |
|
|
|
|
|
user_input = st.text_input("user:") |
|
|
|
|
|
if user_input: |
|
|
|
|
|
bot_response = chatbot(user_input) |
|
|
st.text_area("bot:", value=bot_response, height=200) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|