Aryanshh commited on
Commit
39605fc
·
1 Parent(s): 48e1617

Robustness: Prevent unhandled exceptions in inference.py and align with exact sample STDOUT

Browse files
Files changed (1) hide show
  1. inference.py +93 -68
inference.py CHANGED
@@ -1,95 +1,120 @@
1
  import json
2
  import os
3
  import sys
4
- from typing import Optional
 
 
5
  from openai import OpenAI
6
  import httpx
7
 
8
  # ---------------------------------------------------------------------------
9
- # Config
10
  # ---------------------------------------------------------------------------
11
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
12
- MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
13
- HF_TOKEN = os.getenv("HF_TOKEN")
14
 
15
- if not HF_TOKEN:
16
- print("ERROR: HF_TOKEN environment variable is required")
 
 
 
 
17
  sys.exit(1)
18
 
19
  # OpenAI Client configured via environment variables
20
- client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
21
-
22
- # Environment Server URL (defaults to local for dev)
23
- ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
24
 
25
  # ---------------------------------------------------------------------------
26
- # Prompting
27
  # ---------------------------------------------------------------------------
28
- SYSTEM_PROMPT = """\
29
- You are an Eco-Resilient Logistics Agent. Your goal is to fulfill orders while minimizing CO2.
30
- Respond ONLY with a valid JSON completion.
31
 
32
- Available Actions:
33
- {
34
- "action_type": "order_parts | produce | offset | skip",
35
- "part_type": "chips | sensors | batteries | casing",
36
- "quantity": number,
37
- "mode": "sea | air | rail | road",
38
- "product": "EcoPhone | GreenTab",
39
- "offset_amount": number
40
- }
41
- """
 
 
 
 
 
 
42
 
43
- def get_action(obs):
44
- prompt = f"Current State: {json.dumps(obs, indent=2)}\nChoose next action:"
45
- response = client.chat.completions.create(
46
- model=MODEL_NAME,
47
- messages=[
48
- {"role": "system", "content": SYSTEM_PROMPT},
49
- {"role": "user", "content": prompt}
50
- ],
51
- response_format={"type": "json_object"}
52
- )
53
- return json.loads(response.choices[0].message.content)
 
 
 
 
 
 
54
 
55
  # ---------------------------------------------------------------------------
56
  # Runner
57
  # ---------------------------------------------------------------------------
58
  def run_task(task_name: str):
59
- print(f"[START] task={task_name} env=netzero-nav model={MODEL_NAME}", flush=True)
 
 
 
60
 
61
- with httpx.Client(base_url=ENV_URL, timeout=30.0) as app:
62
- obs = app.post("/reset", json={"task": task_name}).json()
63
-
64
- done = False
65
- step = 0
66
- rewards = []
67
- score = 0.0
68
- success = False
69
-
70
- while not done and step < 50:
71
- step += 1
72
- action_json = get_action(obs)
73
- resp = app.post("/step", json=action_json).json()
74
-
75
- obs = resp["observation"]
76
- reward = float(resp["reward"] or 0.0)
77
- rewards.append(reward)
78
- done = resp["done"]
79
- info = resp.get("info", {})
80
- error_val = info.get("error", "null")
81
-
82
- # Format action string minimally without quotes inside the action bracket (for visual parsing ease)
83
- action_str = f"{action_json['action_type']}"
84
- if action_json.get("part_type"): action_str += f"-{action_json['part_type']}"
85
-
86
- print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
87
 
