Upload inference.py with huggingface_hub
Browse files- inference.py +6 -9
inference.py
CHANGED
|
@@ -23,9 +23,10 @@ from meta_ads_env.models import Action
|
|
| 23 |
from meta_ads_env.tasks import TASK_REGISTRY
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
REQUIRED_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
|
| 30 |
BENCHMARK = "meta_ads_attribution_openenv"
|
| 31 |
MAX_TOKENS = 300
|
|
@@ -452,8 +453,8 @@ def run_task(client: OpenAI, task_id: str) -> int:
|
|
| 452 |
def main() -> int:
|
| 453 |
global API_BASE_URL, MODEL_NAME, API_KEY
|
| 454 |
_load_env_file(Path(__file__).resolve().with_name(".env"))
|
| 455 |
-
API_BASE_URL = os.getenv("API_BASE_URL")
|
| 456 |
-
MODEL_NAME = os.getenv("MODEL_NAME")
|
| 457 |
API_KEY = os.getenv("HF_TOKEN")
|
| 458 |
|
| 459 |
missing = []
|
|
@@ -465,10 +466,6 @@ def main() -> int:
|
|
| 465 |
missing.append("HF_TOKEN")
|
| 466 |
if missing:
|
| 467 |
raise EnvironmentError(f"Missing required environment variables: {', '.join(missing)}")
|
| 468 |
-
if MODEL_NAME != REQUIRED_MODEL_NAME:
|
| 469 |
-
raise EnvironmentError(
|
| 470 |
-
f"MODEL_NAME must be '{REQUIRED_MODEL_NAME}' for this codebase. Got: '{MODEL_NAME}'"
|
| 471 |
-
)
|
| 472 |
|
| 473 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 474 |
|
|
|
|
| 23 |
from meta_ads_env.tasks import TASK_REGISTRY
|
| 24 |
|
| 25 |
|
| 26 |
+
DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
|
| 27 |
+
API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
|
| 28 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 29 |
+
API_KEY = os.getenv("HF_TOKEN","")
|
| 30 |
REQUIRED_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
|
| 31 |
BENCHMARK = "meta_ads_attribution_openenv"
|
| 32 |
MAX_TOKENS = 300
|
|
|
|
| 453 |
def main() -> int:
|
| 454 |
global API_BASE_URL, MODEL_NAME, API_KEY
|
| 455 |
_load_env_file(Path(__file__).resolve().with_name(".env"))
|
| 456 |
+
API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
|
| 457 |
+
MODEL_NAME = os.getenv("MODEL_NAME", REQUIRED_MODEL_NAME)
|
| 458 |
API_KEY = os.getenv("HF_TOKEN")
|
| 459 |
|
| 460 |
missing = []
|
|
|
|
| 466 |
missing.append("HF_TOKEN")
|
| 467 |
if missing:
|
| 468 |
raise EnvironmentError(f"Missing required environment variables: {', '.join(missing)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 471 |
|