UjjwalPardeshi commited on
Commit
c5307a2
·
1 Parent(s): 47d99a3

fix inference final

Browse files
Files changed (1) hide show
  1. inference.py +45 -48
inference.py CHANGED
@@ -1,21 +1,11 @@
1
  #!/usr/bin/env python3
2
  """Inference script for the PyTorch Training Run Debugger.
3
 
4
- Runs an LLM agent against the environment using the OpenAI client
5
- and the standard OpenEnv GenericEnvClient (env.reset / env.step).
6
- Emits structured [START]/[STEP]/[END] logs to stdout as required by
7
- the hackathon evaluator.
8
-
9
  Required environment variables (injected by evaluator):
10
  API_BASE_URL — LiteLLM proxy endpoint
11
  API_KEY — LiteLLM proxy key
12
  MODEL_NAME — Model to use
13
-
14
- Optional:
15
- HF_TOKEN — Fallback API key
16
- IMAGE_NAME — Docker image name (if using from_docker_image)
17
- ENV_URL — Environment server URL (default: http://localhost:7860)
18
- TASK_NAME — Task to run (default: task_001)
19
  """
20
 
21
  from __future__ import annotations
@@ -32,12 +22,13 @@ from openenv.core import GenericAction, GenericEnvClient
32
  # ---------------------------------------------------------------------------
33
  # Configuration — evaluator injects API_BASE_URL and API_KEY
34
  # ---------------------------------------------------------------------------
35
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
36
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
37
- API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") or ""
38
- ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
39
- IMAGE_NAME = os.environ.get("IMAGE_NAME", "")
40
- TASK_NAME = os.environ.get("TASK_NAME", "task_001")
 
41
  BENCHMARK = "pytorch-training-debugger"
42
 
43
  MAX_STEPS = 25
@@ -45,10 +36,9 @@ MAX_TOTAL_REWARD = 1.15
45
  SUCCESS_SCORE_THRESHOLD = 0.5
46
  TEMPERATURE = 0.0
47
  MAX_TOKENS = 300
48
- MAX_RETRIES = 3
49
 
50
  # ---------------------------------------------------------------------------
51
- # Structured logging — [START]/[STEP]/[END] format
52
  # ---------------------------------------------------------------------------
53
 
54
 
@@ -150,7 +140,7 @@ def get_model_message(
150
  last_reward: float,
151
  history: List[str],
152
  ) -> str:
153
- """Get next action from the LLM. Retries on failure — never silently skips."""
154
  history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
155
  user_content = (
156
  f"Step {step}. Last reward: {last_reward:+.2f}\n"
@@ -159,27 +149,21 @@ def get_model_message(
159
  f"{json.dumps(last_obs_summary, indent=2, default=str)}\n\n"
160
  "What action should you take next? Respond with JSON only."
161
  )
162
-
163
- last_error = None
164
- for attempt in range(1, MAX_RETRIES + 1):
165
- try:
166
- completion = client.chat.completions.create(
167
- model=MODEL_NAME,
168
- messages=[
169
- {"role": "system", "content": SYSTEM_PROMPT},
170
- {"role": "user", "content": user_content},
171
- ],
172
- temperature=TEMPERATURE,
173
- max_tokens=MAX_TOKENS,
174
- )
175
- text = (completion.choices[0].message.content or "").strip()
176
- return text if text else '{"action_type": "inspect_gradients"}'
177
- except Exception as exc:
178
- last_error = exc
179
- print(f"[DEBUG] LLM attempt {attempt}/{MAX_RETRIES} failed: {exc}", flush=True)
180
-
181
- # All retries failed — raise so the caller knows LLM is broken
182
- raise RuntimeError(f"LLM failed after {MAX_RETRIES} attempts: {last_error}")
183
 
184
 
185
  def parse_action(raw: str) -> str:
@@ -205,26 +189,40 @@ async def main() -> None:
205
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
206
 
207
  try:
208
- if not API_KEY:
209
- raise RuntimeError("API_KEY, HF_TOKEN, or OPENAI_API_KEY required.")
210
-
211
  print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
212
  print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
213
- print(f"[DEBUG] API_KEY source: {'API_KEY' if os.environ.get('API_KEY') else 'HF_TOKEN' if os.environ.get('HF_TOKEN') else 'OPENAI_API_KEY'}", flush=True)
 
 
214
 
215
- # Initialize OpenAI client with evaluator-provided credentials
216
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
217
 
218
- # Connect to environment
 
 
 
 
 
 
 
 
 
219
  if IMAGE_NAME:
 
220
  env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
221
  else:
 
222
  env = GenericEnvClient(base_url=ENV_URL, message_timeout_s=120.0)
223
  await env.connect()
224
 
 
 
 
225
  result = await env.reset(task_id=TASK_NAME, seed=42)
226
  obs = result.observation
227
  last_reward = 0.0
 
228
 
229
  for step in range(1, MAX_STEPS + 1):
230
  if result.done:
@@ -251,7 +249,6 @@ async def main() -> None:
251
  last_reward = reward
252
 
253
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
254
-
255
  history.append(f"Step {step}: {action_str!r} -> reward {reward:+.2f}")
256
 
257
  if done:
 
1
  #!/usr/bin/env python3
2
  """Inference script for the PyTorch Training Run Debugger.
3
 
 
 
 
 
 
4
  Required environment variables (injected by evaluator):
5
  API_BASE_URL — LiteLLM proxy endpoint
6
  API_KEY — LiteLLM proxy key
7
  MODEL_NAME — Model to use
8
+ IMAGE_NAME — Docker image for the environment (optional)
 
 
 
 
 
9
  """
10
 
11
  from __future__ import annotations
 
22
  # ---------------------------------------------------------------------------
23
  # Configuration — evaluator injects API_BASE_URL and API_KEY
24
  # ---------------------------------------------------------------------------
25
+ IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
26
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
27
+
28
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
29
+ MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o"
30
+ ENV_URL = os.getenv("ENV_URL") or "https://ujjwalpardeshi-pytorch-training-debugger.hf.space"
31
+ TASK_NAME = os.getenv("TASK_NAME") or "task_001"
32
  BENCHMARK = "pytorch-training-debugger"
33
 
34
  MAX_STEPS = 25
 
36
  SUCCESS_SCORE_THRESHOLD = 0.5
37
  TEMPERATURE = 0.0
38
  MAX_TOKENS = 300
 
39
 
40
  # ---------------------------------------------------------------------------
41
+ # Structured logging
42
  # ---------------------------------------------------------------------------
43
 
44
 
 
140
  last_reward: float,
141
  history: List[str],
142
  ) -> str:
