xtreme86 commited on
Commit
6df87c2
·
1 Parent(s): 8d2d1dc
Files changed (2) hide show
  1. app.py +12 -12
  2. requirements.txt +2 -3
app.py CHANGED
@@ -34,8 +34,8 @@ def sanitize_input(text):
34
  return html.escape(text)
35
 
36
  def validate_parameters(max_tokens, temperature, top_p):
37
- if not (1 <= max_tokens <= 2048):
38
- return False, "Error: 'Max new tokens' must be between 1 and 2048."
39
  if not (0.1 <= temperature <= 4.0):
40
  return False, "Error: 'Temperature' must be between 0.1 and 4.0."
41
  if not (0.1 <= top_p <= 1.0):
@@ -43,15 +43,12 @@ def validate_parameters(max_tokens, temperature, top_p):
43
  return True, ""
44
 
45
  # Load the model and tokenizer
46
- model_name = "HuggingFaceH4/mistral-7b-instruct" # Update with the correct model name
47
 
48
  try:
49
- from transformers import MistralForCausalLM, MistralTokenizer
50
-
51
- tokenizer = MistralTokenizer.from_pretrained(model_name)
52
- model = MistralForCausalLM.from_pretrained(
53
  model_name,
54
- torch_dtype=torch.float16,
55
  device_map="auto",
56
  )
57
  model.eval()
@@ -69,6 +66,7 @@ def respond(message, history, persona_choice, custom_persona, max_tokens, temper
69
  truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
70
  system_message = system_message_selector(persona_choice, custom_persona)
71
 
 
72
  conversation = system_message + "\n\n"
73
  for user_msg, bot_msg in truncated_history:
74
  conversation += f"User: {user_msg}\n"
@@ -90,7 +88,9 @@ def respond(message, history, persona_choice, custom_persona, max_tokens, temper
90
  eos_token_id=tokenizer.eos_token_id,
91
  )
92
 
93
- generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
 
94
  return generated_text.strip()
95
  except Exception as e:
96
  logging.error(f"An error occurred: {e}")
@@ -109,15 +109,15 @@ system_message_textbox = gr.Textbox(
109
  )
110
 
111
  max_tokens_slider = gr.Slider(
112
- minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"
113
  )
114
 
115
  temperature_slider = gr.Slider(
116
- minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
117
  )
118
 
119
  top_p_slider = gr.Slider(
120
- minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
121
  )
122
 
123
  # Create the ChatInterface
 
34
  return html.escape(text)
35
 
36
  def validate_parameters(max_tokens, temperature, top_p):
37
+ if not (1 <= max_tokens <= 1024):
38
+ return False, "Error: 'Max new tokens' must be between 1 and 1024."
39
  if not (0.1 <= temperature <= 4.0):
40
  return False, "Error: 'Temperature' must be between 0.1 and 4.0."
41
  if not (0.1 <= top_p <= 1.0):
 
43
  return True, ""
44
 
45
  # Load the model and tokenizer
46
+ model_name = "gpt2" # Use GPT-2 model
47
 
48
  try:
49
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
50
+ model = transformers.AutoModelForCausalLM.from_pretrained(
 
 
51
  model_name,
 
52
  device_map="auto",
53
  )
54
  model.eval()
 
66
  truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
67
  system_message = system_message_selector(persona_choice, custom_persona)
68
 
69
+ # Build the conversation prompt
70
  conversation = system_message + "\n\n"
71
  for user_msg, bot_msg in truncated_history:
72
  conversation += f"User: {user_msg}\n"
 
88
  eos_token_id=tokenizer.eos_token_id,
89
  )
90
 
91
+ generated_text = tokenizer.decode(
92
+ output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True
93
+ )
94
  return generated_text.strip()
95
  except Exception as e:
96
  logging.error(f"An error occurred: {e}")
 
109
  )
110
 
111
  max_tokens_slider = gr.Slider(
112
+ minimum=1, maximum=1024, value=50, step=1, label="Max new tokens"
113
  )
114
 
115
  temperature_slider = gr.Slider(
116
+ minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"
117
  )
118
 
119
  top_p_slider = gr.Slider(
120
+ minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"
121
  )
122
 
123
  # Create the ChatInterface
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- transformers>=4.34.0
2
  gradio==3.40.1
3
- torch>=2.0.1
4
- xformers
 
1
+ transformers==4.31.0
2
  gradio==3.40.1
3
+ torch==2.0.1