Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +17 -9
inference.py
CHANGED
|
@@ -8,8 +8,8 @@ Usage
|
|
| 8 |
-----
|
| 9 |
python inference.py
|
| 10 |
|
| 11 |
-
# With LLM:
|
| 12 |
-
|
| 13 |
|
| 14 |
# Against a different server:
|
| 15 |
ENV_BASE_URL=https://... python inference.py
|
|
@@ -21,9 +21,13 @@ import sys
|
|
| 21 |
import urllib.request
|
| 22 |
from typing import Any, Dict
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
TASK_IDS = ["easy", "medium", "hard"]
|
| 29 |
|
|
@@ -83,6 +87,7 @@ def _heuristic_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 83 |
# βββ LLM agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
|
| 85 |
def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
| 86 |
inner_obs = obs.get("observation", obs)
|
| 87 |
sample = inner_obs.get("recent_requests", [])[:25]
|
| 88 |
payload = json.dumps({
|
|
@@ -100,11 +105,13 @@ def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 100 |
"max_tokens": 256,
|
| 101 |
"temperature": 0.1,
|
| 102 |
}).encode()
|
|
|
|
|
|
|
| 103 |
req = urllib.request.Request(
|
| 104 |
-
|
| 105 |
data=payload,
|
| 106 |
headers={"Content-Type": "application/json",
|
| 107 |
-
"Authorization": f"Bearer {
|
| 108 |
)
|
| 109 |
with urllib.request.urlopen(req, timeout=30) as resp:
|
| 110 |
raw = json.loads(resp.read())["choices"][0]["message"]["content"].strip()
|
|
@@ -125,7 +132,8 @@ def run_task(task_id: str) -> Dict[str, Any]:
|
|
| 125 |
|
| 126 |
for step_num in range(1, 6):
|
| 127 |
try:
|
| 128 |
-
|
|
|
|
| 129 |
except Exception:
|
| 130 |
action = _heuristic_action(task_id, obs)
|
| 131 |
|
|
@@ -161,4 +169,4 @@ def main():
|
|
| 161 |
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
| 164 |
-
main()
|
|
|
|
| 8 |
-----
|
| 9 |
python inference.py
|
| 10 |
|
| 11 |
+
# With LLM proxy (injected by validator):
|
| 12 |
+
API_BASE_URL=https://... API_KEY=... python inference.py
|
| 13 |
|
| 14 |
# Against a different server:
|
| 15 |
ENV_BASE_URL=https://... python inference.py
|
|
|
|
| 21 |
import urllib.request
|
| 22 |
from typing import Any, Dict
|
| 23 |
|
| 24 |
+
# Use the LiteLLM proxy credentials injected by the validator.
|
| 25 |
+
# API_BASE_URL must end WITHOUT a trailing slash for /chat/completions appending.
|
| 26 |
+
API_KEY = os.getenv("API_KEY", os.getenv("OPENAI_API_KEY", ""))
|
| 27 |
+
_raw_base = os.getenv("API_BASE_URL", "").rstrip("/")
|
| 28 |
+
LLM_BASE_URL = _raw_base if _raw_base else "https://api.openai.com/v1"
|
| 29 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://cystroncode-api-gateway-defender.hf.space")
|
| 30 |
+
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
| 31 |
|
| 32 |
TASK_IDS = ["easy", "medium", "hard"]
|
| 33 |
|
|
|
|
| 87 |
# βββ LLM agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
|
| 89 |
def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
|
| 90 |
+
"""Call the LiteLLM proxy supplied by the validator via API_BASE_URL / API_KEY."""
|
| 91 |
inner_obs = obs.get("observation", obs)
|
| 92 |
sample = inner_obs.get("recent_requests", [])[:25]
|
| 93 |
payload = json.dumps({
|
|
|
|
| 105 |
"max_tokens": 256,
|
| 106 |
"temperature": 0.1,
|
| 107 |
}).encode()
|
| 108 |
+
# Always route through the validator-injected LiteLLM proxy endpoint
|
| 109 |
+
llm_url = f"{LLM_BASE_URL}/chat/completions"
|
| 110 |
req = urllib.request.Request(
|
| 111 |
+
llm_url,
|
| 112 |
data=payload,
|
| 113 |
headers={"Content-Type": "application/json",
|
| 114 |
+
"Authorization": f"Bearer {API_KEY}"},
|
| 115 |
)
|
| 116 |
with urllib.request.urlopen(req, timeout=30) as resp:
|
| 117 |
raw = json.loads(resp.read())["choices"][0]["message"]["content"].strip()
|
|
|
|
| 132 |
|
| 133 |
for step_num in range(1, 6):
|
| 134 |
try:
|
| 135 |
+
# Use LLM if a key is available (prefers validator-injected API_KEY)
|
| 136 |
+
action = _llm_action(task_id, obs) if API_KEY else _heuristic_action(task_id, obs)
|
| 137 |
except Exception:
|
| 138 |
action = _heuristic_action(task_id, obs)
|
| 139 |
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
if __name__ == "__main__":
|
| 172 |
+
main()
|