Deva1211 commited on
Commit
de0ccfd
·
1 Parent(s): 3afd9f6

api error 3

Browse files
Files changed (1) hide show
  1. app.py +49 -36
app.py CHANGED
@@ -4,33 +4,44 @@ import gradio as gr
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
- # Load the tokenizer and model
8
  print("Loading DialoGPT-medium model...")
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
10
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
11
  print("Model loaded successfully!")
12
 
13
- # Define the chat function for the modern ChatInterface
14
- def chat_fn(message, history):
15
  try:
16
  # Validate input
17
  if not message or not message.strip():
18
  return "Please enter a message."
19
 
20
- # Format history for DialoGPT
21
- # History comes as a list of [user_msg, bot_msg] pairs
22
- history_transformer_format = []
23
- for exchange in history:
24
- if len(exchange) >= 2:
25
- user_msg, bot_msg = exchange[0], exchange[1]
26
- if user_msg:
27
- history_transformer_format.append(str(user_msg))
28
- if bot_msg:
29
- history_transformer_format.append(str(bot_msg))
30
-
31
- # Create the input text
32
- history_string = "".join(history_transformer_format)
33
- input_text = history_string + str(message) + tokenizer.eos_token
 
 
 
 
 
 
 
 
 
34
 
35
  # Tokenize the input
36
  new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
@@ -39,14 +50,13 @@ def chat_fn(message, history):
39
  with torch.no_grad():
40
  bot_output_ids = model.generate(
41
  new_user_input_ids,
42
- max_length=1250,
43
  pad_token_id=tokenizer.eos_token_id,
44
  no_repeat_ngram_size=3,
45
  do_sample=True,
46
- top_k=100,
47
  top_p=0.7,
48
- temperature=0.8,
49
- early_stopping=True
50
  )
51
 
52
  # Decode the response
@@ -55,35 +65,38 @@ def chat_fn(message, history):
55
  skip_special_tokens=True
56
  ).strip()
57
 
58
- # Fallback for empty responses
59
  if not response:
60
- response = "I'm not sure how to respond to that. Could you try rephrasing your question?"
 
 
 
 
61
 
62
  return response
63
 
64
  except Exception as e:
65
- print(f"Error in chat function: {e}")
66
- return "Sorry, I encountered an error processing your message. Please try again."
67
 
68
- # Create the Gradio ChatInterface
69
  demo = gr.ChatInterface(
70
- fn=chat_fn,
71
- title="🤖 DialoGPT-medium Chatbot",
72
- description="Chat with Microsoft's DialoGPT-medium model. This conversational AI can engage in natural dialogue!",
73
  examples=[
74
- "Hello, how are you?",
75
- "What's your favorite movie?",
76
  "Tell me a joke",
77
- "What do you think about artificial intelligence?"
78
  ],
79
  cache_examples=False
80
  )
81
 
82
- # Launch the app
83
  if __name__ == "__main__":
84
- demo.queue(max_size=20) # Enable queue for better concurrent handling
85
  demo.launch(
 
86
  server_name="0.0.0.0",
87
- server_port=7860,
88
- share=False
89
  )
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
 
7
  print("Loading DialoGPT-medium model...")
8
+
9
+ # Load the tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
11
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
12
+
13
  print("Model loaded successfully!")
14
 
15
+ # Define the prediction function that works with the modern format
16
+ def predict(message, history):
17
  try:
18
  # Validate input
19
  if not message or not message.strip():
20
  return "Please enter a message."
21
 
22
+ # Format history for DialoGPT - handle both old and new formats
23
+ history_transformer_format = ""
24
+
25
+ # Handle the new 'messages' format (list of dicts)
26
+ if history and isinstance(history[0], dict):
27
+ for turn in history:
28
+ if turn.get("role") == "user":
29
+ history_transformer_format += turn["content"] + tokenizer.eos_token
30
+ elif turn.get("role") == "assistant":
31
+ history_transformer_format += turn["content"] + tokenizer.eos_token
32
+
33
+ # Handle the old 'tuples' format (list of lists)
34
+ elif history and isinstance(history[0], list):
35
+ for exchange in history:
36
+ if len(exchange) >= 2:
37
+ user_msg, bot_msg = exchange[0], exchange[1]
38
+ if user_msg:
39
+ history_transformer_format += str(user_msg) + tokenizer.eos_token
40
+ if bot_msg:
41
+ history_transformer_format += str(bot_msg) + tokenizer.eos_token
42
+
43
+ # Add the current message
44
+ input_text = history_transformer_format + str(message) + tokenizer.eos_token
45
 
46
  # Tokenize the input
47
  new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
 
50
  with torch.no_grad():
51
  bot_output_ids = model.generate(
52
  new_user_input_ids,
53
+ max_length=1000, # Reduced for better performance
54
  pad_token_id=tokenizer.eos_token_id,
55
  no_repeat_ngram_size=3,
56
  do_sample=True,
57
+ top_k=50,
58
  top_p=0.7,
59
+ temperature=0.8
 
60
  )
61
 
62
  # Decode the response
 
65
  skip_special_tokens=True
66
  ).strip()
67
 
68
+ # Clean up and validate response
69
  if not response:
70
+ response = "I'm sorry, I couldn't generate a response. Could you try rephrasing your question?"
71
+
72
+ # Limit response length to prevent protocol errors
73
+ if len(response) > 500:
74
+ response = response[:500] + "..."
75
 
76
  return response
77
 
78
  except Exception as e:
79
+ print(f"Error in predict function: {str(e)}")
80
+ return "Sorry, I encountered an error. Please try again with a different message."
81
 
82
+ # Create a simple ChatInterface
83
  demo = gr.ChatInterface(
84
+ fn=predict,
85
+ title="DialoGPT-medium Chatbot",
86
+ description="Chat with Microsoft's DialoGPT-medium model!",
87
  examples=[
88
+ "Hello!",
89
+ "How are you?",
90
  "Tell me a joke",
91
+ "What's the weather like?"
92
  ],
93
  cache_examples=False
94
  )
95
 
96
+ # Launch the app with public sharing enabled
97
  if __name__ == "__main__":
 
98
  demo.launch(
99
+ share=True, # This creates the public link
100
  server_name="0.0.0.0",
101
+ server_port=7860
 
102
  )