gpt-chatbot / app.py
Dorn4449's picture
Update app.py
ecd4bb8
import os
os.system("pip install transformers")
os.system("pip install streamlit")
os.system("pip install torch torchvision")
import streamlit as st
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
# Load the GPT-2 model and tokenizer
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Create chatbot pipeline
chatbot = pipeline('text-generation', model=model, tokenizer=tokenizer)
# Create translation pipeline
translation = pipeline('translation', model=model, tokenizer=tokenizer)
# Streamlit app
def main():
st.title("GPT-2 Chatbot & Translation")
# Style the user input text area
st.markdown('<style>div.Widget.row-widget.stRadio > div{flex-direction:row;}</style>', unsafe_allow_html=True)
user_input = st.text_area("You:", "Hello, I'm a language model", height=100)
task = st.radio("Choose Task:", ("Chatbot", "Translation"))
if st.button("Generate"):
if task == "Chatbot":
generated_response = chatbot(user_input, max_length=30, num_return_sequences=1)[0]['generated_text']
st.markdown(
f'<div style="border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-bottom: 10px;">'
f'Chatbot Response: {generated_response}'
'</div>',
unsafe_allow_html=True
)
elif task == "Translation":
translated_text = translation(user_input, max_length=30)[0]['translation_text']
st.markdown(
f'<div style="border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-bottom: 10px;">'
f'Translation: {translated_text}'
'</div>',
unsafe_allow_html=True
)
if __name__ == "__main__":
main()