Gregor Betz
commited on
config bugfix
Browse files- app.py +14 -10
- backend/config.py +3 -24
- config.yaml +3 -1
app.py
CHANGED
|
@@ -16,15 +16,6 @@ from backend.svg_processing import postprocess_svg
|
|
| 16 |
|
| 17 |
logging.basicConfig(level=logging.DEBUG)
|
| 18 |
|
| 19 |
-
with open("config.yaml") as stream:
|
| 20 |
-
try:
|
| 21 |
-
DEMO_CONFIG = yaml.safe_load(stream)
|
| 22 |
-
logging.debug(f"Config: {DEMO_CONFIG}")
|
| 23 |
-
except yaml.YAMLError as exc:
|
| 24 |
-
logging.error(f"Error loading config: {exc}")
|
| 25 |
-
raise exc
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
EXAMPLES = [
|
| 30 |
("We're a nature-loving family with three kids, have some money left, and no plans "
|
|
@@ -94,7 +85,20 @@ CHATBOT_INSTRUCTIONS = (
|
|
| 94 |
)
|
| 95 |
|
| 96 |
# config
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
logging.info(f"Reasoning guide expert model is {guide_kwargs['expert_model']}.")
|
| 99 |
|
| 100 |
|
|
|
|
| 16 |
|
| 17 |
logging.basicConfig(level=logging.DEBUG)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
EXAMPLES = [
|
| 21 |
("We're a nature-loving family with three kids, have some money left, and no plans "
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
# config
|
| 88 |
+
with open("config.yaml") as stream:
|
| 89 |
+
try:
|
| 90 |
+
demo_config = yaml.safe_load(stream)
|
| 91 |
+
logging.debug(f"Config: {demo_config}")
|
| 92 |
+
except yaml.YAMLError as exc:
|
| 93 |
+
logging.error(f"Error loading config: {exc}")
|
| 94 |
+
gr.Error("Error loading config: {exc}")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
client_kwargs, guide_kwargs = process_config(demo_config)
|
| 98 |
+
except Exception as exc:
|
| 99 |
+
logging.error(f"Error processing config: {exc}")
|
| 100 |
+
gr.Error(f"Error processing config: {exc}")
|
| 101 |
+
|
| 102 |
logging.info(f"Reasoning guide expert model is {guide_kwargs['expert_model']}.")
|
| 103 |
|
| 104 |
|
backend/config.py
CHANGED
|
@@ -1,26 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
-
# Default client
|
| 4 |
-
INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}"
|
| 5 |
-
MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
|
| 6 |
-
CLIENT_MODEL_KWARGS = {
|
| 7 |
-
"max_tokens": 800,
|
| 8 |
-
"temperature": 0.6,
|
| 9 |
-
}
|
| 10 |
-
|
| 11 |
-
GUIDE_KWARGS = {
|
| 12 |
-
"expert_model": "HuggingFaceH4/zephyr-7b-beta",
|
| 13 |
-
# "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
| 14 |
-
"inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
|
| 15 |
-
# "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
|
| 16 |
-
"llm_backend": "HFChat",
|
| 17 |
-
"classifier_kwargs": {
|
| 18 |
-
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
| 19 |
-
"inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
| 20 |
-
"batch_size": 8,
|
| 21 |
-
},
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
|
| 25 |
def process_config(config):
|
| 26 |
if "HF_TOKEN" not in os.environ:
|
|
@@ -37,8 +16,8 @@ def process_config(config):
|
|
| 37 |
raise ValueError("config.yaml is missing client url.")
|
| 38 |
client_kwargs["api_key"] = os.getenv("HF_TOKEN")
|
| 39 |
client_kwargs["llm_backend"] = "HFChat"
|
| 40 |
-
client_kwargs["temperature"] =
|
| 41 |
-
client_kwargs["max_tokens"] =
|
| 42 |
else:
|
| 43 |
raise ValueError("config.yaml is missing client_llm settings.")
|
| 44 |
|
|
@@ -67,7 +46,7 @@ def process_config(config):
|
|
| 67 |
else:
|
| 68 |
raise ValueError("config.yaml is missing classifier url.")
|
| 69 |
if "batch_size" in config["classifier_llm"]:
|
| 70 |
-
guide_kwargs["classifier_kwargs"]["batch_size"] = config["classifier_llm"]["batch_size"]
|
| 71 |
else:
|
| 72 |
raise ValueError("config.yaml is missing classifier batch_size.")
|
| 73 |
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
|
|
|
|
| 1 |
import os
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
def process_config(config):
|
| 5 |
if "HF_TOKEN" not in os.environ:
|
|
|
|
| 16 |
raise ValueError("config.yaml is missing client url.")
|
| 17 |
client_kwargs["api_key"] = os.getenv("HF_TOKEN")
|
| 18 |
client_kwargs["llm_backend"] = "HFChat"
|
| 19 |
+
client_kwargs["temperature"] = config["client_llm"].get("temperature",.6)
|
| 20 |
+
client_kwargs["max_tokens"] = config["client_llm"].get("max_tokens",800)
|
| 21 |
else:
|
| 22 |
raise ValueError("config.yaml is missing client_llm settings.")
|
| 23 |
|
|
|
|
| 46 |
else:
|
| 47 |
raise ValueError("config.yaml is missing classifier url.")
|
| 48 |
if "batch_size" in config["classifier_llm"]:
|
| 49 |
+
guide_kwargs["classifier_kwargs"]["batch_size"] = int(config["classifier_llm"]["batch_size"])
|
| 50 |
else:
|
| 51 |
raise ValueError("config.yaml is missing classifier batch_size.")
|
| 52 |
guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
|
config.yaml
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
client_llm:
|
| 2 |
url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
| 3 |
model_id: "HuggingFaceH4/zephyr-7b-beta"
|
|
|
|
|
|
|
| 4 |
expert_llm:
|
| 5 |
url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
| 6 |
model_id: "HuggingFaceH4/zephyr-7b-beta"
|
| 7 |
classifier_llm:
|
| 8 |
model_id: "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
|
| 9 |
url: "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
|
| 10 |
-
batch_size: 8
|
|
|
|
| 1 |
client_llm:
|
| 2 |
url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
| 3 |
model_id: "HuggingFaceH4/zephyr-7b-beta"
|
| 4 |
+
max_tokens: 800
|
| 5 |
+
temperature: 0.6
|
| 6 |
expert_llm:
|
| 7 |
url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
|
| 8 |
model_id: "HuggingFaceH4/zephyr-7b-beta"
|
| 9 |
classifier_llm:
|
| 10 |
model_id: "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
|
| 11 |
url: "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
|
| 12 |
+
batch_size: 8
|