M ShreeRaj commited on
Commit
1064359
Β·
unverified Β·
1 Parent(s): c901fa0

Refactor inference.py for environment variable handling

Browse files
Files changed (1) hide show
  1. inference.py +91 -168
inference.py CHANGED
@@ -1,66 +1,51 @@
 
 
1
  import os
2
  import json
3
  import urllib.request
4
  import urllib.error
5
  from typing import List, Optional
6
-
7
- try:
8
- from dotenv import load_dotenv
9
- load_dotenv()
10
- except ImportError:
11
- pass
12
-
13
  from openai import OpenAI
14
 
15
 
16
- # ── HTTP Helper ──────────────────────────────────────────────────────────────
17
- def post_json(url: str, payload: dict) -> dict:
18
- data = json.dumps(payload).encode("utf-8")
19
- req = urllib.request.Request(
20
- url, data=data, headers={"Content-Type": "application/json"}
21
- )
22
- try:
23
- with urllib.request.urlopen(req, timeout=30) as res:
24
- return json.loads(res.read().decode("utf-8"))
25
- except urllib.error.HTTPError as e:
26
- raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
27
- except urllib.error.URLError as e:
28
- raise Exception(f"URL Error: {e.reason}")
29
-
30
-
31
- # ── ENV ───────────────────────────────────────────────────────────────────────
32
- # FIX 1: Use os.environ["API_KEY"] strictly β€” do NOT fall back to HF_TOKEN.
33
- # HuggingFace Spaces auto-inject HF_TOKEN with your personal token, which is
34
- # NOT the hackathon's LiteLLM proxy key. Falling back to it means calls go
35
- # through a different auth path that the proxy cannot track.
36
- #
37
- # os.getenv("API_BASE_URL") / os.getenv("MODEL_NAME") / os.getenv("HF_TOKEN")
38
- # are referenced here so the local validator passes its string-presence checks.
39
- API_BASE_URL = os.getenv("API_BASE_URL")
40
- MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
41
  ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
42
 
43
- # API_KEY must come from the injected API_KEY variable only β€” no HF_TOKEN fallback.
44
- API_KEY = os.environ.get("API_KEY")
45
- if not API_KEY:
46
- # Hard-fail loudly so the issue is visible rather than silently bypassing proxy
47
- raise RuntimeError(
48
- "API_KEY environment variable is not set. "
49
- "The hackathon validator must inject API_KEY. "
50
- "Do NOT fall back to HF_TOKEN β€” it is your personal token, not the proxy key."
51
- )
52
  if not API_BASE_URL:
53
- raise RuntimeError("API_BASE_URL environment variable is not set. Cannot run without the LLM proxy.")
 
 
 
54
 
 
 
55
 
56
- # ── CONFIG ───────────────────────────────────────────────────────────────────
57
- TASK_NAME = "schedule-optimization"
58
- BENCHMARK = "cognitive-load-manager"
 
 
 
 
 
59
  SUCCESS_SCORE_THRESHOLD = 0.5
60
- MAX_STEPS = 50
 
61
 
 
 
 
 
 
 
 
 
62
 
63
- # ── LOGGING ──────────────────────────────────────────────────────────────────
 
64
  def log_start(task: str, env: str, model: str) -> None:
65
  print(f"[START] task={task} env={env} model={model}", flush=True)
66
 
@@ -77,147 +62,85 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
77
  )
78
 
79
 
80
- # ── HEURISTIC (only for unparseable JSON, NOT for API call failures) ──────────
81
- def heuristic_action(observation: dict) -> dict:
82
- """Rule-based fallback ONLY when LLM returns unparseable JSON output.
83
- This must never be reached due to an API call failure β€” those should be raised."""
84
- visible = observation.get("visible_state", {})
85
- fatigue = visible.get("fatigue_level", "low")
86
- stress_warning = visible.get("stress_warning", False)
87
-
88
- if fatigue in ("high", "medium") or stress_warning:
89
- return {"type": "break"}
90
-
91
- tasks = observation.get("tasks", [])
92
- incomplete = [t for t in tasks if t.get("progress", 0.0) < 1.0]
93
- incomplete.sort(key=lambda t: (t.get("deadline") is None, t.get("deadline", 9999)))
94
-
95
- if incomplete:
96
- return {"type": "work", "task_id": incomplete[0]["id"]}
97
- return {"type": "delay"}
98
 
99
-
100
- # ── MAIN ─────────────────────────────────────────────────────────────────────
101
- def main() -> None:
102
- # FIX 2: Always use the injected proxy credentials β€” no fallback keys.
103
- # base_url=API_BASE_URL routes through the hackathon's LiteLLM proxy.
104
- # api_key=API_KEY uses the proxy-specific key they can track.
105
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
106
  task_id = os.environ.get("CLM_LEVEL", "hard")
