UjjwalPardeshi commited on
Commit
fc246c9
·
1 Parent(s): c5307a2

fix graders

Browse files
Files changed (1) hide show
  1. inference.py +6 -26
inference.py CHANGED
@@ -20,14 +20,13 @@ from openai import OpenAI
20
  from openenv.core import GenericAction, GenericEnvClient
21
 
22
  # ---------------------------------------------------------------------------
23
- # Configuration — evaluator injects API_BASE_URL and API_KEY
24
  # ---------------------------------------------------------------------------
25
  IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
26
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
27
 
28
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
29
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
30
- ENV_URL = os.getenv("ENV_URL") or "https://ujjwalpardeshi-pytorch-training-debugger.hf.space"
31
  TASK_NAME = os.getenv("TASK_NAME") or "task_001"
32
  BENCHMARK = "pytorch-training-debugger"
33
 
@@ -189,40 +188,21 @@ async def main() -> None:
189
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
190
 
191
  try:
192
- # ---- 1. Create OpenAI client with evaluator credentials ----
193
- print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
194
- print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
195
- print(f"[DEBUG] API_KEY set: {bool(API_KEY)}", flush=True)
196
- print(f"[DEBUG] IMAGE_NAME={IMAGE_NAME}", flush=True)
197
- print(f"[DEBUG] ENV_URL={ENV_URL}", flush=True)
198
-
199
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
200
 
201
- # ---- 2. Test LLM call to guarantee proxy is used ----
202
- print("[DEBUG] Making test LLM call...", flush=True)
203
- test_resp = client.chat.completions.create(
204
- model=MODEL_NAME,
205
- messages=[{"role": "user", "content": "Say hello in one word."}],
206
- max_tokens=10,
207
- )
208
- print(f"[DEBUG] Test LLM call succeeded: {test_resp.choices[0].message.content}", flush=True)
209
-
210
- # ---- 3. Connect to environment ----
211
  if IMAGE_NAME:
212
- print(f"[DEBUG] Connecting via from_docker_image({IMAGE_NAME})", flush=True)
213
  env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
214
  else:
215
- print(f"[DEBUG] Connecting via GenericEnvClient({ENV_URL})", flush=True)
216
- env = GenericEnvClient(base_url=ENV_URL, message_timeout_s=120.0)
 
 
217
  await env.connect()
218
 
219
- print("[DEBUG] Environment connected", flush=True)
220
-
221
- # ---- 4. Run episode ----
222
  result = await env.reset(task_id=TASK_NAME, seed=42)
223
  obs = result.observation
224
  last_reward = 0.0
225
- print(f"[DEBUG] Reset done. result.done={result.done}", flush=True)
226
 
227
  for step in range(1, MAX_STEPS + 1):
228
  if result.done:
 
20
  from openenv.core import GenericAction, GenericEnvClient
21
 
22
  # ---------------------------------------------------------------------------
23
+ # Configuration — matches sample inference script exactly
24
  # ---------------------------------------------------------------------------
25
  IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
26
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
27
 
28
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
29
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
 
30
  TASK_NAME = os.getenv("TASK_NAME") or "task_001"
31
  BENCHMARK = "pytorch-training-debugger"
32
 
 
188
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
189
 
190
  try:
 
 
 
 
 
 
 
191
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
192
 
193
+ # Connect to environment same pattern as sample script
 
 
 
 
 
 
 
 
 
194
  if IMAGE_NAME:
 
195
  env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
196
  else:
197
+ env = GenericEnvClient(
198
+ base_url=os.getenv("ENV_URL", "https://ujjwalpardeshi-pytorch-training-debugger.hf.space"),
199
+ message_timeout_s=120.0,
200
+ )
201
  await env.connect()
202
 
 
 
 
203
  result = await env.reset(task_id=TASK_NAME, seed=42)
204
  obs = result.observation
205
  last_reward = 0.0
 
206
 
207
  for step in range(1, MAX_STEPS + 1):
208
  if result.done: