shellyTa commited on
Commit
e8d6d85
·
1 Parent(s): 04f1c61

Fix prepare_huggingface_generation_config

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -148,6 +148,9 @@ def prepare_huggingface_generation_config(generation_config):
148
  # According to experimentations, it seems that `transformers` behave similarly
149
 
150
  # I'm not sure what is the right behavior here, but it is better to be explicit
 
 
 
151
  for name, params in GENERATION_CONFIG_PARAMS.items():
152
  if (
153
  "START" in params
@@ -160,19 +163,18 @@ def prepare_huggingface_generation_config(generation_config):
160
  else:
161
  assert generation_config.get("do_sample", False)
162
 
163
- # FIX: use .get() to avoid KeyError
164
  if generation_config.get("is_chat", False):
165
  generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
166
  generation_config["stop"] = generation_config.pop("stop_sequences")
167
  del generation_config["do_sample"]
168
  del generation_config["top_k"]
169
 
 
170
  is_chat = generation_config.pop("is_chat", False)
171
 
172
  return generation_config, is_chat
173
 
174
 
175
-
176
  def escape_markdown(text):
177
  escape_dict = {
178
  "*": r"\*",
 
148
  # According to experimentations, it seems that `transformers` behave similarly
149
 
150
  # I'm not sure what is the right behavior here, but it is better to be explicit
151
+ if "is_chat" not in generation_config:
152
+ generation_config["is_chat"] = False
153
+
154
  for name, params in GENERATION_CONFIG_PARAMS.items():
155
  if (
156
  "START" in params
 
163
  else:
164
  assert generation_config.get("do_sample", False)
165
 
 
166
  if generation_config.get("is_chat", False):
167
  generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
168
  generation_config["stop"] = generation_config.pop("stop_sequences")
169
  del generation_config["do_sample"]
170
  del generation_config["top_k"]
171
 
172
+ # FIX: pop safely with default
173
  is_chat = generation_config.pop("is_chat", False)
174
 
175
  return generation_config, is_chat
176
 
177
 
 
178
  def escape_markdown(text):
179
  escape_dict = {
180
  "*": r"\*",