UjjwalPardeshi commited on
Commit
47d99a3
Β·
1 Parent(s): 3956f8f

fix get model_messages

Browse files
Files changed (1) hide show
  1. inference.py +43 -44
inference.py CHANGED
@@ -6,16 +6,16 @@ and the standard OpenEnv GenericEnvClient (env.reset / env.step).
6
  Emits structured [START]/[STEP]/[END] logs to stdout as required by
7
  the hackathon evaluator.
8
 
9
- Required environment variables (set by hackathon evaluator):
10
- API_BASE_URL β€” OpenAI-compatible API endpoint
11
- MODEL_NAME β€” Model to use (e.g., "gpt-4o", "llama-3.3-70b")
12
- HF_TOKEN β€” Hugging Face token (used as API key if OPENAI_API_KEY not set)
13
 
14
  Optional:
15
- OPENAI_API_KEY β€” API key (takes precedence over HF_TOKEN)
 
16
  ENV_URL β€” Environment server URL (default: http://localhost:7860)
17
  TASK_NAME β€” Task to run (default: task_001)
18
- IMAGE_NAME β€” Docker image name (if set, uses from_docker_image)
19
  """
20
 
21
  from __future__ import annotations
@@ -26,20 +26,14 @@ import os
26
  import sys
27
  from typing import List, Optional
28
 
29
- try:
30
- from openai import OpenAI
31
- except ImportError:
32
- print("Error: openai package not installed. Run: pip install openai", flush=True)
33
- sys.exit(1)
34
-
35
  from openenv.core import GenericAction, GenericEnvClient
36
 
37
  # ---------------------------------------------------------------------------
38
- # Configuration from environment variables
39
  # ---------------------------------------------------------------------------
40
- # Evaluator injects API_BASE_URL and API_KEY β€” read them directly
41
- API_BASE_URL = os.environ.get("API_BASE_URL") or "https://api.openai.com/v1"
42
- MODEL_NAME = os.environ.get("MODEL_NAME") or "gpt-4o"
43
  API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") or ""
44
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
45
  IMAGE_NAME = os.environ.get("IMAGE_NAME", "")
@@ -47,16 +41,14 @@ TASK_NAME = os.environ.get("TASK_NAME", "task_001")
47
  BENCHMARK = "pytorch-training-debugger"
48
 
49
  MAX_STEPS = 25
50
- # Max achievable reward: +0.50 (diagnosis) +0.40 (convergence) +5*0.05 (investigations)
51
- # minus step penalties. Use 1.15 as the theoretical ceiling for normalization.
52
  MAX_TOTAL_REWARD = 1.15
53
  SUCCESS_SCORE_THRESHOLD = 0.5
54
  TEMPERATURE = 0.0
55
  MAX_TOKENS = 300
56
- FALLBACK_ACTION = '{"action_type": "inspect_gradients"}'
57
 
58
  # ---------------------------------------------------------------------------
59
- # Structured logging β€” [START]/[STEP]/[END] format (hackathon requirement)
60
  # ---------------------------------------------------------------------------
61
 
62
 
@@ -84,7 +76,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
84
 
85
 
86
  # ---------------------------------------------------------------------------
87
- # System prompt for the LLM agent
88
  # ---------------------------------------------------------------------------
89
  SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run.
90
  You are interacting with an environment that simulates a broken training job.
@@ -158,7 +150,7 @@ def get_model_message(
158
  last_reward: float,
159
  history: List[str],
160
  ) -> str:
161
- """Get next action from the LLM."""
162
  history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
163
  user_content = (
164
  f"Step {step}. Last reward: {last_reward:+.2f}\n"
@@ -167,21 +159,27 @@ def get_model_message(
167
  f"{json.dumps(last_obs_summary, indent=2, default=str)}\n\n"
168
  "What action should you take next? Respond with JSON only."
169
  )
170
- try:
171
- completion = client.chat.completions.create(
172
- model=MODEL_NAME,
173
- messages=[
174
- {"role": "system", "content": SYSTEM_PROMPT},
175
- {"role": "user", "content": user_content},
176
- ],
177
- temperature=TEMPERATURE,
178
- max_tokens=MAX_TOKENS,
179
- )
180
- text = (completion.choices[0].message.content or "").strip()
181
- return text if text else FALLBACK_ACTION
182
- except Exception as exc:
183
- print(f"[DEBUG] Model request failed: {exc}", flush=True)
184
- return FALLBACK_ACTION
 
 
 
 
 
 
185
 
186
 
187
  def parse_action(raw: str) -> str:
@@ -193,7 +191,7 @@ def parse_action(raw: str) -> str:
193
  json.loads(text)
194
  return text
195
  except json.JSONDecodeError:
196
- return FALLBACK_ACTION
197
 
198
 
199
  async def main() -> None:
@@ -210,13 +208,14 @@ async def main() -> None:
210
  if not API_KEY:
211
  raise RuntimeError("API_KEY, HF_TOKEN, or OPENAI_API_KEY required.")
212
 
213
- print(f"[DEBUG] Using API_BASE_URL={API_BASE_URL}", flush=True)
214
- print(f"[DEBUG] Using MODEL_NAME={MODEL_NAME}", flush=True)
215
  print(f"[DEBUG] API_KEY source: {'API_KEY' if os.environ.get('API_KEY') else 'HF_TOKEN' if os.environ.get('HF_TOKEN') else 'OPENAI_API_KEY'}", flush=True)
216
 
 
217
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
218
 
219
- # Connect to environment via standard OpenEnv client
220
  if IMAGE_NAME:
221
  env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
222
  else:
@@ -259,18 +258,18 @@ async def main() -> None:
259
  break
260
 
261
  score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
262
- score = min(max(score, 0.01), 0.99) # clamp to (0, 1) exclusive
263
  success = score >= SUCCESS_SCORE_THRESHOLD
264
 
265
  except Exception as exc:
266
- print(f"[DEBUG] Unhandled error: {exc}", flush=True)
267
 
268
  finally:
269
  if env is not None:
270
  try:
271
  await env.close()
272
  except Exception as e:
273
- print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
274
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
275
 
276
 
 
6
  Emits structured [START]/[STEP]/[END] logs to stdout as required by
7
  the hackathon evaluator.
8
 
9
+ Required environment variables (injected by evaluator):
10
+ API_BASE_URL β€” LiteLLM proxy endpoint
11
+ API_KEY β€” LiteLLM proxy key
12
+ MODEL_NAME β€” Model to use
13
 
14
  Optional:
15
+ HF_TOKEN β€” Fallback API key
16
+ IMAGE_NAME β€” Docker image name (if using from_docker_image)
17
  ENV_URL β€” Environment server URL (default: http://localhost:7860)
18
  TASK_NAME β€” Task to run (default: task_001)
 
19
  """
20
 
21
  from __future__ import annotations
 
26
  import sys
27
  from typing import List, Optional
28
 
29
+ from openai import OpenAI
 
 
 
 
 
30
  from openenv.core import GenericAction, GenericEnvClient
31
 
32
  # ---------------------------------------------------------------------------
33
+ # Configuration β€” evaluator injects API_BASE_URL and API_KEY
34
  # ---------------------------------------------------------------------------
35
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
36
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
 
37
  API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") or ""
38
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
39
  IMAGE_NAME = os.environ.get("IMAGE_NAME", "")
 
41
  BENCHMARK = "pytorch-training-debugger"
42
 
43
  MAX_STEPS = 25
 
 
44
  MAX_TOTAL_REWARD = 1.15
45
  SUCCESS_SCORE_THRESHOLD = 0.5
46
  TEMPERATURE = 0.0
47
  MAX_TOKENS = 300
48
+ MAX_RETRIES = 3
49
 
50
  # ---------------------------------------------------------------------------
51
+ # Structured logging β€” [START]/[STEP]/[END] format
52
  # ---------------------------------------------------------------------------
53
 
54
 
 
76
 
77
 
78
  # ---------------------------------------------------------------------------
79
+ # System prompt
80
  # ---------------------------------------------------------------------------
81
  SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run.
82
  You are interacting with an environment that simulates a broken training job.
 
150
  last_reward: float,
151
  history: List[str],
152
  ) -> str:
