Fix prepare_huggingface_generation_config
Browse files
app.py
CHANGED
|
@@ -149,7 +149,6 @@ def prepare_huggingface_generation_config(generation_config):
|
|
| 149 |
|
| 150 |
# I'm not sure what is the right behavior here, but it is better to be explicit
|
| 151 |
for name, params in GENERATION_CONFIG_PARAMS.items():
|
| 152 |
-
# Checking for START to examine the a slider parameters only
|
| 153 |
if (
|
| 154 |
"START" in params
|
| 155 |
and params["SAMPLING"]
|
|
@@ -159,21 +158,21 @@ def prepare_huggingface_generation_config(generation_config):
|
|
| 159 |
if generation_config[name] == params["DEFAULT"]:
|
| 160 |
generation_config[name] = None
|
| 161 |
else:
|
| 162 |
-
assert generation_config
|
| 163 |
|
| 164 |
-
#
|
| 165 |
-
if generation_config
|
| 166 |
generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
|
| 167 |
-
|
| 168 |
generation_config["stop"] = generation_config.pop("stop_sequences")
|
| 169 |
del generation_config["do_sample"]
|
| 170 |
del generation_config["top_k"]
|
| 171 |
|
| 172 |
-
is_chat = generation_config.pop("is_chat")
|
| 173 |
|
| 174 |
return generation_config, is_chat
|
| 175 |
|
| 176 |
|
|
|
|
| 177 |
def escape_markdown(text):
|
| 178 |
escape_dict = {
|
| 179 |
"*": r"\*",
|
|
|
|
| 149 |
|
| 150 |
# I'm not sure what is the right behavior here, but it is better to be explicit
|
| 151 |
for name, params in GENERATION_CONFIG_PARAMS.items():
|
|
|
|
| 152 |
if (
|
| 153 |
"START" in params
|
| 154 |
and params["SAMPLING"]
|
|
|
|
| 158 |
if generation_config[name] == params["DEFAULT"]:
|
| 159 |
generation_config[name] = None
|
| 160 |
else:
|
| 161 |
+
assert generation_config.get("do_sample", False)
|
| 162 |
|
| 163 |
+
# FIX: use .get() to avoid KeyError
|
| 164 |
+
if generation_config.get("is_chat", False):
|
| 165 |
generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
|
|
|
|
| 166 |
generation_config["stop"] = generation_config.pop("stop_sequences")
|
| 167 |
del generation_config["do_sample"]
|
| 168 |
del generation_config["top_k"]
|
| 169 |
|
| 170 |
+
is_chat = generation_config.pop("is_chat", False)
|
| 171 |
|
| 172 |
return generation_config, is_chat
|
| 173 |
|
| 174 |
|
| 175 |
+
|
| 176 |
def escape_markdown(text):
|
| 177 |
escape_dict = {
|
| 178 |
"*": r"\*",
|