xtreme86 commited on
Commit
2dc57ec
·
1 Parent(s): 0d7aeff
Files changed (1) hide show
  1. app.py +171 -55
app.py CHANGED
@@ -1,68 +1,170 @@
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
 
4
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
5
 
 
 
 
 
 
 
 
 
 
6
 
7
- def respond(
8
- message: str,
9
- history: list[tuple[str, str]],
10
- system_message: str,
11
- max_tokens: int,
12
- temperature: float,
13
- top_p: float,
14
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
  Generates a response using the Hugging Face Inference API.
17
 
18
- Args:
19
  message (str): User's current input.
20
  history (list[tuple[str, str]]): Previous conversation history.
21
- system_message (str): Instructions for the model (e.g., persona details).
 
22
  max_tokens (int): Maximum tokens allowed for the response.
23
- temperature (float): Sampling temperature for randomness in the output.
24
- top_p (float): Top-p (nucleus) sampling parameter.
25
 
26
  Yields:
27
  str: The generated chatbot response.
28
  """
29
- messages = [{"role": "system", "content": system_message}]
30
-
31
- for val in history:
32
- if val[0]:
33
- messages.append({"role": "user", "content": val[0]})
34
- if val[1]:
35
- messages.append({"role": "assistant", "content": val[1]})
36
 
37
- messages.append({"role": "user", "content": message})
 
 
38
 
39
- response = ""
 
40
 
41
- try:
42
- for message in client.chat_completion(
43
- messages,
44
- max_tokens=max_tokens,
45
- stream=True,
46
- temperature=temperature,
47
- top_p=top_p,
48
- ):
49
- token = message.choices[0].delta.content
50
- response += token
51
- yield response
52
- except Exception as e:
53
- yield f"Error: {str(e)}"
54
 
 
 
 
 
 
 
 
 
55
 
56
- def system_message_selector(choice):
57
- if choice == "Friendly Chatbot":
58
- return "You are a friendly and helpful chatbot."
59
- elif choice == "Professional Assistant":
60
- return "You are a highly knowledgeable and professional assistant."
61
- elif choice == "Curious Researcher":
62
- return "You are a curious researcher who loves to explore new ideas."
63
- else:
64
- return "You are a helpful assistant."
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Create the UI components
68
  system_message_radio = gr.Radio(
@@ -71,23 +173,37 @@ system_message_radio = gr.Radio(
71
  label="Choose a Persona",
72
  )
73
 
74
- # ChatInterface with dynamic system message selection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  demo = gr.ChatInterface(
76
- respond,
77
  additional_inputs=[
78
  system_message_radio,
79
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
80
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
81
- gr.Slider(
82
- minimum=0.1,
83
- maximum=1.0,
84
- value=0.95,
85
- step=0.05,
86
- label="Top-p (nucleus sampling)",
87
- ),
88
  ],
 
 
 
89
  )
90
 
91
-
92
  if __name__ == "__main__":
93
  demo.launch()
 
1
+ # Requires gradio==3.16.2
2
+ # Requires huggingface_hub==0.11.0
3
+
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
+ import logging
7
+ from functools import lru_cache
8
+ import html
9
+ import signal
10
+
11
+ # Setup logging
12
+ logging.basicConfig(level=logging.INFO)
13
 
14
+ # Initialize the Hugging Face Inference Client
15
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
16
 
17
+ # Constants
18
+ MAX_HISTORY_LENGTH = 5 # Adjust as needed
19
+
20
+ # Graceful shutdown handler
21
+ def shutdown_handler(signum, frame):
22
+ logging.info("Shutting down gracefully...")
23
+ exit(0)
24
+
25
+ signal.signal(signal.SIGINT, shutdown_handler)
26
 
27
+ def system_message_selector(choice, custom_message):
28
+ """
29
+ Selects the system message based on the user's choice or custom input.
30
+
31
+ Parameters:
32
+ choice (str): The persona choice selected by the user.
33
+ custom_message (str): A custom persona or system message provided by the user.
34
+
35
+ Returns:
36
+ str: The system message to be used in the conversation.
37
+ """
38
+ if custom_message:
39
+ return custom_message
40
+ elif choice == "Friendly Chatbot":
41
+ return "You are a friendly and helpful chatbot."
42
+ elif choice == "Professional Assistant":
43
+ return "You are a highly knowledgeable and professional assistant."
44
+ elif choice == "Curious Researcher":
45
+ return "You are a curious researcher who loves to explore new ideas."
46
+ else:
47
+ return "You are a helpful assistant."
48
+
49
+ @lru_cache(maxsize=32)
50
+ def get_response_from_model(messages_tuple, max_tokens, temperature, top_p):
51
+ """
52
+ Calls the Hugging Face Inference API to get a response.
53
+
54
+ Parameters:
55
+ messages_tuple (tuple): A tuple of messages to be sent to the model.
56
+ max_tokens (int): Maximum number of tokens for the response.
57
+ temperature (float): Sampling temperature.
58
+ top_p (float): Top-p (nucleus) sampling parameter.
59
+
60
+ Returns:
61
+ str: The generated response from the model.
62
+ """
63
+ # Convert tuple back to list of dicts
64
+ messages = [dict(m) for m in messages_tuple]
65
+ response = ""
66
+ for message in client.chat_completion(
67
+ messages,
68
+ max_tokens=max_tokens,
69
+ stream=True,
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ ):
73
+ token = message.choices[0].delta.content
74
+ response += token
75
+ return response
76
+
77
+ def sanitize_input(text):
78
+ """
79
+ Sanitizes user input to prevent code injection or XSS attacks.
80
+
81
+ Parameters:
82
+ text (str): The user input text.
83
+
84
+ Returns:
85
+ str: The sanitized text.
86
+ """
87
+ return html.escape(text)
88
+
89
+ def validate_parameters(max_tokens, temperature, top_p):
90
+ """
91
+ Validates input parameters.
92
+
93
+ Parameters:
94
+ max_tokens (int): Maximum number of tokens for the response.
95
+ temperature (float): Sampling temperature.
96
+ top_p (float): Top-p (nucleus) sampling parameter.
97
+
98
+ Returns:
99
+ tuple: (bool, str) indicating validity and an error message if invalid.
100
+ """
101
+ if not (1 <= max_tokens <= 2048):
102
+ return False, "Error: 'Max new tokens' must be between 1 and 2048."
103
+ if not (0.1 <= temperature <= 4.0):
104
+ return False, "Error: 'Temperature' must be between 0.1 and 4.0."
105
+ if not (0.1 <= top_p <= 1.0):
106
+ return False, "Error: 'Top-p' must be between 0.1 and 1.0."
107
+ return True, ""
108
+
109
+ def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
110
  """
111
  Generates a response using the Hugging Face Inference API.
112
 
113
+ Parameters:
114
  message (str): User's current input.
115
  history (list[tuple[str, str]]): Previous conversation history.
116
+ persona_choice (str): The selected persona.
117
+ custom_persona (str): Custom persona or system message.
118
  max_tokens (int): Maximum tokens allowed for the response.
119
+ temperature (float): Sampling temperature.
120
+ top_p (float): Top-p (nucleus sampling) parameter.
121
 
122
  Yields:
123
  str: The generated chatbot response.
124
  """
125
+ # Validate parameters
126
+ is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
127
+ if not is_valid:
128
+ yield error_message
129
+ return
 
 
130
 
131
+ # Sanitize user input
132
+ safe_message = sanitize_input(message)
133
+ safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history]
134
 
135
+ # Limit the history to the most recent exchanges
136
+ truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
137
 
138
+ # Select system message
139
+ system_message = system_message_selector(persona_choice, custom_persona)
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Build messages with truncated history
142
+ messages = [{"role": "system", "content": system_message}]
143
+ for user_msg, bot_msg in truncated_history:
144
+ if user_msg:
145
+ messages.append({"role": "user", "content": user_msg})
146
+ if bot_msg:
147
+ messages.append({"role": "assistant", "content": bot_msg})
148
+ messages.append({"role": "user", "content": safe_message})
149
 
150
+ # Log the request
151
+ logging.info(f"Received message: {safe_message}")
 
 
 
 
 
 
 
152
 
153
+ try:
154
+ # Convert messages to a tuple of tuples for caching
155
+ messages_tuple = tuple(tuple(m.items()) for m in messages)
156
+
157
+ # Use caching to optimize performance
158
+ response = get_response_from_model(
159
+ messages_tuple,
160
+ max_tokens,
161
+ temperature,
162
+ top_p,
163
+ )
164
+ yield response
165
+ except Exception as e:
166
+ logging.error(f"An error occurred: {e}")
167
+ yield "I'm sorry, but something went wrong. Please try again."
168
 
169
  # Create the UI components
170
  system_message_radio = gr.Radio(
 
173
  label="Choose a Persona",
174
  )
175
 
176
+ system_message_textbox = gr.Textbox(
177
+ placeholder="Enter custom persona or system message...",
178
+ label="Custom Persona (Optional)",
179
+ )
180
+
181
+ max_tokens_slider = gr.Slider(
182
+ minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
183
+ )
184
+
185
+ temperature_slider = gr.Slider(
186
+ minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
187
+ )
188
+
189
+ top_p_slider = gr.Slider(
190
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
191
+ )
192
+
193
+ # Create the ChatInterface directly with additional inputs
194
  demo = gr.ChatInterface(
195
+ fn=respond,
196
  additional_inputs=[
197
  system_message_radio,
198
+ system_message_textbox,
199
+ max_tokens_slider,
200
+ temperature_slider,
201
+ top_p_slider,
 
 
 
 
 
202
  ],
203
+ allow_reset_history=True,
204
+ title="Customizable Chatbot Interface",
205
+ description="Choose a persona or enter a custom one, and adjust parameters as needed.",
206
  )
207
 
 
208
  if __name__ == "__main__":
209
  demo.launch()