Dorn4449 commited on
Commit
ecd4bb8
·
1 Parent(s): c638480

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -4,27 +4,43 @@ os.system("pip install streamlit")
4
  os.system("pip install torch torchvision")
5
 
6
  import streamlit as st
7
- from transformers import pipeline
8
 
9
- # Load the GPT-2 model
10
- generator = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
 
 
11
 
12
  # Streamlit app
13
  def main():
14
- st.title("GPT-2 Chatbot")
15
 
16
  # Style the user input text area
17
  st.markdown('<style>div.Widget.row-widget.stRadio > div{flex-direction:row;}</style>', unsafe_allow_html=True)
18
  user_input = st.text_area("You:", "Hello, I'm a language model", height=100)
19
 
 
 
20
  if st.button("Generate"):
21
- generated_responses = generator(user_input, max_length=200, num_return_sequences=2)
22
-
23
- for i, response in enumerate(generated_responses):
24
- # Style the response box
 
 
 
 
 
 
25
  st.markdown(
26
  f'<div style="border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-bottom: 10px;">'
27
- f'Response {i + 1}: {response["generated_text"]}'
28
  '</div>',
29
  unsafe_allow_html=True
30
  )
 
4
  os.system("pip install torch torchvision")
5
 
6
  import streamlit as st
7
+ from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
8
 
9
+ # Load the GPT-2 model and tokenizer
10
+ model_name = 'gpt2'
11
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
12
+ model = GPT2LMHeadModel.from_pretrained(model_name)
13
+
14
+ # Create chatbot pipeline
15
+ chatbot = pipeline('text-generation', model=model, tokenizer=tokenizer)
16
+
17
+ # Create translation pipeline
18
+ translation = pipeline('translation', model=model, tokenizer=tokenizer)
19
 
20
  # Streamlit app
21
  def main():
22
+ st.title("GPT-2 Chatbot & Translation")
23
 
24
  # Style the user input text area
25
  st.markdown('<style>div.Widget.row-widget.stRadio > div{flex-direction:row;}</style>', unsafe_allow_html=True)
26
  user_input = st.text_area("You:", "Hello, I'm a language model", height=100)
27
 
28
+ task = st.radio("Choose Task:", ("Chatbot", "Translation"))
29
+
30
  if st.button("Generate"):
31
+ if task == "Chatbot":
32
+ generated_response = chatbot(user_input, max_length=30, num_return_sequences=1)[0]['generated_text']
33
+ st.markdown(
34
+ f'<div style="border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-bottom: 10px;">'
35
+ f'Chatbot Response: {generated_response}'
36
+ '</div>',
37
+ unsafe_allow_html=True
38
+ )
39
+ elif task == "Translation":
40
+ translated_text = translation(user_input, max_length=30)[0]['translation_text']
41
  st.markdown(
42
  f'<div style="border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-bottom: 10px;">'
43
+ f'Translation: {translated_text}'
44
  '</div>',
45
  unsafe_allow_html=True
46
  )