153
+ """Get next action from the LLM. Retries on failure β€” never silently skips."""
154
  history_ctx = "\n".join(history[-5:]) if history else "No previous steps."
155
  user_content = (
156
  f"Step {step}. Last reward: {last_reward:+.2f}\n"
 
159
  f"{json.dumps(last_obs_summary, indent=2, default=str)}\n\n"
160
  "What action should you take next? Respond with JSON only."
161
  )
162
+
163
+ last_error = None
164
+ for attempt in range(1, MAX_RETRIES + 1):
165
+ try:
166
+ completion = client.chat.completions.create(
167
+ model=MODEL_NAME,
168
+ messages=[
169
+ {"role": "system", "content": SYSTEM_PROMPT},
170
+ {"role": "user", "content": user_content},
171
+ ],
172
+ temperature=TEMPERATURE,
173
+ max_tokens=MAX_TOKENS,
174
+ )
175
+ text = (completion.choices[0].message.content or "").strip()
176
+ return text if text else '{"action_type": "inspect_gradients"}'
177
+ except Exception as exc:
178
+ last_error = exc
179
+ print(f"[DEBUG] LLM attempt {attempt}/{MAX_RETRIES} failed: {exc}", flush=True)
180
+
181
+ # All retries failed β€” raise so the caller knows LLM is broken
182
+ raise RuntimeError(f"LLM failed after {MAX_RETRIES} attempts: {last_error}")
183
 
184
 
185
  def parse_action(raw: str) -> str:
 
191
  json.loads(text)
192
  return text
193
  except json.JSONDecodeError:
194
+ return '{"action_type": "inspect_gradients"}'
195
 
196
 
197
  async def main() -> None:
 
208
  if not API_KEY:
209
  raise RuntimeError("API_KEY, HF_TOKEN, or OPENAI_API_KEY required.")
210
 
211
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True)
212
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True)
213
  print(f"[DEBUG] API_KEY source: {'API_KEY' if os.environ.get('API_KEY') else 'HF_TOKEN' if os.environ.get('HF_TOKEN') else 'OPENAI_API_KEY'}", flush=True)
214
 
215
+ # Initialize OpenAI client with evaluator-provided credentials
216
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
217
 
218
+ # Connect to environment
219
  if IMAGE_NAME:
220
  env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
221
  else:
 
258
  break
259
 
260
  score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
261
+ score = min(max(score, 0.01), 0.99)
262
  success = score >= SUCCESS_SCORE_THRESHOLD
263
 
264
  except Exception as exc:
265
+ print(f"[DEBUG] Error: {exc}", flush=True)
266
 
267
  finally:
268
  if env is not None:
269
  try:
270
  await env.close()
271
  except Exception as e:
272
+ print(f"[DEBUG] env.close() error: {e}", flush=True)
273
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
274
 
275