Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +23 -21
inference.py
CHANGED
|
@@ -7,8 +7,7 @@ endpoint and keeps request pacing under a hard RPM ceiling.
|
|
| 7 |
Environment Variables:
|
| 8 |
API_BASE_URL - LLM API endpoint
|
| 9 |
MODEL_NAME - model identifier
|
| 10 |
-
|
| 11 |
-
GEMINI_API_KEY - Gemini API key (preferred for Gemini endpoint)
|
| 12 |
ENV_URL - environment server URL (default: http://localhost:8000)
|
| 13 |
LLM_RPM_LIMIT - max model requests per minute (default: 5)
|
| 14 |
LLM_MAX_RETRIES - max rate-limit retries per request (default: 3)
|
|
@@ -42,12 +41,14 @@ API_BASE_URL = os.environ.get(
|
|
| 42 |
"https://api.openai.com/v1",
|
| 43 |
)
|
| 44 |
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 45 |
-
|
|
|
|
| 46 |
os.environ.get("HF_TOKEN")
|
| 47 |
or os.environ.get("OPENAI_API_KEY")
|
| 48 |
or os.environ.get("OPENROUTER_API_KEY")
|
| 49 |
or os.environ.get("GEMINI_API_KEY")
|
| 50 |
)
|
|
|
|
| 51 |
ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
|
| 52 |
|
| 53 |
LLM_RPM_LIMIT = max(1, int(os.environ.get("LLM_RPM_LIMIT", "5")))
|
|
@@ -282,7 +283,7 @@ def create_client() -> OpenAI:
|
|
| 282 |
}
|
| 283 |
|
| 284 |
return OpenAI(
|
| 285 |
-
api_key=
|
| 286 |
base_url=API_BASE_URL,
|
| 287 |
default_headers=extra_headers,
|
| 288 |
)
|
|
@@ -857,16 +858,16 @@ def run_task(client: OpenAI, scheduler: RequestScheduler, task_id: str) -> Dict[
|
|
| 857 |
|
| 858 |
def main() -> None:
|
| 859 |
"""Run the baseline across all tasks."""
|
| 860 |
-
if not
|
| 861 |
-
print("ERROR: Set OPENAI_API_KEY, OPENROUTER_API_KEY or GEMINI_API_KEY environment variable")
|
| 862 |
sys.exit(1)
|
| 863 |
|
| 864 |
-
print(f"API Base: {API_BASE_URL}")
|
| 865 |
-
print(f"Model: {MODEL_NAME}")
|
| 866 |
-
print(f"Environment: {ENV_URL}")
|
| 867 |
-
print(f"Tasks: {TASKS}")
|
| 868 |
-
print(f"RPM limit: {LLM_RPM_LIMIT} (min gap {MIN_REQUEST_GAP_SECONDS:.1f}s)")
|
| 869 |
-
print(f"Retries: {LLM_MAX_RETRIES}")
|
| 870 |
|
| 871 |
client = create_client()
|
| 872 |
scheduler = RequestScheduler(
|
|
@@ -887,20 +888,21 @@ def main() -> None:
|
|
| 887 |
total_requests = sum(result["requests"] for result in results)
|
| 888 |
total_retries = sum(result["retries"] for result in results)
|
| 889 |
|
| 890 |
-
print(f"\n{'=' * 60}")
|
| 891 |
-
print("BASELINE RESULTS")
|
| 892 |
-
print(f"{'=' * 60}")
|
| 893 |
for result in results:
|
| 894 |
print(
|
| 895 |
" "
|
| 896 |
f"{result['task_id']:20s} score={result['final_score']} "
|
| 897 |
f"decisions={result['llm_decisions']} requests={result['requests']} "
|
| 898 |
-
f"retries={result['retries']} elapsed={result['elapsed_seconds']}s"
|
|
|
|
| 899 |
)
|
| 900 |
-
print(f" Total requests: {total_requests}")
|
| 901 |
-
print(f" Total retries: {total_retries}")
|
| 902 |
-
print(f" Total elapsed: {elapsed}s ({elapsed / 60:.1f} min)")
|
| 903 |
-
print(f"{'=' * 60}")
|
| 904 |
|
| 905 |
output = {
|
| 906 |
"results": results,
|
|
@@ -911,7 +913,7 @@ def main() -> None:
|
|
| 911 |
"rpm_limit": LLM_RPM_LIMIT,
|
| 912 |
"min_request_gap_seconds": round(MIN_REQUEST_GAP_SECONDS, 1),
|
| 913 |
}
|
| 914 |
-
print(f"\n{json.dumps(output, indent=2, default=str)}")
|
| 915 |
|
| 916 |
|
| 917 |
if __name__ == "__main__":
|
|
|
|
| 7 |
Environment Variables:
|
| 8 |
API_BASE_URL - LLM API endpoint
|
| 9 |
MODEL_NAME - model identifier
|
| 10 |
+
HF_TOKEN - API key for auth
|
|
|
|
| 11 |
ENV_URL - environment server URL (default: http://localhost:8000)
|
| 12 |
LLM_RPM_LIMIT - max model requests per minute (default: 5)
|
| 13 |
LLM_MAX_RETRIES - max rate-limit retries per request (default: 3)
|
|
|
|
| 41 |
"https://api.openai.com/v1",
|
| 42 |
)
|
| 43 |
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 44 |
+
|
| 45 |
+
HF_TOKEN = (
|
| 46 |
os.environ.get("HF_TOKEN")
|
| 47 |
or os.environ.get("OPENAI_API_KEY")
|
| 48 |
or os.environ.get("OPENROUTER_API_KEY")
|
| 49 |
or os.environ.get("GEMINI_API_KEY")
|
| 50 |
)
|
| 51 |
+
|
| 52 |
ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
|
| 53 |
|
| 54 |
LLM_RPM_LIMIT = max(1, int(os.environ.get("LLM_RPM_LIMIT", "5")))
|
|
|
|
| 283 |
}
|
| 284 |
|
| 285 |
return OpenAI(
|
| 286 |
+
api_key=HF_TOKEN,
|
| 287 |
base_url=API_BASE_URL,
|
| 288 |
default_headers=extra_headers,
|
| 289 |
)
|
|
|
|
| 858 |
|
| 859 |
def main() -> None:
|
| 860 |
"""Run the baseline across all tasks."""
|
| 861 |
+
if not HF_TOKEN:
|
| 862 |
+
print("ERROR: Set HF_TOKEN, OPENAI_API_KEY, OPENROUTER_API_KEY or GEMINI_API_KEY environment variable")
|
| 863 |
sys.exit(1)
|
| 864 |
|
| 865 |
+
print(f"API Base: {API_BASE_URL}", file=sys.stderr)
|
| 866 |
+
print(f"Model: {MODEL_NAME}", file=sys.stderr)
|
| 867 |
+
print(f"Environment: {ENV_URL}", file=sys.stderr)
|
| 868 |
+
print(f"Tasks: {TASKS}", file=sys.stderr)
|
| 869 |
+
print(f"RPM limit: {LLM_RPM_LIMIT} (min gap {MIN_REQUEST_GAP_SECONDS:.1f}s)", file=sys.stderr)
|
| 870 |
+
print(f"Retries: {LLM_MAX_RETRIES}", file=sys.stderr)
|
| 871 |
|
| 872 |
client = create_client()
|
| 873 |
scheduler = RequestScheduler(
|
|
|
|
| 888 |
total_requests = sum(result["requests"] for result in results)
|
| 889 |
total_retries = sum(result["retries"] for result in results)
|
| 890 |
|
| 891 |
+
print(f"\n{'=' * 60}", file=sys.stderr)
|
| 892 |
+
print("BASELINE RESULTS", file=sys.stderr)
|
| 893 |
+
print(f"{'=' * 60}", file=sys.stderr)
|
| 894 |
for result in results:
|
| 895 |
print(
|
| 896 |
" "
|
| 897 |
f"{result['task_id']:20s} score={result['final_score']} "
|
| 898 |
f"decisions={result['llm_decisions']} requests={result['requests']} "
|
| 899 |
+
f"retries={result['retries']} elapsed={result['elapsed_seconds']}s",
|
| 900 |
+
file=sys.stderr
|
| 901 |
)
|
| 902 |
+
print(f" Total requests: {total_requests}", file=sys.stderr)
|
| 903 |
+
print(f" Total retries: {total_retries}", file=sys.stderr)
|
| 904 |
+
print(f" Total elapsed: {elapsed}s ({elapsed / 60:.1f} min)", file=sys.stderr)
|
| 905 |
+
print(f"{'=' * 60}", file=sys.stderr)
|
| 906 |
|
| 907 |
output = {
|
| 908 |
"results": results,
|
|
|
|
| 913 |
"rpm_limit": LLM_RPM_LIMIT,
|
| 914 |
"min_request_gap_seconds": round(MIN_REQUEST_GAP_SECONDS, 1),
|
| 915 |
}
|
| 916 |
+
print(f"\n{json.dumps(output, indent=2, default=str)}", file=sys.stderr)
|
| 917 |
|
| 918 |
|
| 919 |
if __name__ == "__main__":
|