UjjwalPardeshi commited on
Commit ·
dbae750
1
Parent(s): fc246c9
fix: prioritize API_KEY over HF_TOKEN for LLM proxy, fail loudly on first call
Browse files- inference.py +12 -2
inference.py
CHANGED
|
@@ -23,7 +23,7 @@ from openenv.core import GenericAction, GenericEnvClient
|
|
| 23 |
# Configuration — matches sample inference script exactly
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
|
| 26 |
-
API_KEY = os.getenv("
|
| 27 |
|
| 28 |
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
|
| 29 |
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
|
|
@@ -138,6 +138,7 @@ def get_model_message(
|
|
| 138 |
last_obs_summary: dict,
|
| 139 |
last_reward: float,
|
| 140 |
history: List[str],
|
|
|
|
| 141 |
) -> str:
|
| 142 |
"""Get next action from the LLM."""
|
| 143 |
history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
|
|
@@ -162,6 +163,9 @@ def get_model_message(
|
|
| 162 |
return text if text else '{"action_type": "inspect_gradients"}'
|
| 163 |
except Exception as exc:
|
| 164 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
|
|
|
|
|
|
|
|
|
| 165 |
return '{"action_type": "inspect_gradients"}'
|
| 166 |
|
| 167 |
|
|
@@ -186,6 +190,9 @@ async def main() -> None:
|
|
| 186 |
env = None
|
| 187 |
|
| 188 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
try:
|
| 191 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
|
@@ -209,7 +216,10 @@ async def main() -> None:
|
|
| 209 |
break
|
| 210 |
|
| 211 |
obs_summary = _build_obs_summary(obs)
|
| 212 |
-
raw = get_model_message(
|
|
|
|
|
|
|
|
|
|
| 213 |
action_str = parse_action(raw)
|
| 214 |
|
| 215 |
action = GenericAction(json.loads(action_str))
|
|
|
|
| 23 |
# Configuration — matches sample inference script exactly
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
|
| 26 |
+
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
|
| 27 |
|
| 28 |
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
|
| 29 |
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
|
|
|
|
| 138 |
last_obs_summary: dict,
|
| 139 |
last_reward: float,
|
| 140 |
history: List[str],
|
| 141 |
+
is_first_call: bool = False,
|
| 142 |
) -> str:
|
| 143 |
"""Get next action from the LLM."""
|
| 144 |
history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
|
|
|
|
| 163 |
return text if text else '{"action_type": "inspect_gradients"}'
|
| 164 |
except Exception as exc:
|
| 165 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 166 |
+
# On first call, re-raise so we know the proxy isn't working
|
| 167 |
+
if is_first_call:
|
| 168 |
+
raise
|
| 169 |
return '{"action_type": "inspect_gradients"}'
|
| 170 |
|
| 171 |
|
|
|
|
| 190 |
env = None
|
| 191 |
|
| 192 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 193 |
+
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
|
| 194 |
+
print(f"[DEBUG] API_KEY={'set' if API_KEY else 'NOT SET'} (source={'API_KEY' if os.getenv('API_KEY') else 'HF_TOKEN' if os.getenv('HF_TOKEN') else 'NONE'})", flush=True)
|
| 195 |
+
print(f"[DEBUG] IMAGE_NAME={IMAGE_NAME}", flush=True)
|
| 196 |
|
| 197 |
try:
|
| 198 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
|
|
|
| 216 |
break
|
| 217 |
|
| 218 |
obs_summary = _build_obs_summary(obs)
|
| 219 |
+
raw = get_model_message(
|
| 220 |
+
client, step, obs_summary, last_reward, history,
|
| 221 |
+
is_first_call=(step == 1),
|
| 222 |
+
)
|
| 223 |
action_str = parse_action(raw)
|
| 224 |
|
| 225 |
action = GenericAction(json.loads(action_str))
|