88
- if done:
89
- score = float(info.get("final_score", 0.0))
90
- success = score >= 0.99
91
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
92
- print(f"[END] success={str(success).lower()} steps={step} score={score:.4f} rewards={rewards_str}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  if __name__ == "__main__":
95
  for task in ["easy", "medium", "hard"]:
 
1
  import json
2
  import os
3
  import sys
4
+ import time
5
+ import textwrap
6
+ from typing import List, Optional
7
  from openai import OpenAI
8
  import httpx
9
 
10
  # ---------------------------------------------------------------------------
11
+ # Config (MANDATORY per Checklist)
12
  # ---------------------------------------------------------------------------
13
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
14
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
15
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
16
 
17
+ # Environment Server URL
18
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
19
+
20
+ if not API_KEY:
21
+ # We print and exit to avoid unhandled exceptions later
22
+ print("ERROR: HF_TOKEN or API_KEY environment variable is required", flush=True)
23
  sys.exit(1)
24
 
25
  # OpenAI Client configured via environment variables
26
+ client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
 
 
 
27
 
28
  # ---------------------------------------------------------------------------
29
+ # Logging Utilities
30
  # ---------------------------------------------------------------------------
31
+ def log_start(task: str, env: str, model: str) -> None:
32
+ print(f"[START] task={task} env={env} model={model}", flush=True)
 
33
 
34
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
35
+ error_val = error if error else "null"
36
+ done_val = str(done).lower()
37
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
38
+
39
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
40
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
41
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Agent Logic
45
+ # ---------------------------------------------------------------------------
46
+ SYSTEM_PROMPT = """You are an Eco-Resilient Logistics Agent.
47
+ Your goal is to fulfill orders while minimizing CO2.
48
+ Available Actions: {"action_type": "order_parts | produce | offset | skip", "part_type": "chips | sensors | batteries | casing", "quantity": count, "mode": "sea | air | rail | road", "product": "EcoPhone | GreenTab"}
49
+ Respond ONLY with a valid JSON object."""
50
 
51
+ def get_action(obs) -> dict:
52
+ prompt = f"Current Observation: {json.dumps(obs)}\nChoose next action:"
53
+ try:
54
+ response = client.chat.completions.create(
55
+ model=MODEL_NAME,
56
+ messages=[
57
+ {"role": "system", "content": SYSTEM_PROMPT},
58
+ {"role": "user", "content": prompt}
59
+ ],
60
+ response_format={"type": "json_object"},
61
+ timeout=15.0
62
+ )
63
+ return json.loads(response.choices[0].message.content)
64
+ except Exception as e:
65
+ # Emergency fallback to prevent script crash
66
+ print(f"[DEBUG] Model error: {e}", file=sys.stderr)
67
+ return {"action_type": "skip"}
68
 
69
  # ---------------------------------------------------------------------------
70
  # Runner
71
  # ---------------------------------------------------------------------------
72
  def run_task(task_name: str):
73
+ success = False
74
+ score = 0.0
75
+ steps_taken = 0
76
+ rewards = []
77
 
78
+ log_start(task=task_name, env="netzero-nav", model=MODEL_NAME)
79
+
80
+ try:
81
+ with httpx.Client(base_url=ENV_URL, timeout=30.0) as app:
82
+ # Reset environment
83
+ resp = app.post("/reset", json={"task": task_name})
84
+ obs = resp.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ done = False
87
+ while not done and steps_taken < 50:
88
+ steps_taken += 1
89
+ action_json = get_action(obs)
90
+
91
+ # Take step
92
+ resp = app.post("/step", json=action_json).json()
93
+
94
+ obs = resp["observation"]
95
+ reward = float(resp["reward"] or 0.0)
96
+ done = resp["done"]
97
+ info = resp.get("info", {})
98
+ error = info.get("error")
99
+
100
+ rewards.append(reward)
101
+
102
+ # Format action for logs
103
+ act_type = action_json.get("action_type", "skip")
104
+ act_part = action_json.get("part_type", "")
105
+ act_str = f"{act_type}-{act_part}" if act_part else act_type
106
+
107
+ log_step(step=steps_taken, action=act_str, reward=reward, done=done, error=error)
108
+
109
+ if done:
110
+ score = float(info.get("final_score", 0.0))
111
+ success = score >= 0.99
112
+
113
+ except Exception as e:
114
+ print(f"[DEBUG] Runtime error during task {task_name}: {e}", file=sys.stderr)
115
+
116
+ finally:
117
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
118
 
119
  if __name__ == "__main__":
120
  for task in ["easy", "medium", "hard"]: