xtreme86 commited on
Commit
8d2d1dc
·
1 Parent(s): 50bb5db
Files changed (2) hide show
  1. app.py +17 -94
  2. requirements.txt +3 -2
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import logging
5
  import html
6
  import signal
7
- from functools import lru_cache
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
@@ -20,16 +19,6 @@ def shutdown_handler(signum, frame):
20
  signal.signal(signal.SIGINT, shutdown_handler)
21
 
22
  def system_message_selector(choice, custom_message):
23
- """
24
- Selects the system message based on the user's choice or custom input.
25
-
26
- Parameters:
27
- choice (str): The persona choice selected by the user.
28
- custom_message (str): A custom persona or system message provided by the user.
29
-
30
- Returns:
31
- str: The system message to be used in the conversation.
32
- """
33
  if custom_message:
34
  return custom_message
35
  elif choice == "Friendly Chatbot":
@@ -42,29 +31,9 @@ def system_message_selector(choice, custom_message):
42
  return "You are a helpful assistant."
43
 
44
  def sanitize_input(text):
45
- """
46
- Sanitizes user input to prevent code injection or XSS attacks.
47
-
48
- Parameters:
49
- text (str): The user input text.
50
-
51
- Returns:
52
- str: The sanitized text.
53
- """
54
  return html.escape(text)
55
 
56
  def validate_parameters(max_tokens, temperature, top_p):
57
- """
58
- Validates input parameters.
59
-
60
- Parameters:
61
- max_tokens (int): Maximum number of tokens for the response.
62
- temperature (float): Sampling temperature.
63
- top_p (float): Top-p (nucleus) sampling parameter.
64
-
65
- Returns:
66
- tuple: (bool, str) indicating validity and an error message if invalid.
67
- """
68
  if not (1 <= max_tokens <= 2048):
69
  return False, "Error: 'Max new tokens' must be between 1 and 2048."
70
  if not (0.1 <= temperature <= 4.0):
@@ -74,101 +43,55 @@ def validate_parameters(max_tokens, temperature, top_p):
74
  return True, ""
75
 
76
  # Load the model and tokenizer
77
- model_name = "HuggingFaceH4/zephyr-7b-beta" # Replace with your actual model name
78
 
79
  try:
80
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
81
- model = transformers.AutoModelForCausalLM.from_pretrained(
 
 
82
  model_name,
83
  torch_dtype=torch.float16,
84
- device_map="auto" # Automatically places model layers on available GPUs
85
  )
86
  model.eval()
87
  except Exception as e:
88
  logging.error(f"Failed to load model {model_name}: {e}")
89
  exit(1)
90
 
91
- @lru_cache(maxsize=32)
92
- def generate_response(prompt, max_tokens, temperature, top_p):
93
- """
94
- Generates a response using the loaded language model.
95
-
96
- Parameters:
97
- prompt (str): The input prompt for the model.
98
- max_tokens (int): Maximum number of tokens for the response.
99
- temperature (float): Sampling temperature.
100
- top_p (float): Top-p (nucleus) sampling parameter.
101
-
102
- Returns:
103
- str: The generated response from the model.
104
- """
105
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
106
- input_ids = input_ids.to(model.device)
107
-
108
- with torch.no_grad():
109
- output_ids = model.generate(
110
- input_ids,
111
- max_length=input_ids.shape[1] + max_tokens,
112
- temperature=temperature,
113
- top_p=top_p,
114
- do_sample=True,
115
- pad_token_id=tokenizer.eos_token_id,
116
- eos_token_id=tokenizer.eos_token_id,
117
- )
118
-
119
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
120
- return generated_text[len(prompt):].strip()
121
-
122
  def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
123
- """
124
- Generates a response using the loaded language model.
125
-
126
- Parameters:
127
- message (str): User's current input.
128
- history (list[tuple[str, str]]): Previous conversation history.
129
- persona_choice (str): The selected persona.
130
- custom_persona (str): Custom persona or system message.
131
- max_tokens (int): Maximum tokens allowed for the response.
132
- temperature (float): Sampling temperature.
133
- top_p (float): Top-p (nucleus sampling) parameter.
134
-
135
- Returns:
136
- str: The generated chatbot response.
137
- """
138
- # Validate parameters
139
  is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
140
  if not is_valid:
141
  return error_message
142
 
143
- # Sanitize user input
144
  safe_message = sanitize_input(message)
145
  safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history]
146
-
147
- # Limit the history to the most recent exchanges
148
  truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
149
-
150
- # Select system message
151
  system_message = system_message_selector(persona_choice, custom_persona)
152
 
153
- # Build the conversation prompt
154
  conversation = system_message + "\n\n"
155
  for user_msg, bot_msg in truncated_history:
156
  conversation += f"User: {user_msg}\n"
157
  conversation += f"Assistant: {bot_msg}\n"
158
  conversation += f"User: {safe_message}\nAssistant:"
159
 
160
- # Log the request
161
  logging.info(f"Received message: {safe_message}")
162
 
163
  try:
164
- # Use caching to optimize performance
165
- response = generate_response(
166
- prompt=conversation,
167
- max_tokens=max_tokens,
 
168
  temperature=temperature,
169
  top_p=top_p,
 
 
 
170
  )
171
- return response
 
 
172
  except Exception as e:
173
  logging.error(f"An error occurred: {e}")
174
  return "I'm sorry, but something went wrong. Please try again."
 
4
  import logging
5
  import html
6
  import signal
 
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO)
 
19
  signal.signal(signal.SIGINT, shutdown_handler)
20
 
21
  def system_message_selector(choice, custom_message):
 
 
 
 
 
 
 
 
 
 
22
  if custom_message:
23
  return custom_message
24
  elif choice == "Friendly Chatbot":
 
31
  return "You are a helpful assistant."
32
 
33
  def sanitize_input(text):
 
 
 
 
 
 
 
 
 
34
  return html.escape(text)
35
 
36
  def validate_parameters(max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
37
  if not (1 <= max_tokens <= 2048):
38
  return False, "Error: 'Max new tokens' must be between 1 and 2048."
39
  if not (0.1 <= temperature <= 4.0):
 
43
  return True, ""
44
 
45
  # Load the model and tokenizer
46
+ model_name = "HuggingFaceH4/mistral-7b-instruct" # Update with the correct model name
47
 
48
  try:
49
+ from transformers import MistralForCausalLM, MistralTokenizer
50
+
51
+ tokenizer = MistralTokenizer.from_pretrained(model_name)
52
+ model = MistralForCausalLM.from_pretrained(
53
  model_name,
54
  torch_dtype=torch.float16,
55
+ device_map="auto",
56
  )
57
  model.eval()
58
  except Exception as e:
59
  logging.error(f"Failed to load model {model_name}: {e}")
60
  exit(1)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
64
  if not is_valid:
65
  return error_message
66
 
 
67
  safe_message = sanitize_input(message)
68
  safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history]
 
 
69
  truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
 
 
70
  system_message = system_message_selector(persona_choice, custom_persona)
71
 
 
72
  conversation = system_message + "\n\n"
73
  for user_msg, bot_msg in truncated_history:
74
  conversation += f"User: {user_msg}\n"
75
  conversation += f"Assistant: {bot_msg}\n"
76
  conversation += f"User: {safe_message}\nAssistant:"
77
 
 
78
  logging.info(f"Received message: {safe_message}")
79
 
80
  try:
81
+ input_ids = tokenizer.encode(conversation, return_tensors="pt").to(model.device)
82
+
83
+ output_ids = model.generate(
84
+ input_ids,
85
+ max_new_tokens=max_tokens,
86
  temperature=temperature,
87
  top_p=top_p,
88
+ do_sample=True,
89
+ pad_token_id=tokenizer.eos_token_id,
90
+ eos_token_id=tokenizer.eos_token_id,
91
  )
92
+
93
+ generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
94
+ return generated_text.strip()
95
  except Exception as e:
96
  logging.error(f"An error occurred: {e}")
97
  return "I'm sorry, but something went wrong. Please try again."
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- transformers==4.31.0
2
  gradio==3.40.1
3
- torch==2.0.1
 
 
1
+ transformers>=4.34.0
2
  gradio==3.40.1
3
+ torch>=2.0.1
4
+ xformers