BART-ender commited on
Commit
83df01b
·
verified ·
1 Parent(s): 685c05b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +23 -21
inference.py CHANGED
@@ -7,8 +7,7 @@ endpoint and keeps request pacing under a hard RPM ceiling.
7
  Environment Variables:
8
  API_BASE_URL - LLM API endpoint
9
  MODEL_NAME - model identifier
10
- OPENAI_API_KEY - API key for auth
11
- GEMINI_API_KEY - Gemini API key (preferred for Gemini endpoint)
12
  ENV_URL - environment server URL (default: http://localhost:8000)
13
  LLM_RPM_LIMIT - max model requests per minute (default: 5)
14
  LLM_MAX_RETRIES - max rate-limit retries per request (default: 3)
@@ -42,12 +41,14 @@ API_BASE_URL = os.environ.get(
42
  "https://api.openai.com/v1",
43
  )
44
  MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
45
- OPENAI_API_KEY = (
 
46
  os.environ.get("HF_TOKEN")
47
  or os.environ.get("OPENAI_API_KEY")
48
  or os.environ.get("OPENROUTER_API_KEY")
49
  or os.environ.get("GEMINI_API_KEY")
50
  )
 
51
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
52
 
53
  LLM_RPM_LIMIT = max(1, int(os.environ.get("LLM_RPM_LIMIT", "5")))
@@ -282,7 +283,7 @@ def create_client() -> OpenAI:
282
  }
283
 
284
  return OpenAI(
285
- api_key=OPENAI_API_KEY,
286
  base_url=API_BASE_URL,
287
  default_headers=extra_headers,
288
  )
@@ -857,16 +858,16 @@ def run_task(client: OpenAI, scheduler: RequestScheduler, task_id: str) -> Dict[
857
 
858
  def main() -> None:
859
  """Run the baseline across all tasks."""
860
- if not OPENAI_API_KEY:
861
- print("ERROR: Set OPENAI_API_KEY, OPENROUTER_API_KEY or GEMINI_API_KEY environment variable")
862
  sys.exit(1)
863
 
864
- print(f"API Base: {API_BASE_URL}")
865
- print(f"Model: {MODEL_NAME}")
866
- print(f"Environment: {ENV_URL}")
867
- print(f"Tasks: {TASKS}")
868
- print(f"RPM limit: {LLM_RPM_LIMIT} (min gap {MIN_REQUEST_GAP_SECONDS:.1f}s)")
869
- print(f"Retries: {LLM_MAX_RETRIES}")
870
 
871
  client = create_client()
872
  scheduler = RequestScheduler(
@@ -887,20 +888,21 @@ def main() -> None:
887
  total_requests = sum(result["requests"] for result in results)
888
  total_retries = sum(result["retries"] for result in results)
889
 
890
- print(f"\n{'=' * 60}")
891
- print("BASELINE RESULTS")
892
- print(f"{'=' * 60}")
893
  for result in results:
894
  print(
895
  " "
896
  f"{result['task_id']:20s} score={result['final_score']} "
897
  f"decisions={result['llm_decisions']} requests={result['requests']} "
898
- f"retries={result['retries']} elapsed={result['elapsed_seconds']}s"
 
899
  )
900
- print(f" Total requests: {total_requests}")
901
- print(f" Total retries: {total_retries}")
902
- print(f" Total elapsed: {elapsed}s ({elapsed / 60:.1f} min)")
903
- print(f"{'=' * 60}")
904
 
905
  output = {
906
  "results": results,
@@ -911,7 +913,7 @@ def main() -> None:
911
  "rpm_limit": LLM_RPM_LIMIT,
912
  "min_request_gap_seconds": round(MIN_REQUEST_GAP_SECONDS, 1),
913
  }
914
- print(f"\n{json.dumps(output, indent=2, default=str)}")
915
 
916
 
917
  if __name__ == "__main__":
 
7
  Environment Variables:
8
  API_BASE_URL - LLM API endpoint
9
  MODEL_NAME - model identifier
10
+ HF_TOKEN - API key for auth
 
11
  ENV_URL - environment server URL (default: http://localhost:8000)
12
  LLM_RPM_LIMIT - max model requests per minute (default: 5)
13
  LLM_MAX_RETRIES - max rate-limit retries per request (default: 3)
 
41
  "https://api.openai.com/v1",
42
  )
43
  MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
44
+
45
+ HF_TOKEN = (
46
  os.environ.get("HF_TOKEN")
47
  or os.environ.get("OPENAI_API_KEY")
48
  or os.environ.get("OPENROUTER_API_KEY")
49
  or os.environ.get("GEMINI_API_KEY")
50
  )
51
+
52
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
53
 
54
  LLM_RPM_LIMIT = max(1, int(os.environ.get("LLM_RPM_LIMIT", "5")))
 
283
  }
284
 
285
  return OpenAI(
286
+ api_key=HF_TOKEN,
287
  base_url=API_BASE_URL,
288
  default_headers=extra_headers,
289
  )
 
858
 
859
  def main() -> None:
860
  """Run the baseline across all tasks."""
861
+ if not HF_TOKEN:
862
+ print("ERROR: Set HF_TOKEN, OPENAI_API_KEY, OPENROUTER_API_KEY or GEMINI_API_KEY environment variable")
863
  sys.exit(1)
864
 
865
+ print(f"API Base: {API_BASE_URL}", file=sys.stderr)
866
+ print(f"Model: {MODEL_NAME}", file=sys.stderr)
867
+ print(f"Environment: {ENV_URL}", file=sys.stderr)
868
+ print(f"Tasks: {TASKS}", file=sys.stderr)
869
+ print(f"RPM limit: {LLM_RPM_LIMIT} (min gap {MIN_REQUEST_GAP_SECONDS:.1f}s)", file=sys.stderr)
870
+ print(f"Retries: {LLM_MAX_RETRIES}", file=sys.stderr)
871
 
872
  client = create_client()
873
  scheduler = RequestScheduler(
 
888
  total_requests = sum(result["requests"] for result in results)
889
  total_retries = sum(result["retries"] for result in results)
890
 
891
+ print(f"\n{'=' * 60}", file=sys.stderr)
892
+ print("BASELINE RESULTS", file=sys.stderr)
893
+ print(f"{'=' * 60}", file=sys.stderr)
894
  for result in results:
895
  print(
896
  " "
897
  f"{result['task_id']:20s} score={result['final_score']} "
898
  f"decisions={result['llm_decisions']} requests={result['requests']} "
899
+ f"retries={result['retries']} elapsed={result['elapsed_seconds']}s",
900
+ file=sys.stderr
901
  )
902
+ print(f" Total requests: {total_requests}", file=sys.stderr)
903
+ print(f" Total retries: {total_retries}", file=sys.stderr)
904
+ print(f" Total elapsed: {elapsed}s ({elapsed / 60:.1f} min)", file=sys.stderr)
905
+ print(f"{'=' * 60}", file=sys.stderr)
906
 
907
  output = {
908
  "results": results,
 
913
  "rpm_limit": LLM_RPM_LIMIT,
914
  "min_request_gap_seconds": round(MIN_REQUEST_GAP_SECONDS, 1),
915
  }
916
+ print(f"\n{json.dumps(output, indent=2, default=str)}", file=sys.stderr)
917
 
918
 
919
  if __name__ == "__main__":