vikash-nuvai commited on
Commit ·
bbe8627
1
Parent(s): 5d20aef
fix: restore top-level openai import and client init per sample format
Browse files- inference.py +9 -14
inference.py
CHANGED
|
@@ -23,6 +23,7 @@ import time
|
|
| 23 |
import traceback
|
| 24 |
|
| 25 |
import requests
|
|
|
|
| 26 |
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
# Required environment variables
|
|
@@ -32,6 +33,11 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
|
| 32 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 33 |
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# ---------------------------------------------------------------------------
|
| 36 |
# System prompt
|
| 37 |
# ---------------------------------------------------------------------------
|
|
@@ -64,14 +70,6 @@ STRATEGY:
|
|
| 64 |
Respond with ONLY valid JSON. No explanation, no markdown, no extra text."""
|
| 65 |
|
| 66 |
|
| 67 |
-
def get_client():
|
| 68 |
-
"""Lazily create an OpenAI client. Returns None if openai is unavailable."""
|
| 69 |
-
try:
|
| 70 |
-
from openai import OpenAI
|
| 71 |
-
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
|
| 72 |
-
except Exception as e:
|
| 73 |
-
print(f"WARNING: Could not create OpenAI client: {e}", flush=True)
|
| 74 |
-
return None
|
| 75 |
|
| 76 |
|
| 77 |
def parse_action(text: str) -> dict:
|
|
@@ -110,7 +108,7 @@ def parse_action(text: str) -> dict:
|
|
| 110 |
return {"command": "observe"}
|
| 111 |
|
| 112 |
|
| 113 |
-
def run_episode(task_id: str
|
| 114 |
"""Run one episode of the tiffin packing task."""
|
| 115 |
# Emit [START] structured output for the validator
|
| 116 |
print(f"[START] task={task_id}", flush=True)
|
|
@@ -162,8 +160,6 @@ def run_episode(task_id: str, client) -> dict:
|
|
| 162 |
|
| 163 |
# Get LLM decision
|
| 164 |
try:
|
| 165 |
-
if client is None:
|
| 166 |
-
raise RuntimeError("No OpenAI client available")
|
| 167 |
response = client.chat.completions.create(
|
| 168 |
model=MODEL_NAME,
|
| 169 |
messages=messages,
|
|
@@ -276,14 +272,13 @@ def main():
|
|
| 276 |
print(f" Env: {ENV_URL}", flush=True)
|
| 277 |
print("=" * 60, flush=True)
|
| 278 |
|
| 279 |
-
|
| 280 |
-
client = get_client()
|
| 281 |
|
| 282 |
start_time = time.time()
|
| 283 |
results = {}
|
| 284 |
|
| 285 |
for task_id in ["easy", "medium", "hard"]:
|
| 286 |
-
result = run_episode(task_id
|
| 287 |
results[task_id] = result
|
| 288 |
|
| 289 |
elapsed = time.time() - start_time
|
|
|
|
| 23 |
import traceback
|
| 24 |
|
| 25 |
import requests
|
| 26 |
+
from openai import OpenAI
|
| 27 |
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
# Required environment variables
|
|
|
|
| 33 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 34 |
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 35 |
|
| 36 |
+
if not HF_TOKEN:
|
| 37 |
+
print("WARNING: HF_TOKEN not set. LLM calls will fail.", flush=True)
|
| 38 |
+
|
| 39 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 40 |
+
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
# System prompt
|
| 43 |
# ---------------------------------------------------------------------------
|
|
|
|
| 70 |
Respond with ONLY valid JSON. No explanation, no markdown, no extra text."""
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
def parse_action(text: str) -> dict:
|
|
|
|
| 108 |
return {"command": "observe"}
|
| 109 |
|
| 110 |
|
| 111 |
+
def run_episode(task_id: str) -> dict:
|
| 112 |
"""Run one episode of the tiffin packing task."""
|
| 113 |
# Emit [START] structured output for the validator
|
| 114 |
print(f"[START] task={task_id}", flush=True)
|
|
|
|
| 160 |
|
| 161 |
# Get LLM decision
|
| 162 |
try:
|
|
|
|
|
|
|
| 163 |
response = client.chat.completions.create(
|
| 164 |
model=MODEL_NAME,
|
| 165 |
messages=messages,
|
|
|
|
| 272 |
print(f" Env: {ENV_URL}", flush=True)
|
| 273 |
print("=" * 60, flush=True)
|
| 274 |
|
| 275 |
+
|
|
|
|
| 276 |
|
| 277 |
start_time = time.time()
|
| 278 |
results = {}
|
| 279 |
|
| 280 |
for task_id in ["easy", "medium", "hard"]:
|
| 281 |
+
result = run_episode(task_id)
|
| 282 |
results[task_id] = result
|
| 283 |
|
| 284 |
elapsed = time.time() - start_time
|