Deva1211 commited on
Commit
68e5f1b
·
1 Parent(s): 9a3c26a

Fixed some errors

Browse files
Files changed (2) hide show
  1. app.py +53 -30
  2. requirements.txt +1 -1
app.py CHANGED
@@ -11,46 +11,69 @@ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
11
 
12
  # Define the prediction function
13
  def predict(message, history):
14
- # 'history' is a list of lists, where each inner list has a user and a bot message.
15
- # We need to format it for DialoGPT.
16
- history_transformer_format = []
17
- for user, bot in history:
18
- history_transformer_format.append(user)
19
- history_transformer_format.append(bot)
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Join the history and the new message, separated by the EOS token
22
- history_string = "".join(history_transformer_format)
23
- input_text = history_string + message + tokenizer.eos_token
24
 
25
- # Tokenize the input
26
- new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
27
 
28
- # Generate a response
29
- # The max_length is set to 1250 to allow for a decent conversation history.
30
- bot_output_ids = model.generate(
31
- new_user_input_ids,
32
- max_length=1250,
33
- pad_token_id=tokenizer.eos_token_id,
34
- no_repeat_ngram_size=3,
35
- do_sample=True,
36
- top_k=100,
37
- top_p=0.7,
38
- temperature=0.8
39
- )
 
40
 
41
- # Decode the response, skipping the input part
42
- response = tokenizer.decode(bot_output_ids[:, new_user_input_ids.shape[-1]:][0], skip_special_tokens=True)
43
-
44
- # Return an empty string to clear the textbox and the updated history
45
- return "", history + [[message, response]]
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Build the Gradio interface
48
  with gr.Blocks() as demo:
49
  gr.Markdown("## DialoGPT-medium Chatbot")
50
  gr.Markdown("This chatbot uses the microsoft/DialoGPT-medium model. Start typing to chat!")
51
 
52
- chatbot = gr.Chatbot()
53
- textbox = gr.Textbox(placeholder="Type your message here and press Enter")
54
 
55
  # When the user submits the textbox, call the 'predict' function
56
  textbox.submit(
 
11
 
12
  # Define the prediction function
13
  def predict(message, history):
14
+ try:
15
+ # Validate inputs
16
+ if not message or not message.strip():
17
+ return "", history
18
+
19
+ if history is None:
20
+ history = []
21
+
22
+ # 'history' is a list of lists, where each inner list has a user and a bot message.
23
+ # We need to format it for DialoGPT.
24
+ history_transformer_format = []
25
+ for exchange in history:
26
+ if isinstance(exchange, list) and len(exchange) >= 2:
27
+ user_msg, bot_msg = exchange[0], exchange[1]
28
+ if user_msg:
29
+ history_transformer_format.append(str(user_msg))
30
+ if bot_msg:
31
+ history_transformer_format.append(str(bot_msg))
32
 
33
+ # Join the history and the new message, separated by the EOS token
34
+ history_string = "".join(history_transformer_format)
35
+ input_text = history_string + str(message) + tokenizer.eos_token
36
 
37
+ # Tokenize the input
38
+ new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
39
 
40
+ # Generate a response
41
+ # The max_length is set to 1250 to allow for a decent conversation history.
42
+ with torch.no_grad():
43
+ bot_output_ids = model.generate(
44
+ new_user_input_ids,
45
+ max_length=1250,
46
+ pad_token_id=tokenizer.eos_token_id,
47
+ no_repeat_ngram_size=3,
48
+ do_sample=True,
49
+ top_k=100,
50
+ top_p=0.7,
51
+ temperature=0.8
52
+ )
53
 
54
+ # Decode the response, skipping the input part
55
+ response = tokenizer.decode(bot_output_ids[:, new_user_input_ids.shape[-1]:][0], skip_special_tokens=True)
56
+
57
+ # Clean up response
58
+ response = response.strip()
59
+ if not response:
60
+ response = "I'm not sure how to respond to that. Could you try rephrasing?"
61
+
62
+ # Return an empty string to clear the textbox and the updated history
63
+ return "", history + [[message, response]]
64
+
65
+ except Exception as e:
66
+ print(f"Error in predict function: {e}")
67
+ error_response = "Sorry, I encountered an error. Please try again."
68
+ return "", history + [[message, error_response]]
69
 
70
  # Build the Gradio interface
71
  with gr.Blocks() as demo:
72
  gr.Markdown("## DialoGPT-medium Chatbot")
73
  gr.Markdown("This chatbot uses the microsoft/DialoGPT-medium model. Start typing to chat!")
74
 
75
+ chatbot = gr.Chatbot(value=[], label="DialoGPT Conversation")
76
+ textbox = gr.Textbox(placeholder="Type your message here and press Enter", label="Message")
77
 
78
  # When the user submits the textbox, call the 'predict' function
79
  textbox.submit(
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  torch
2
  transformers
3
- gradio
 
1
  torch
2
  transformers
3
+ gradio>=3.50.0,<4.0.0