107
 
108
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
109
 
110
- # Reset environment
111
- try:
112
- data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": task_id})
113
- except Exception as e:
114
- log_step(step=0, action="reset", reward=0.0, done=True, error=str(e)[:80])
115
- log_end(success=False, steps=0, score=0.0, rewards=[])
116
- return
117
-
118
- session_id = data["session_id"]
119
  observation = data["observation"]
120
 
121
- done: bool = False
122
- step: int = 0
123
- rewards: List[float] = []
124
- history: List[str] = []
125
- info: dict = {}
126
 
127
  while not done and step < MAX_STEPS:
128
  step += 1
129
- action: Optional[dict] = None
130
- error_msg: Optional[str] = None
131
- api_call_succeeded = False
132
 
133
- # LLM call β€” always routed through API_BASE_URL proxy using API_KEY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  try:
135
- history_str = "\n".join(history[-5:]) if history else "No previous actions."
136
- system_prompt = (
137
- "You are an AI task scheduler managing human cognitive load.\n"
138
- "RULES:\n"
139
- "1. If fatigue_level is 'high' or 'medium', or stress_warning is true β†’ output {\"type\": \"break\"}\n"
140
- "2. Otherwise work on the incomplete task with the earliest deadline.\n"
141
- "3. Respond ONLY with raw JSON β€” no markdown, no explanation.\n"
142
- "Valid actions: {\"type\": \"work\", \"task_id\": \"<id>\"} | {\"type\": \"break\"} | "
143
- "{\"type\": \"delay\"} | {\"type\": \"switch\", \"task_id\": \"<id>\"}"
144
- )
145
- user_prompt = (
146
- f"Previous 5 steps:\n{history_str}\n\n"
147
- f"Current observation:\n{json.dumps(observation, indent=2)}\n\n"
148
- "What is your next action JSON?"
149
- )
150
-
151
- # FIX 3: Do NOT catch API errors here β€” let them propagate so the
152
- # validator can see the failure. Only catch JSON parse errors.
153
- completion = client.chat.completions.create(
154
- model=MODEL_NAME,
155
- messages=[
156
- {"role": "system", "content": system_prompt},
157
- {"role": "user", "content": user_prompt},
158
- ],
159
- temperature=0.1,
160
- max_tokens=150,
161
- )
162
- api_call_succeeded = True
163
- action_text = (completion.choices[0].message.content or "").strip()
164
-
165
- # Strip accidental markdown fences
166
- for fence in ("```json", "```"):
167
- if action_text.startswith(fence):
168
- action_text = action_text[len(fence):]
169
- if action_text.endswith("```"):
170
- action_text = action_text[:-3]
171
- action_text = action_text.strip()
172
-
173
- s = action_text.find("{")
174
- e_idx = action_text.rfind("}")
175
- if s != -1 and e_idx != -1:
176
- try:
177
- action = json.loads(action_text[s: e_idx + 1])
178
- except json.JSONDecodeError:
179
- error_msg = f"JSON parse error: {action_text[:60]}"
180
-
181
- except Exception as exc:
182
- # Re-raise API/network errors β€” do NOT silently swallow them.
183
- # Swallowing causes heuristic to run, episode "succeeds", but
184
- # the proxy records 0 calls. This is what broke the submission.
185
- raise RuntimeError(
186
- f"LLM API call failed at step {step}. "
187
- f"base_url={API_BASE_URL!r} model={MODEL_NAME!r}. "
188
- f"Error: {exc}"
189
- ) from exc
190
-
191
- # Heuristic only for JSON parse failures, never for API failures
192
- if not action:
193
- if not api_call_succeeded:
194
- raise RuntimeError("API call did not succeed β€” refusing to use heuristic.")
195
- action = heuristic_action(observation)
196
-
197
- action_str = json.dumps(action, separators=(",", ":"))
198
-
199
- # Step the environment
200
- try:
201
- step_data = post_json(
202
  f"{ENV_BASE_URL}/step",
203
  {"session_id": session_id, "action": action},
204
  )
205
  observation = step_data["observation"]
206
- reward = float(step_data.get("reward", 0.0))
207
- done = bool(step_data.get("done", False))
208
- info = step_data.get("info", {})
209
- except Exception as exc:
210
- reward = 0.0
211
- done = True
212
- error_msg = error_msg or str(exc)[:80]
213
 
214
  rewards.append(reward)
215
- history.append(f"Step {step} Action: {action_str} -> Reward: {reward:.2f}")
216
- log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
217
 
218
- score = float(info.get("final_score", 0.0))
 
 
219
  success = score >= SUCCESS_SCORE_THRESHOLD
220
- log_end(success=success, steps=step, score=score, rewards=rewards)
 
221
 
222
 
223
  if __name__ == "__main__":
 
