CystronCode commited on
Commit
8a168fe
Β·
verified Β·
1 Parent(s): 71ebbfc

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +17 -9
inference.py CHANGED
@@ -8,8 +8,8 @@ Usage
8
  -----
9
  python inference.py
10
 
11
- # With LLM:
12
- OPENAI_API_KEY=sk-... python inference.py
13
 
14
  # Against a different server:
15
  ENV_BASE_URL=https://... python inference.py
@@ -21,9 +21,13 @@ import sys
21
  import urllib.request
22
  from typing import Any, Dict
23
 
24
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
25
- ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://cystroncode-api-gateway-defender.hf.space")
26
- LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
 
 
 
 
27
 
28
  TASK_IDS = ["easy", "medium", "hard"]
29
 
@@ -83,6 +87,7 @@ def _heuristic_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
83
  # ─── LLM agent ────────────────────────────────────────────────────────────────
84
 
85
  def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
 
86
  inner_obs = obs.get("observation", obs)
87
  sample = inner_obs.get("recent_requests", [])[:25]
88
  payload = json.dumps({
@@ -100,11 +105,13 @@ def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
100
  "max_tokens": 256,
101
  "temperature": 0.1,
102
  }).encode()
 
 
103
  req = urllib.request.Request(
104
- "https://api.openai.com/v1/chat/completions",
105
  data=payload,
106
  headers={"Content-Type": "application/json",
107
- "Authorization": f"Bearer {OPENAI_API_KEY}"},
108
  )
109
  with urllib.request.urlopen(req, timeout=30) as resp:
110
  raw = json.loads(resp.read())["choices"][0]["message"]["content"].strip()
@@ -125,7 +132,8 @@ def run_task(task_id: str) -> Dict[str, Any]:
125
 
126
  for step_num in range(1, 6):
127
  try:
128
- action = _llm_action(task_id, obs) if OPENAI_API_KEY else _heuristic_action(task_id, obs)
 
129
  except Exception:
130
  action = _heuristic_action(task_id, obs)
131
 
@@ -161,4 +169,4 @@ def main():
161
 
162
 
163
  if __name__ == "__main__":
164
- main()
 
8
  -----
9
  python inference.py
10
 
11
+ # With LLM proxy (injected by validator):
12
+ API_BASE_URL=https://... API_KEY=... python inference.py
13
 
14
  # Against a different server:
15
  ENV_BASE_URL=https://... python inference.py
 
21
  import urllib.request
22
  from typing import Any, Dict
23
 
24
+ # Use the LiteLLM proxy credentials injected by the validator.
25
+ # API_BASE_URL must end WITHOUT a trailing slash for /chat/completions appending.
26
+ API_KEY = os.getenv("API_KEY", os.getenv("OPENAI_API_KEY", ""))
27
+ _raw_base = os.getenv("API_BASE_URL", "").rstrip("/")
28
+ LLM_BASE_URL = _raw_base if _raw_base else "https://api.openai.com/v1"
29
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://cystroncode-api-gateway-defender.hf.space")
30
+ LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
31
 
32
  TASK_IDS = ["easy", "medium", "hard"]
33
 
 
87
  # ─── LLM agent ────────────────────────────────────────────────────────────────
88
 
89
  def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
90
+ """Call the LiteLLM proxy supplied by the validator via API_BASE_URL / API_KEY."""
91
  inner_obs = obs.get("observation", obs)
92
  sample = inner_obs.get("recent_requests", [])[:25]
93
  payload = json.dumps({
 
105
  "max_tokens": 256,
106
  "temperature": 0.1,
107
  }).encode()
108
+ # Always route through the validator-injected LiteLLM proxy endpoint
109
+ llm_url = f"{LLM_BASE_URL}/chat/completions"
110
  req = urllib.request.Request(
111
+ llm_url,
112
  data=payload,
113
  headers={"Content-Type": "application/json",
114
+ "Authorization": f"Bearer {API_KEY}"},
115
  )
116
  with urllib.request.urlopen(req, timeout=30) as resp:
117
  raw = json.loads(resp.read())["choices"][0]["message"]["content"].strip()
 
132
 
133
  for step_num in range(1, 6):
134
  try:
135
+ # Use LLM if a key is available (prefers validator-injected API_KEY)
136
+ action = _llm_action(task_id, obs) if API_KEY else _heuristic_action(task_id, obs)
137
  except Exception:
138
  action = _heuristic_action(task_id, obs)
139
 
 
169
 
170
 
171
  if __name__ == "__main__":
172
+ main()