mishiawan commited on
Commit
dffa8fd
·
verified ·
1 Parent(s): cea23c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -30
app.py CHANGED
@@ -2,10 +2,10 @@ import streamlit as st
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
 
5
- # Check if PyTorch is installed correctly and configure device
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- # Load pre-trained GPT-2 (DialoGPT is fine-tuned from GPT-2 for conversations)
9
  @st.cache_resource
10
  def load_model():
11
  model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium").to(device)
@@ -14,45 +14,49 @@ def load_model():
14
 
15
  model, tokenizer = load_model()
16
 
17
- # Function to generate responses from the model
18
- def get_response(input_text):
19
- # Encode the input and append the end-of-sentence token
20
- inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt").to(device)
21
-
22
- # Generate the response
23
  with torch.no_grad():
24
- outputs = model.generate(inputs, max_length=150, num_return_sequences=1,
25
- no_repeat_ngram_size=3, top_k=50, top_p=0.95, temperature=0.7)
26
-
27
- # Decode the response
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- return response
 
 
 
 
30
 
31
  # Streamlit app interface
32
- st.title("Chat with ChatGPT-like Bot")
33
-
34
- # Display instructions
35
- st.write("### Chat with the DialoGPT model. Type your message below and press Enter!")
36
 
37
- # Maintain a conversation history
38
  if "messages" not in st.session_state:
39
  st.session_state.messages = []
40
 
41
- # Display the conversation so far
42
  for message in st.session_state.messages:
43
- st.markdown(f"**{message['role']}**: {message['content']}")
44
 
45
- # Text input for user to interact with the chatbot
46
- user_input = st.text_input("You: ", "")
47
 
48
- # When the user submits a message, generate a response
49
  if user_input:
50
- # Add user message to conversation history
51
  st.session_state.messages.append({"role": "User", "content": user_input})
52
-
53
- # Generate and display the chatbot's response
54
- response = get_response(user_input)
 
 
 
55
  st.session_state.messages.append({"role": "Chatbot", "content": response})
56
 
57
- # Keep the conversation going
58
- st.text_input("You: ", "", key="input")
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
 
5
+ # Configure device (CPU/GPU)
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
+ # Load the model and tokenizer
9
  @st.cache_resource
10
  def load_model():
11
  model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium").to(device)
 
14
 
15
  model, tokenizer = load_model()
16
 
17
+ # Function to generate chatbot response
18
+ def get_response(conversation_history, user_input):
19
+ input_text = " ".join(conversation_history) + " " + user_input + tokenizer.eos_token
20
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
21
+
 
22
  with torch.no_grad():
23
+ outputs = model.generate(
24
+ inputs,
25
+ max_length=75,
26
+ num_return_sequences=1,
27
+ no_repeat_ngram_size=2,
28
+ top_k=40,
29
+ top_p=0.8,
30
+ temperature=0.7
31
+ )
32
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
33
 
34
  # Streamlit app interface
35
+ st.title("Chatbot with Hugging Face Model")
36
+ st.write("### Chat with the chatbot powered by DialoGPT. Type your message below!")
 
 
37
 
38
+ # Initialize conversation history
39
  if "messages" not in st.session_state:
40
  st.session_state.messages = []
41
 
42
+ # Display chat history
43
  for message in st.session_state.messages:
44
+ st.markdown(f"{message['role']}: {message['content']}")
45
 
46
+ # User input
47
+ user_input = st.text_input("You: ", key="user_input")
48
 
49
+ # Generate response if user submits a message
50
  if user_input:
51
+ # Add user input to conversation history
52
  st.session_state.messages.append({"role": "User", "content": user_input})
53
+
54
+ # Prepare context for the chatbot
55
+ history = [msg["content"] for msg in st.session_state.messages[-3:] if msg["role"] == "User"]
56
+
57
+ # Generate chatbot response
58
+ response = get_response(history, user_input)
59
  st.session_state.messages.append({"role": "Chatbot", "content": response})
60
 
61
+ # Clear input box for new input
62
+ st.text_input("You: ", key="user_input", value="")