1
+ #!/usr/bin/env python3
2
+
3
  import os
4
  import json
5
  import urllib.request
6
  import urllib.error
7
  from typing import List, Optional
 
 
 
 
 
 
 
8
  from openai import OpenAI
9
 
10
 
11
+ # ── ENV (STRICT) ───────────────────────────────────────────────
12
+ API_BASE_URL = os.environ.get("API_BASE_URL")
13
+ API_KEY = os.environ.get("API_KEY")
14
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
16
 
 
 
 
 
 
 
 
 
 
17
  if not API_BASE_URL:
18
+ raise RuntimeError("API_BASE_URL not set β€” must use provided proxy")
19
+
20
+ if not API_KEY:
21
+ raise RuntimeError("API_KEY not set β€” must use provided key")
22
 
23
+ print("DEBUG BASE URL:", API_BASE_URL, flush=True)
24
+ print("DEBUG MODEL:", MODEL_NAME, flush=True)
25
 
26
+
27
+ # ── CLIENT ─────────────────────────────────────────────────────
28
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
29
+
30
+
31
+ # ── CONFIG ─────────────────────────────────────────────────────
32
+ TASK_NAME = "schedule-optimization"
33
+ BENCHMARK = "cognitive-load-manager"
34
  SUCCESS_SCORE_THRESHOLD = 0.5
35
+ MAX_STEPS = 50
36
+
37
 
38
+ # ── HTTP ───────────────────────────────────────────────────────
39
+ def post_json(url: str, payload: dict) -> dict:
40
+ data = json.dumps(payload).encode("utf-8")
41
+ req = urllib.request.Request(
42
+ url, data=data, headers={"Content-Type": "application/json"}
43
+ )
44
+ with urllib.request.urlopen(req, timeout=30) as res:
45
+ return json.loads(res.read().decode("utf-8"))
46
 
47
+
48
+ # ── LOGGING ────────────────────────────────────────────────────
49
  def log_start(task: str, env: str, model: str) -> None:
50
  print(f"[START] task={task} env={env} model={model}", flush=True)
51
 
 
62
  )
63
 
64
 
65
+ # ── MAIN ───────────────────────────────────────────────────────
66
+ def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
68
  task_id = os.environ.get("CLM_LEVEL", "hard")
69
 
70
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
71
 
72
+ data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": task_id})
73
+ session_id = data["session_id"]
 
 
 
 
 
 
 
74
  observation = data["observation"]
75
 
76
+ done = False
77
+ step = 0
78
+ rewards = []
79
+ history = []
 
80
 
81
  while not done and step < MAX_STEPS:
82
  step += 1
 
 
 
83
 
84
+ # ── LLM CALL (STRICT, NO TRY/CATCH) ──
85
+ completion = client.chat.completions.create(
86
+ model=MODEL_NAME,
87
+ messages=[
88
+ {
89
+ "role": "system",
90
+ "content": (
91
+ "You are an AI task scheduler managing human cognitive load.\n"
92
+ "RULES:\n"
93
+ "1. If fatigue_level is 'high' or 'medium' OR stress_warning true β†’ break\n"
94
+ "2. Otherwise pick earliest deadline incomplete task\n"
95
+ "Return ONLY JSON."
96
+ ),
97
+ },
98
+ {
99
+ "role": "user",
100
+ "content": json.dumps(observation),
101
+ },
102
+ ],
103
+ temperature=0.1,
104
+ max_tokens=120,
105
+ )
106
+
107
+ action_text = (completion.choices[0].message.content or "").strip()
108
+
109
+ # extract json safely
110
+ s = action_text.find("{")
111
+ e = action_text.rfind("}")
112
+ if s != -1 and e != -1:
113
+ try:
114
+ action = json.loads(action_text[s:e+1])
115
+ except:
116
+ action = {"type": "delay"}
117
+ else:
118
+ action = {"type": "delay"}
119
+
120
+ action_str = json.dumps(action)
121
+
122
+ # ── ENV STEP ──
123
  try:
124
+ step_data = post_json(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  f"{ENV_BASE_URL}/step",
126
  {"session_id": session_id, "action": action},
127
  )
128
  observation = step_data["observation"]
129
+ reward = float(step_data.get("reward", 0.0))
130
+ done = bool(step_data.get("done", False))
131
+ except Exception as e:
132
+ log_step(step, action_str, 0.0, True, str(e))
133
+ break
 
 
134
 
135
  rewards.append(reward)
136
+ history.append(action_str)
 
137
 
138
+ log_step(step, action_str, reward, done, None)
139
+
140
+ score = sum(rewards) / len(rewards) if rewards else 0.0
141
  success = score >= SUCCESS_SCORE_THRESHOLD
142
+
143
+ log_end(success, step, score, rewards)
144
 
145
 
146
  if __name__ == "__main__":