xtreme86 commited on
Commit
50bb5db
·
1 Parent(s): 2dc57ec
Files changed (2) hide show
  1. app.py +66 -59
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,19 +1,14 @@
1
- # Requires gradio==3.16.2
2
- # Requires huggingface_hub==0.11.0
3
-
4
  import gradio as gr
5
- from huggingface_hub import InferenceClient
 
6
  import logging
7
- from functools import lru_cache
8
  import html
9
  import signal
 
10
 
11
  # Setup logging
12
  logging.basicConfig(level=logging.INFO)
13
 
14
- # Initialize the Hugging Face Inference Client
15
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
16
-
17
  # Constants
18
  MAX_HISTORY_LENGTH = 5 # Adjust as needed
19
 
@@ -46,34 +41,6 @@ def system_message_selector(choice, custom_message):
46
  else:
47
  return "You are a helpful assistant."
48
 
49
- @lru_cache(maxsize=32)
50
- def get_response_from_model(messages_tuple, max_tokens, temperature, top_p):
51
- """
52
- Calls the Hugging Face Inference API to get a response.
53
-
54
- Parameters:
55
- messages_tuple (tuple): A tuple of messages to be sent to the model.
56
- max_tokens (int): Maximum number of tokens for the response.
57
- temperature (float): Sampling temperature.
58
- top_p (float): Top-p (nucleus) sampling parameter.
59
-
60
- Returns:
61
- str: The generated response from the model.
62
- """
63
- # Convert tuple back to list of dicts
64
- messages = [dict(m) for m in messages_tuple]
65
- response = ""
66
- for message in client.chat_completion(
67
- messages,
68
- max_tokens=max_tokens,
69
- stream=True,
70
- temperature=temperature,
71
- top_p=top_p,
72
- ):
73
- token = message.choices[0].delta.content
74
- response += token
75
- return response
76
-
77
  def sanitize_input(text):
78
  """
79
  Sanitizes user input to prevent code injection or XSS attacks.
@@ -106,9 +73,55 @@ def validate_parameters(max_tokens, temperature, top_p):
106
  return False, "Error: 'Top-p' must be between 0.1 and 1.0."
107
  return True, ""
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
110
  """
111
- Generates a response using the Hugging Face Inference API.
112
 
113
  Parameters:
114
  message (str): User's current input.
@@ -119,14 +132,13 @@ def respond(message, history, persona_choice, custom_persona, max_tokens, temper
119
  temperature (float): Sampling temperature.
120
  top_p (float): Top-p (nucleus sampling) parameter.
121
 
122
- Yields:
123
  str: The generated chatbot response.
124
  """
125
  # Validate parameters
126
  is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
127
  if not is_valid:
128
- yield error_message
129
- return
130
 
131
  # Sanitize user input
132
  safe_message = sanitize_input(message)
@@ -138,33 +150,28 @@ def respond(message, history, persona_choice, custom_persona, max_tokens, temper
138
  # Select system message
139
  system_message = system_message_selector(persona_choice, custom_persona)
140
 
141
- # Build messages with truncated history
142
- messages = [{"role": "system", "content": system_message}]
143
  for user_msg, bot_msg in truncated_history:
144
- if user_msg:
145
- messages.append({"role": "user", "content": user_msg})
146
- if bot_msg:
147
- messages.append({"role": "assistant", "content": bot_msg})
148
- messages.append({"role": "user", "content": safe_message})
149
 
150
  # Log the request
151
  logging.info(f"Received message: {safe_message}")
152
 
153
  try:
154
- # Convert messages to a tuple of tuples for caching
155
- messages_tuple = tuple(tuple(m.items()) for m in messages)
156
-
157
  # Use caching to optimize performance
158
- response = get_response_from_model(
159
- messages_tuple,
160
- max_tokens,
161
- temperature,
162
- top_p,
163
  )
164
- yield response
165
  except Exception as e:
166
  logging.error(f"An error occurred: {e}")
167
- yield "I'm sorry, but something went wrong. Please try again."
168
 
169
  # Create the UI components
170
  system_message_radio = gr.Radio(
@@ -179,7 +186,7 @@ system_message_textbox = gr.Textbox(
179
  )
180
 
181
  max_tokens_slider = gr.Slider(
182
- minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
183
  )
184
 
185
  temperature_slider = gr.Slider(
@@ -190,7 +197,7 @@ top_p_slider = gr.Slider(
190
  minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
191
  )
192
 
193
- # Create the ChatInterface directly with additional inputs
194
  demo = gr.ChatInterface(
195
  fn=respond,
196
  additional_inputs=[
 
 
 
 
1
  import gradio as gr
2
+ import transformers
3
+ import torch
4
  import logging
 
5
  import html
6
  import signal
7
+ from functools import lru_cache
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
 
 
 
 
12
  # Constants
13
  MAX_HISTORY_LENGTH = 5 # Adjust as needed
14
 
 
41
  else:
42
  return "You are a helpful assistant."
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def sanitize_input(text):
45
  """
46
  Sanitizes user input to prevent code injection or XSS attacks.
 
73
  return False, "Error: 'Top-p' must be between 0.1 and 1.0."
74
  return True, ""
75
 
76
+ # Load the model and tokenizer
77
+ model_name = "HuggingFaceH4/zephyr-7b-beta" # Replace with your actual model name
78
+
79
+ try:
80
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
81
+ model = transformers.AutoModelForCausalLM.from_pretrained(
82
+ model_name,
83
+ torch_dtype=torch.float16,
84
+ device_map="auto" # Automatically places model layers on available GPUs
85
+ )
86
+ model.eval()
87
+ except Exception as e:
88
+ logging.error(f"Failed to load model {model_name}: {e}")
89
+ exit(1)
90
+
91
+ @lru_cache(maxsize=32)
92
+ def generate_response(prompt, max_tokens, temperature, top_p):
93
+ """
94
+ Generates a response using the loaded language model.
95
+
96
+ Parameters:
97
+ prompt (str): The input prompt for the model.
98
+ max_tokens (int): Maximum number of tokens for the response.
99
+ temperature (float): Sampling temperature.
100
+ top_p (float): Top-p (nucleus) sampling parameter.
101
+
102
+ Returns:
103
+ str: The generated response from the model.
104
+ """
105
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
106
+ input_ids = input_ids.to(model.device)
107
+
108
+ with torch.no_grad():
109
+ output_ids = model.generate(
110
+ input_ids,
111
+ max_length=input_ids.shape[1] + max_tokens,
112
+ temperature=temperature,
113
+ top_p=top_p,
114
+ do_sample=True,
115
+ pad_token_id=tokenizer.eos_token_id,
116
+ eos_token_id=tokenizer.eos_token_id,
117
+ )
118
+
119
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
120
+ return generated_text[len(prompt):].strip()
121
+
122
  def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
123
  """
124
+ Generates a response using the loaded language model.
125
 
126
  Parameters:
127
  message (str): User's current input.
 
132
  temperature (float): Sampling temperature.
133
  top_p (float): Top-p (nucleus sampling) parameter.
134
 
135
+ Returns:
136
  str: The generated chatbot response.
137
  """
138
  # Validate parameters
139
  is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
140
  if not is_valid:
141
+ return error_message
 
142
 
143
  # Sanitize user input
144
  safe_message = sanitize_input(message)
 
150
  # Select system message
151
  system_message = system_message_selector(persona_choice, custom_persona)
152
 
153
+ # Build the conversation prompt
154
+ conversation = system_message + "\n\n"
155
  for user_msg, bot_msg in truncated_history:
156
+ conversation += f"User: {user_msg}\n"
157
+ conversation += f"Assistant: {bot_msg}\n"
158
+ conversation += f"User: {safe_message}\nAssistant:"
 
 
159
 
160
  # Log the request
161
  logging.info(f"Received message: {safe_message}")
162
 
163
  try:
 
 
 
164
  # Use caching to optimize performance
165
+ response = generate_response(
166
+ prompt=conversation,
167
+ max_tokens=max_tokens,
168
+ temperature=temperature,
169
+ top_p=top_p,
170
  )
171
+ return response
172
  except Exception as e:
173
  logging.error(f"An error occurred: {e}")
174
+ return "I'm sorry, but something went wrong. Please try again."
175
 
176
  # Create the UI components
177
  system_message_radio = gr.Radio(
 
186
  )
187
 
188
  max_tokens_slider = gr.Slider(
189
+ minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"
190
  )
191
 
192
  temperature_slider = gr.Slider(
 
197
  minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
198
  )
199
 
200
+ # Create the ChatInterface
201
  demo = gr.ChatInterface(
202
  fn=respond,
203
  additional_inputs=[
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  transformers==4.31.0
2
  gradio==3.40.1
3
- torch==2.0.1
 
1
  transformers==4.31.0
2
  gradio==3.40.1
3
+ torch==2.0.1