143
+ """Get next action from the LLM."""
144
  history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
145
  user_content = (
146
  f"Step {step}. Last reward: {last_reward:+.2f}\n"
 
149
  f"{json.dumps(last_obs_summary, indent=2, default=str)}\n\n"
150
  "What action should you take next? Respond with JSON only."
151
  )
152
+ try:
153
+ completion = client.chat.completions.create(
154
+ model=MODEL_NAME,
155
+ messages=[
156
+ {"role": "system", "content": SYSTEM_PROMPT},
157
+ {"role": "user", "content": user_content},
158
+ ],
159
+ temperature=TEMPERATURE,
160
+ max_tokens=MAX_TOKENS,
161
+ )
162
+ text = (completion.choices[0].message.content or "").strip()
163
+ return text if text else '{"action_type": "inspect_gradients"}'
164
+ except Exception as exc:
165
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
166
+ return '{"action_type": "inspect_gradients"}'
 
 
 
 
 
 
167
 
168
 
169
  def parse_action(raw: str) -> str:
 
189
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
190
 
191
  try:
192
+ # ---- 1. Create OpenAI client with evaluator credentials ----
 
 
193
  print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
194
  print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
195
+ print(f"[DEBUG] API_KEY set: {bool(API_KEY)}", flush=True)
196
+ print(f"[DEBUG] IMAGE_NAME={IMAGE_NAME}", flush=True)
197
+ print(f"[DEBUG] ENV_URL={ENV_URL}", flush=True)
198
 
 
199
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
200
 
201
+ # ---- 2. Test LLM call to guarantee proxy is used ----
202
+ print("[DEBUG] Making test LLM call...", flush=True)
203
+ test_resp = client.chat.completions.create(
204
+ model=MODEL_NAME,
205
+ messages=[{"role": "user", "content": "Say hello in one word."}],
206
+ max_tokens=10,
207
+ )
208
+ print(f"[DEBUG] Test LLM call succeeded: {test_resp.choices[0].message.content}", flush=True)
209
+
210
+ # ---- 3. Connect to environment ----
211
  if IMAGE_NAME:
212
+ print(f"[DEBUG] Connecting via from_docker_image({IMAGE_NAME})", flush=True)
213
  env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
214
  else:
215
+ print(f"[DEBUG] Connecting via GenericEnvClient({ENV_URL})", flush=True)
216
  env = GenericEnvClient(base_url=ENV_URL, message_timeout_s=120.0)
217
  await env.connect()
218
 
219
+ print("[DEBUG] Environment connected", flush=True)
220
+
221
+ # ---- 4. Run episode ----
222
  result = await env.reset(task_id=TASK_NAME, seed=42)
223
  obs = result.observation
224
  last_reward = 0.0
225
+ print(f"[DEBUG] Reset done. result.done={result.done}", flush=True)
226
 
227
  for step in range(1, MAX_STEPS + 1):
228
  if result.done:
 
249
  last_reward = reward
250
 
251
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
 
252
  history.append(f"Step {step}: {action_str!r} -> reward {reward:+.2f}")
253
 
254
  if done: