Deva1211 commited on
Commit
f6aa413
·
verified ·
1 Parent(s): de0ccfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -79
app.py CHANGED
@@ -12,91 +12,49 @@ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
12
 
13
  print("Model loaded successfully!")
14
 
15
- # Define the prediction function that works with the modern format
16
  def predict(message, history):
17
- try:
18
- # Validate input
19
- if not message or not message.strip():
20
- return "Please enter a message."
21
-
22
- # Format history for DialoGPT - handle both old and new formats
23
- history_transformer_format = ""
24
-
25
- # Handle the new 'messages' format (list of dicts)
26
- if history and isinstance(history[0], dict):
27
- for turn in history:
28
- if turn.get("role") == "user":
29
- history_transformer_format += turn["content"] + tokenizer.eos_token
30
- elif turn.get("role") == "assistant":
31
- history_transformer_format += turn["content"] + tokenizer.eos_token
32
-
33
- # Handle the old 'tuples' format (list of lists)
34
- elif history and isinstance(history[0], list):
35
- for exchange in history:
36
- if len(exchange) >= 2:
37
- user_msg, bot_msg = exchange[0], exchange[1]
38
- if user_msg:
39
- history_transformer_format += str(user_msg) + tokenizer.eos_token
40
- if bot_msg:
41
- history_transformer_format += str(bot_msg) + tokenizer.eos_token
42
-
43
- # Add the current message
44
- input_text = history_transformer_format + str(message) + tokenizer.eos_token
45
-
46
- # Tokenize the input
47
- new_user_input_ids = tokenizer.encode(input_text, return_tensors='pt')
48
-
49
- # Generate a response with memory management
50
- with torch.no_grad():
51
- bot_output_ids = model.generate(
52
- new_user_input_ids,
53
- max_length=1000, # Reduced for better performance
54
- pad_token_id=tokenizer.eos_token_id,
55
- no_repeat_ngram_size=3,
56
- do_sample=True,
57
- top_k=50,
58
- top_p=0.7,
59
- temperature=0.8
60
- )
61
 
62
- # Decode the response
63
- response = tokenizer.decode(
64
- bot_output_ids[:, new_user_input_ids.shape[-1]:][0],
65
- skip_special_tokens=True
66
- ).strip()
67
-
68
- # Clean up and validate response
69
- if not response:
70
- response = "I'm sorry, I couldn't generate a response. Could you try rephrasing your question?"
71
-
72
- # Limit response length to prevent protocol errors
73
- if len(response) > 500:
74
- response = response[:500] + "..."
75
-
76
- return response
77
-
78
- except Exception as e:
79
- print(f"Error in predict function: {str(e)}")
80
- return "Sorry, I encountered an error. Please try again with a different message."
81
 
82
- # Create a simple ChatInterface
 
83
  demo = gr.ChatInterface(
84
  fn=predict,
85
  title="DialoGPT-medium Chatbot",
86
- description="Chat with Microsoft's DialoGPT-medium model!",
87
- examples=[
88
- "Hello!",
89
- "How are you?",
90
- "Tell me a joke",
91
- "What's the weather like?"
92
- ],
93
- cache_examples=False
94
  )
95
 
96
- # Launch the app with public sharing enabled
97
  if __name__ == "__main__":
98
- demo.launch(
99
- share=True, # This creates the public link
100
- server_name="0.0.0.0",
101
- server_port=7860
102
- )
 
12
 
13
  print("Model loaded successfully!")
14
 
15
+ # Define the prediction function that works with the modern 'messages' format
16
  def predict(message, history):
17
+ # Format the history for DialoGPT. It expects a flat string of alternating user/bot messages.
18
+ history_transformer_format = ""
19
+ for user_msg, bot_msg in history:
20
+ history_transformer_format += user_msg + tokenizer.eos_token
21
+ history_transformer_format += bot_msg + tokenizer.eos_token
22
+
23
+ # Append the new user message
24
+ history_transformer_format += message + tokenizer.eos_token
25
+
26
+ # Tokenize the input
27
+ new_user_input_ids = tokenizer.encode(history_transformer_format, return_tensors='pt')
28
+
29
+ # Generate a response
30
+ bot_output_ids = model.generate(
31
+ new_user_input_ids,
32
+ max_length=1250,
33
+ pad_token_id=tokenizer.eos_token_id,
34
+ no_repeat_ngram_size=3,
35
+ do_sample=True,
36
+ top_k=100,
37
+ top_p=0.7,
38
+ temperature=0.8
39
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Decode the response, skipping the input part
42
+ response = tokenizer.decode(bot_output_ids[:, new_user_input_ids.shape[-1]:][0], skip_special_tokens=True)
43
+
44
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Build the Gradio interface using the modern 'gr.ChatInterface'
47
+ # This is much simpler and handles all the UI elements for you.
48
  demo = gr.ChatInterface(
49
  fn=predict,
50
  title="DialoGPT-medium Chatbot",
51
+ description="This chatbot uses the microsoft/DialoGPT-medium model. Start typing to chat!",
52
+ theme="soft",
53
+ examples=["Hello!", "How does a computer work?", "Tell me a joke."],
54
+ undo_btn="Undo Last Turn",
55
+ clear_btn="Clear Chat",
 
 
 
56
  )
57
 
58
+ # Launch the app. No 'share=True' is needed on Spaces.
59
  if __name__ == "__main__":
60
+ demo.queue().launch()