UjjwalPardeshi commited on
Commit
dbae750
·
1 Parent(s): fc246c9

fix: prioritize API_KEY over HF_TOKEN for LLM proxy, fail loudly on first call

Browse files
Files changed (1) hide show
  1. inference.py +12 -2
inference.py CHANGED
@@ -23,7 +23,7 @@ from openenv.core import GenericAction, GenericEnvClient
23
  # Configuration — matches sample inference script exactly
24
  # ---------------------------------------------------------------------------
25
  IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
26
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
27
 
28
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
29
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
@@ -138,6 +138,7 @@ def get_model_message(
138
  last_obs_summary: dict,
139
  last_reward: float,
140
  history: List[str],
 
141
  ) -> str:
142
  """Get next action from the LLM."""
143
  history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
@@ -162,6 +163,9 @@ def get_model_message(
162
  return text if text else '{"action_type": "inspect_gradients"}'
163
  except Exception as exc:
164
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
 
 
 
165
  return '{"action_type": "inspect_gradients"}'
166
 
167
 
@@ -186,6 +190,9 @@ async def main() -> None:
186
  env = None
187
 
188
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
 
 
 
189
 
190
  try:
191
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
@@ -209,7 +216,10 @@ async def main() -> None:
209
  break
210
 
211
  obs_summary = _build_obs_summary(obs)
212
- raw = get_model_message(client, step, obs_summary, last_reward, history)
 
 
 
213
  action_str = parse_action(raw)
214
 
215
  action = GenericAction(json.loads(action_str))
 
23
  # Configuration — matches sample inference script exactly
24
  # ---------------------------------------------------------------------------
25
  IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
26
+ API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
27
 
28
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
29
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
 
138
  last_obs_summary: dict,
139
  last_reward: float,
140
  history: List[str],
141
+ is_first_call: bool = False,
142
  ) -> str:
143
  """Get next action from the LLM."""
144
  history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
 
163
  return text if text else '{"action_type": "inspect_gradients"}'
164
  except Exception as exc:
165
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
166
+ # On first call, re-raise so we know the proxy isn't working
167
+ if is_first_call:
168
+ raise
169
  return '{"action_type": "inspect_gradients"}'
170
 
171
 
 
190
  env = None
191
 
192
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
193
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
194
+ print(f"[DEBUG] API_KEY={'set' if API_KEY else 'NOT SET'} (source={'API_KEY' if os.getenv('API_KEY') else 'HF_TOKEN' if os.getenv('HF_TOKEN') else 'NONE'})", flush=True)
195
+ print(f"[DEBUG] IMAGE_NAME={IMAGE_NAME}", flush=True)
196
 
197
  try:
198
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
216
  break
217
 
218
  obs_summary = _build_obs_summary(obs)
219
+ raw = get_model_message(
220
+ client, step, obs_summary, last_reward, history,
221
+ is_first_call=(step == 1),
222
+ )
223
  action_str = parse_action(raw)
224
 
225
  action = GenericAction(json.loads(action_str))