shellyTa commited on
Commit
04f1c61
·
1 Parent(s): 4ecff45

Fix prepare_huggingface_generation_config

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -149,7 +149,6 @@ def prepare_huggingface_generation_config(generation_config):
149
 
150
  # I'm not sure what is the right behavior here, but it is better to be explicit
151
  for name, params in GENERATION_CONFIG_PARAMS.items():
152
- # Checking for START to examine the a slider parameters only
153
  if (
154
  "START" in params
155
  and params["SAMPLING"]
@@ -159,21 +158,21 @@ def prepare_huggingface_generation_config(generation_config):
159
  if generation_config[name] == params["DEFAULT"]:
160
  generation_config[name] = None
161
  else:
162
- assert generation_config["do_sample"]
163
 
164
- # TODO: refactor this part
165
- if generation_config["is_chat"]:
166
  generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
167
-
168
  generation_config["stop"] = generation_config.pop("stop_sequences")
169
  del generation_config["do_sample"]
170
  del generation_config["top_k"]
171
 
172
- is_chat = generation_config.pop("is_chat")
173
 
174
  return generation_config, is_chat
175
 
176
 
 
177
  def escape_markdown(text):
178
  escape_dict = {
179
  "*": r"\*",
 
149
 
150
  # I'm not sure what is the right behavior here, but it is better to be explicit
151
  for name, params in GENERATION_CONFIG_PARAMS.items():
 
152
  if (
153
  "START" in params
154
  and params["SAMPLING"]
 
158
  if generation_config[name] == params["DEFAULT"]:
159
  generation_config[name] = None
160
  else:
161
+ assert generation_config.get("do_sample", False)
162
 
163
+ # FIX: use .get() to avoid KeyError
164
+ if generation_config.get("is_chat", False):
165
  generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
 
166
  generation_config["stop"] = generation_config.pop("stop_sequences")
167
  del generation_config["do_sample"]
168
  del generation_config["top_k"]
169
 
170
+ is_chat = generation_config.pop("is_chat", False)
171
 
172
  return generation_config, is_chat
173
 
174
 
175
+
176
  def escape_markdown(text):
177
  escape_dict = {
178
  "*": r"\*",