vikash-nuvai commited on
Commit
bbe8627
·
1 Parent(s): 5d20aef

fix: restore top-level openai import and client init per sample format

Browse files
Files changed (1) hide show
  1. inference.py +9 -14
inference.py CHANGED
@@ -23,6 +23,7 @@ import time
23
  import traceback
24
 
25
  import requests
 
26
 
27
  # ---------------------------------------------------------------------------
28
  # Required environment variables
@@ -32,6 +33,11 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
32
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
33
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
34
 
 
 
 
 
 
35
  # ---------------------------------------------------------------------------
36
  # System prompt
37
  # ---------------------------------------------------------------------------
@@ -64,14 +70,6 @@ STRATEGY:
64
  Respond with ONLY valid JSON. No explanation, no markdown, no extra text."""
65
 
66
 
67
- def get_client():
68
- """Lazily create an OpenAI client. Returns None if openai is unavailable."""
69
- try:
70
- from openai import OpenAI
71
- return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
72
- except Exception as e:
73
- print(f"WARNING: Could not create OpenAI client: {e}", flush=True)
74
- return None
75
 
76
 
77
  def parse_action(text: str) -> dict:
@@ -110,7 +108,7 @@ def parse_action(text: str) -> dict:
110
  return {"command": "observe"}
111
 
112
 
113
- def run_episode(task_id: str, client) -> dict:
114
  """Run one episode of the tiffin packing task."""
115
  # Emit [START] structured output for the validator
116
  print(f"[START] task={task_id}", flush=True)
@@ -162,8 +160,6 @@ def run_episode(task_id: str, client) -> dict:
162
 
163
  # Get LLM decision
164
  try:
165
- if client is None:
166
- raise RuntimeError("No OpenAI client available")
167
  response = client.chat.completions.create(
168
  model=MODEL_NAME,
169
  messages=messages,
@@ -276,14 +272,13 @@ def main():
276
  print(f" Env: {ENV_URL}", flush=True)
277
  print("=" * 60, flush=True)
278
 
279
- # Create client lazily — don't crash on import
280
- client = get_client()
281
 
282
  start_time = time.time()
283
  results = {}
284
 
285
  for task_id in ["easy", "medium", "hard"]:
286
- result = run_episode(task_id, client)
287
  results[task_id] = result
288
 
289
  elapsed = time.time() - start_time
 
23
  import traceback
24
 
25
  import requests
26
+ from openai import OpenAI
27
 
28
  # ---------------------------------------------------------------------------
29
  # Required environment variables
 
33
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
34
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
35
 
36
+ if not HF_TOKEN:
37
+ print("WARNING: HF_TOKEN not set. LLM calls will fail.", flush=True)
38
+
39
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
40
+
41
  # ---------------------------------------------------------------------------
42
  # System prompt
43
  # ---------------------------------------------------------------------------
 
70
  Respond with ONLY valid JSON. No explanation, no markdown, no extra text."""
71
 
72
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  def parse_action(text: str) -> dict:
 
108
  return {"command": "observe"}
109
 
110
 
111
+ def run_episode(task_id: str) -> dict:
112
  """Run one episode of the tiffin packing task."""
113
  # Emit [START] structured output for the validator
114
  print(f"[START] task={task_id}", flush=True)
 
160
 
161
  # Get LLM decision
162
  try:
 
 
163
  response = client.chat.completions.create(
164
  model=MODEL_NAME,
165
  messages=messages,
 
272
  print(f" Env: {ENV_URL}", flush=True)
273
  print("=" * 60, flush=True)
274
 
275
+
 
276
 
277
  start_time = time.time()
278
  results = {}
279
 
280
  for task_id in ["easy", "medium", "hard"]:
281
+ result = run_episode(task_id)
282
  results[task_id] = result
283
 
284
  elapsed = time.time() - start_time