Deva1211 commited on
Commit
12a2568
·
1 Parent(s): 68e5f1b

Api issue fix

Browse files
Files changed (2) hide show
  1. app.py +51 -41
  2. requirements.txt +3 -3
app.py CHANGED
@@ -5,40 +5,37 @@ import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  # Load the tokenizer and model
8
- # Using a specific revision to ensure compatibility
9
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
10
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
11
 
12
- # Define the prediction function
13
- def predict(message, history):
14
  try:
15
- # Validate inputs
16
  if not message or not message.strip():
17
- return "", history
18
 
19
- if history is None:
20
- history = []
21
-
22
- # 'history' is a list of lists, where each inner list has a user and a bot message.
23
- # We need to format it for DialoGPT.
24
  history_transformer_format = []
25
  for exchange in history:
26
- if isinstance(exchange, list) and len(exchange) >= 2:
27
  user_msg, bot_msg = exchange[0], exchange[1]
28
  if user_msg:
29
  history_transformer_format.append(str(user_msg))
30
  if bot_msg:
31
  history_transformer_format.append(str(bot_msg))
32
 
33
- # Join the history and the new message, separated by the EOS token
34
  history_string = "".join(history_transformer_format)
35
  input_text = history_string + str(message) + tokenizer.eos_token
36
 
37
  # Tokenize the input
38
  new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
39
 
40
- # Generate a response
41
- # The max_length is set to 1250 to allow for a decent conversation history.
42
  with torch.no_grad():
43
  bot_output_ids = model.generate(
44
  new_user_input_ids,
@@ -48,41 +45,54 @@ def predict(message, history):
48
  do_sample=True,
49
  top_k=100,
50
  top_p=0.7,
51
- temperature=0.8
 
52
  )
53
 
54
- # Decode the response, skipping the input part
55
- response = tokenizer.decode(bot_output_ids[:, new_user_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
56
 
57
- # Clean up response
58
- response = response.strip()
59
  if not response:
60
- response = "I'm not sure how to respond to that. Could you try rephrasing?"
61
 
62
- # Return an empty string to clear the textbox and the updated history
63
- return "", history + [[message, response]]
64
 
65
  except Exception as e:
66
- print(f"Error in predict function: {e}")
67
- error_response = "Sorry, I encountered an error. Please try again."
68
- return "", history + [[message, error_response]]
69
-
70
- # Build the Gradio interface
71
- with gr.Blocks() as demo:
72
- gr.Markdown("## DialoGPT-medium Chatbot")
73
- gr.Markdown("This chatbot uses the microsoft/DialoGPT-medium model. Start typing to chat!")
74
-
75
- chatbot = gr.Chatbot(value=[], label="DialoGPT Conversation")
76
- textbox = gr.Textbox(placeholder="Type your message here and press Enter", label="Message")
77
 
78
- # When the user submits the textbox, call the 'predict' function
79
- textbox.submit(
80
- predict,
81
- inputs=[textbox, chatbot],
82
- outputs=[textbox, chatbot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
 
84
 
85
- # Enable the queue for better handling of multiple users and to enable API usage
86
- demo.queue()
87
  # Launch the app
88
- demo.launch()
 
 
 
 
 
 
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  # Load the tokenizer and model
8
+ print("Loading DialoGPT-medium model...")
9
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
10
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
11
+ print("Model loaded successfully!")
12
 
13
+ # Define the chat function for the modern ChatInterface
14
+ def chat_fn(message, history):
15
  try:
16
+ # Validate input
17
  if not message or not message.strip():
18
+ return "Please enter a message."
19
 
20
+ # Format history for DialoGPT
21
+ # History comes as a list of [user_msg, bot_msg] pairs
 
 
 
22
  history_transformer_format = []
23
  for exchange in history:
24
+ if len(exchange) >= 2:
25
  user_msg, bot_msg = exchange[0], exchange[1]
26
  if user_msg:
27
  history_transformer_format.append(str(user_msg))
28
  if bot_msg:
29
  history_transformer_format.append(str(bot_msg))
30
 
31
+ # Create the input text
32
  history_string = "".join(history_transformer_format)
33
  input_text = history_string + str(message) + tokenizer.eos_token
34
 
35
  # Tokenize the input
36
  new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
37
 
38
+ # Generate a response with memory management
 
39
  with torch.no_grad():
40
  bot_output_ids = model.generate(
41
  new_user_input_ids,
 
45
  do_sample=True,
46
  top_k=100,
47
  top_p=0.7,
48
+ temperature=0.8,
49
+ early_stopping=True
50
  )
51
 
52
+ # Decode the response
53
+ response = tokenizer.decode(
54
+ bot_output_ids[:, new_user_input_ids.shape[-1]:][0],
55
+ skip_special_tokens=True
56
+ ).strip()
57
 
58
+ # Fallback for empty responses
 
59
  if not response:
60
+ response = "I'm not sure how to respond to that. Could you try rephrasing your question?"
61
 
62
+ return response
 
63
 
64
  except Exception as e:
65
+ print(f"Error in chat function: {e}")
66
+ return "Sorry, I encountered an error processing your message. Please try again."
 
 
 
 
 
 
 
 
 
67
 
68
+ # Create the Gradio ChatInterface
69
+ demo = gr.ChatInterface(
70
+ fn=chat_fn,
71
+ title="🤖 DialoGPT-medium Chatbot",
72
+ description="Chat with Microsoft's DialoGPT-medium model. This conversational AI can engage in natural dialogue!",
73
+ examples=[
74
+ "Hello, how are you?",
75
+ "What's your favorite movie?",
76
+ "Tell me a joke",
77
+ "What do you think about artificial intelligence?"
78
+ ],
79
+ cache_examples=False,
80
+ retry_btn="🔄 Retry",
81
+ undo_btn="↶ Undo",
82
+ clear_btn="🗑️ Clear",
83
+ submit_btn="Send",
84
+ textbox=gr.Textbox(
85
+ placeholder="Type your message here...",
86
+ container=False,
87
+ scale=7
88
  )
89
+ )
90
 
 
 
91
  # Launch the app
92
+ if __name__ == "__main__":
93
+ demo.queue(max_size=20) # Enable queue for better concurrent handling
94
+ demo.launch(
95
+ server_name="0.0.0.0",
96
+ server_port=7860,
97
+ share=False
98
+ )
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- torch
2
- transformers
3
- gradio>=3.50.0,<4.0.0
 
1
+ torch>=1.9.0
2
+ transformers>=4.21.0
3
+ gradio>=4.0.0