TheAarvee05 commited on
Commit
8eced79
·
verified ·
1 Parent(s): 7f499e6

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +6 -9
inference.py CHANGED
@@ -23,9 +23,10 @@ from meta_ads_env.models import Action
23
  from meta_ads_env.tasks import TASK_REGISTRY
24
 
25
 
26
- API_BASE_URL = os.getenv("API_BASE_URL")
27
- MODEL_NAME = os.getenv("MODEL_NAME")
28
- API_KEY = os.getenv("HF_TOKEN")
 
29
  REQUIRED_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
30
  BENCHMARK = "meta_ads_attribution_openenv"
31
  MAX_TOKENS = 300
@@ -452,8 +453,8 @@ def run_task(client: OpenAI, task_id: str) -> int:
452
  def main() -> int:
453
  global API_BASE_URL, MODEL_NAME, API_KEY
454
  _load_env_file(Path(__file__).resolve().with_name(".env"))
455
- API_BASE_URL = os.getenv("API_BASE_URL")
456
- MODEL_NAME = os.getenv("MODEL_NAME")
457
  API_KEY = os.getenv("HF_TOKEN")
458
 
459
  missing = []
@@ -465,10 +466,6 @@ def main() -> int:
465
  missing.append("HF_TOKEN")
466
  if missing:
467
  raise EnvironmentError(f"Missing required environment variables: {', '.join(missing)}")
468
- if MODEL_NAME != REQUIRED_MODEL_NAME:
469
- raise EnvironmentError(
470
- f"MODEL_NAME must be '{REQUIRED_MODEL_NAME}' for this codebase. Got: '{MODEL_NAME}'"
471
- )
472
 
473
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
474
 
 
23
  from meta_ads_env.tasks import TASK_REGISTRY
24
 
25
 
26
+ DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
27
+ API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
28
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
29
+ API_KEY = os.getenv("HF_TOKEN","")
30
  REQUIRED_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
31
  BENCHMARK = "meta_ads_attribution_openenv"
32
  MAX_TOKENS = 300
 
453
  def main() -> int:
454
  global API_BASE_URL, MODEL_NAME, API_KEY
455
  _load_env_file(Path(__file__).resolve().with_name(".env"))
456
+ API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
457
+ MODEL_NAME = os.getenv("MODEL_NAME", REQUIRED_MODEL_NAME)
458
  API_KEY = os.getenv("HF_TOKEN")
459
 
460
  missing = []
 
466
  missing.append("HF_TOKEN")
467
  if missing:
468
  raise EnvironmentError(f"Missing required environment variables: {', '.join(missing)}")
 
 
 
 
469
 
470
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
471