soumi guria commited on
Commit
cb6df9a
Β·
unverified Β·
2 Parent(s): 80637473c18657

Merge pull request #3 from soumiguria/soumi

Browse files
Files changed (1) hide show
  1. inference.py +394 -56
inference.py CHANGED
@@ -1,3 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import urllib.request
@@ -19,6 +360,7 @@ except ImportError:
19
 
20
  from openai import OpenAI
21
 
 
22
  def post_json(url: str, payload: dict) -> dict:
23
  data = json.dumps(payload).encode("utf-8")
24
  req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"})
@@ -28,36 +370,31 @@ def post_json(url: str, payload: dict) -> dict:
28
  except urllib.error.HTTPError as e:
29
  raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
30
 
31
- # ── Environment variables ────────────────────────────────────────────────────
32
- # API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
33
- # HF_TOKEN = os.getenv("HF_TOKEN")
34
-
35
- # API_KEY = HF_TOKEN or os.getenv("API_KEY")
36
- # if not API_KEY:
37
- # raise ValueError("API_KEY environment variable is required")
38
 
 
39
  API_BASE_URL = os.environ.get("API_BASE_URL")
40
  API_KEY = os.environ.get("API_KEY")
41
  MODEL_NAME = os.environ.get("MODEL_NAME")
42
- ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
43
 
44
  if not API_BASE_URL:
45
  raise ValueError("API_BASE_URL must be set")
46
-
47
  if not API_KEY:
48
  raise ValueError("API_KEY must be set")
 
 
49
 
50
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
51
- ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
52
 
53
  TASK_NAME = "schedule-optimization"
54
  BENCHMARK = "cognitive-load-manager"
55
  SUCCESS_SCORE_THRESHOLD = 0.5
56
  MAX_STEPS = 50
57
 
 
58
  def log_start(task: str, env: str, model: str) -> None:
59
  print(f"[START] task={task} env={env} model={model}", flush=True)
60
 
 
61
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
62
  error_val = error if error else "null"
63
  done_val = str(done).lower()
@@ -66,14 +403,14 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
66
  flush=True,
67
  )
68
 
 
69
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
70
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
71
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
72
 
 
73
  def main():
74
- # Always initialise the OpenAI client using the proxy URL and API key.
75
- # The hackathon validator requires ALL LLM calls to go through API_BASE_URL
76
- # with the provided API_KEY β€” never bypass this with hardcoded credentials.
77
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
78
 
79
  task_id = os.getenv("CLM_LEVEL", "hard")
@@ -97,64 +434,64 @@ def main():
97
  history = []
98
  info = {}
99
 
100
- while not done and step < MAX_STEPS:
101
- step += 1
102
-
103
- # 2. Get next action from LLM via the hackathon proxy
104
- history_str = "\n".join(history[-5:]) if history else "No previous actions."
105
- system_prompt = """
106
- You are an AI task scheduler managing cognitive load.
107
  CRITICAL RULES:
108
  1. If "fatigue_level" is "high" or "medium", output {"type": "break"}. Do NOT work until fatigue is "low".
109
  2. If "stress_warning" is true, {"type": "break"} reduces stress safely.
110
  3. Find tasks where "progress" < 1.0. Output {"type": "work", "task_id": "<id>"}. Do NOT work on 1.0 tasks.
111
- 4. Respond ONLY with raw JSON format. No markdown blocks.
112
- Valid actions: {"type": "work", "task_id": "id"}, {"type": "break"}, {"type": "delay"}, {"type": "switch", "task_id": "id"}
113
- """
114
- user_prompt = f"""
115
- Previous 5 Steps History:
116
- {history_str}
117
-
118
- Current Observation:
119
- {json.dumps(observation, indent=2)}
120
-
121
- What is your next action JSON?
122
- """
123
  action = None
124
  error_msg = None
125
 
 
 
 
 
126
  try:
127
- completion = client.chat.completions.create(
128
  model=MODEL_NAME,
129
- messages=[
130
- {"role": "system", "content": system_prompt.strip()},
131
- {"role": "user", "content": user_prompt.strip()}
132
- ],
133
  temperature=0.1,
134
- max_tokens=150
135
  )
136
- action_text = (completion.choices[0].message.content or "").strip()
137
-
138
- # Strip accidental markdown code fences
139
- if action_text.startswith("```json"):
140
- action_text = action_text[7:]
141
- if action_text.startswith("```"):
142
- action_text = action_text[3:]
143
- if action_text.endswith("```"):
144
- action_text = action_text[:-3]
145
-
146
- start_idx = action_text.find("{")
147
- end_idx = action_text.rfind("}")
 
 
 
 
 
 
 
 
 
148
  if start_idx != -1 and end_idx != -1:
149
- action = json.loads(action_text[start_idx:end_idx + 1])
 
150
  except Exception as e:
151
  error_msg = str(e)[:50]
152
 
153
- # Fallback heuristic only if LLM call failed / returned unparseable output
154
  if not action:
155
  tasks = observation.get("tasks", [])
156
  incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
157
- if observation.get("visible_state", {}).get("fatigue_level") in ("high", "medium"):
 
158
  action = {"type": "break"}
159
  elif incomp:
160
  action = {"type": "work", "task_id": incomp[0]["id"]}
@@ -167,7 +504,7 @@ What is your next action JSON?
167
  try:
168
  step_data = post_json(f"{ENV_BASE_URL}/step", {
169
  "session_id": session_id,
170
- "action": action
171
  })
172
  observation = step_data["observation"]
173
  reward = step_data.get("reward", 0.0)
@@ -186,5 +523,6 @@ What is your next action JSON?
186
  success = score >= SUCCESS_SCORE_THRESHOLD
187
  log_end(success=success, steps=step, score=score, rewards=rewards)
188
 
 
189
  if __name__ == "__main__":
190
- main()
 
1
+ # # import os
2
+ # # import json
3
+ # # import urllib.request
4
+ # # import urllib.error
5
+ # # from typing import List, Optional
6
+
7
+ # # try:
8
+ # # from dotenv import load_dotenv
9
+ # # load_dotenv()
10
+ # # except ImportError:
11
+ # # pass
12
+
13
+ # # # /// script
14
+ # # # requires-python = ">=3.11"
15
+ # # # dependencies = [
16
+ # # # "openai",
17
+ # # # ]
18
+ # # # ///
19
+
20
+ # # from openai import OpenAI
21
+
22
+ # # def post_json(url: str, payload: dict) -> dict:
23
+ # # data = json.dumps(payload).encode("utf-8")
24
+ # # req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"})
25
+ # # try:
26
+ # # with urllib.request.urlopen(req) as res:
27
+ # # return json.loads(res.read().decode("utf-8"))
28
+ # # except urllib.error.HTTPError as e:
29
+ # # raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
30
+
31
+ # # # ── Environment variables ────────────────────────────────────────────────────
32
+ # # # API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
33
+ # # # HF_TOKEN = os.getenv("HF_TOKEN")
34
+
35
+ # # # API_KEY = HF_TOKEN or os.getenv("API_KEY")
36
+ # # # if not API_KEY:
37
+ # # # raise ValueError("API_KEY environment variable is required")
38
+
39
+ # # API_BASE_URL = os.environ.get("API_BASE_URL")
40
+ # # API_KEY = os.environ.get("API_KEY")
41
+ # # MODEL_NAME = os.environ.get("MODEL_NAME")
42
+ # # ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
43
+
44
+ # # if not API_BASE_URL:
45
+ # # raise ValueError("API_BASE_URL must be set")
46
+
47
+ # # if not API_KEY:
48
+ # # raise ValueError("API_KEY must be set")
49
+
50
+ # # MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
51
+ # # ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
52
+
53
+ # # TASK_NAME = "schedule-optimization"
54
+ # # BENCHMARK = "cognitive-load-manager"
55
+ # # SUCCESS_SCORE_THRESHOLD = 0.5
56
+ # # MAX_STEPS = 50
57
+
58
+ # # def log_start(task: str, env: str, model: str) -> None:
59
+ # # print(f"[START] task={task} env={env} model={model}", flush=True)
60
+
61
+ # # def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
62
+ # # error_val = error if error else "null"
63
+ # # done_val = str(done).lower()
64
+ # # print(
65
+ # # f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
66
+ # # flush=True,
67
+ # # )
68
+
69
+ # # def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
70
+ # # rewards_str = ",".join(f"{r:.2f}" for r in rewards)
71
+ # # print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
72
+
73
+ # # def main():
74
+ # # # Always initialise the OpenAI client using the proxy URL and API key.
75
+ # # # The hackathon validator requires ALL LLM calls to go through API_BASE_URL
76
+ # # # with the provided API_KEY β€” never bypass this with hardcoded credentials.
77
+ # # client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
78
+
79
+ # # task_id = os.getenv("CLM_LEVEL", "hard")
80
+
81
+ # # log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
82
+
83
+ # # # 1. Reset Environment
84
+ # # try:
85
+ # # data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": task_id})
86
+ # # except Exception as e:
87
+ # # log_step(step=0, action="reset", reward=0.0, done=True, error=str(e)[:50])
88
+ # # log_end(success=False, steps=0, score=0.0, rewards=[])
89
+ # # return
90
+
91
+ # # session_id = data["session_id"]
92
+ # # observation = data["observation"]
93
+
94
+ # # done = False
95
+ # # step = 0
96
+ # # rewards = []
97
+ # # history = []
98
+ # # info = {}
99
+
100
+ # # while not done and step < MAX_STEPS:
101
+ # # step += 1
102
+
103
+ # # # 2. Get next action from LLM via the hackathon proxy
104
+ # # history_str = "\n".join(history[-5:]) if history else "No previous actions."
105
+ # # system_prompt = """
106
+ # # You are an AI task scheduler managing cognitive load.
107
+ # # CRITICAL RULES:
108
+ # # 1. If "fatigue_level" is "high" or "medium", output {"type": "break"}. Do NOT work until fatigue is "low".
109
+ # # 2. If "stress_warning" is true, {"type": "break"} reduces stress safely.
110
+ # # 3. Find tasks where "progress" < 1.0. Output {"type": "work", "task_id": "<id>"}. Do NOT work on 1.0 tasks.
111
+ # # 4. Respond ONLY with raw JSON format. No markdown blocks.
112
+ # # Valid actions: {"type": "work", "task_id": "id"}, {"type": "break"}, {"type": "delay"}, {"type": "switch", "task_id": "id"}
113
+ # # """
114
+ # # user_prompt = f"""
115
+ # # Previous 5 Steps History:
116
+ # # {history_str}
117
+
118
+ # # Current Observation:
119
+ # # {json.dumps(observation, indent=2)}
120
+
121
+ # # What is your next action JSON?
122
+ # # """
123
+ # # action = None
124
+ # # error_msg = None
125
+
126
+ # # try:
127
+ # # completion = client.chat.completions.create(
128
+ # # model=MODEL_NAME,
129
+ # # messages=[
130
+ # # {"role": "system", "content": system_prompt.strip()},
131
+ # # {"role": "user", "content": user_prompt.strip()}
132
+ # # ],
133
+ # # temperature=0.1,
134
+ # # max_tokens=150
135
+ # # )
136
+ # # action_text = (completion.choices[0].message.content or "").strip()
137
+
138
+ # # # Strip accidental markdown code fences
139
+ # # if action_text.startswith("```json"):
140
+ # # action_text = action_text[7:]
141
+ # # if action_text.startswith("```"):
142
+ # # action_text = action_text[3:]
143
+ # # if action_text.endswith("```"):
144
+ # # action_text = action_text[:-3]
145
+
146
+ # # start_idx = action_text.find("{")
147
+ # # end_idx = action_text.rfind("}")
148
+ # # if start_idx != -1 and end_idx != -1:
149
+ # # action = json.loads(action_text[start_idx:end_idx + 1])
150
+ # # except Exception as e:
151
+ # # error_msg = str(e)[:50]
152
+
153
+ # # # Fallback heuristic only if LLM call failed / returned unparseable output
154
+ # # if not action:
155
+ # # tasks = observation.get("tasks", [])
156
+ # # incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
157
+ # # if observation.get("visible_state", {}).get("fatigue_level") in ("high", "medium"):
158
+ # # action = {"type": "break"}
159
+ # # elif incomp:
160
+ # # action = {"type": "work", "task_id": incomp[0]["id"]}
161
+ # # else:
162
+ # # action = {"type": "delay"}
163
+
164
+ # # action_str = json.dumps(action).replace(" ", "")
165
+
166
+ # # # 3. Step the environment
167
+ # # try:
168
+ # # step_data = post_json(f"{ENV_BASE_URL}/step", {
169
+ # # "session_id": session_id,
170
+ # # "action": action
171
+ # # })
172
+ # # observation = step_data["observation"]
173
+ # # reward = step_data.get("reward", 0.0)
174
+ # # done = step_data.get("done", False)
175
+ # # info = step_data.get("info", {})
176
+ # # except Exception as e:
177
+ # # reward = 0.0
178
+ # # done = True
179
+ # # error_msg = error_msg or str(e)[:50]
180
+
181
+ # # rewards.append(reward)
182
+ # # history.append(f"Step {step} Action: {action_str} -> Reward: {reward}")
183
+ # # log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
184
+
185
+ # # score = info.get("final_score", 0.0)
186
+ # # success = score >= SUCCESS_SCORE_THRESHOLD
187
+ # # log_end(success=success, steps=step, score=score, rewards=rewards)
188
+
189
+ # # if __name__ == "__main__":
190
+ # # main()
191
+
192
+
193
+
194
+ # import os
195
+ # import json
196
+ # import urllib.request
197
+ # import urllib.error
198
+ # from typing import List, Optional
199
+
200
+ # from openai import OpenAI
201
+
202
+
203
+ # # ── HTTP Helper ──────────────────────────────────────────────────────────────
204
+ # def post_json(url: str, payload: dict) -> dict:
205
+ # data = json.dumps(payload).encode("utf-8")
206
+ # req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"})
207
+ # with urllib.request.urlopen(req) as res:
208
+ # return json.loads(res.read().decode("utf-8"))
209
+
210
+
211
+ # # ── STRICT ENV (NO FALLBACKS) ────────────────────────────────────────────────
212
+ # API_BASE_URL = os.environ.get("API_BASE_URL")
213
+ # API_KEY = os.environ.get("API_KEY")
214
+ # MODEL_NAME = os.environ.get("MODEL_NAME")
215
+
216
+ # if not API_BASE_URL:
217
+ # raise ValueError("API_BASE_URL must be set")
218
+ # if not API_KEY:
219
+ # raise ValueError("API_KEY must be set")
220
+ # if not MODEL_NAME:
221
+ # raise ValueError("MODEL_NAME must be set")
222
+
223
+ # ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
224
+
225
+
226
+ # # ── CONFIG ───────────────────────────────────────────────────────────────────
227
+ # TASK_NAME = "schedule-optimization"
228
+ # BENCHMARK = "cognitive-load-manager"
229
+ # SUCCESS_SCORE_THRESHOLD = 0.5
230
+ # MAX_STEPS = 50
231
+
232
+
233
+ # # ── LOGGING ──────────────────────────────────────────────────────────────────
234
+ # def log_start(task: str, env: str, model: str):
235
+ # print(f"[START] task={task} env={env} model={model}", flush=True)
236
+
237
+
238
+ # def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]):
239
+ # error_val = error if error else "null"
240
+ # print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
241
+
242
+
243
+ # def log_end(success: bool, steps: int, score: float, rewards: List[float]):
244
+ # rewards_str = ",".join(f"{r:.2f}" for r in rewards)
245
+ # print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
246
+
247
+
248
+ # # ── MAIN ───��─────────────────────────────────────────────────────────────────
249
+ # def main():
250
+ # client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
251
+
252
+ # log_start(TASK_NAME, BENCHMARK, MODEL_NAME)
253
+
254
+ # # RESET
255
+ # try:
256
+ # data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": "hard"})
257
+ # except Exception as e:
258
+ # log_step(0, "reset", 0.0, True, str(e)[:50])
259
+ # log_end(False, 0, 0.0, [])
260
+ # return
261
+
262
+ # session_id = data["session_id"]
263
+ # observation = data["observation"]
264
+
265
+ # rewards = []
266
+ # done = False
267
+ # step = 0
268
+ # info = {}
269
+
270
+ # while not done and step < MAX_STEPS:
271
+ # step += 1
272
+
273
+ # action = None
274
+ # error_msg = None
275
+
276
+ # # πŸ”₯ FORCE LLM CALL (NO SKIP)
277
+ # try:
278
+ # response = client.responses.create(
279
+ # model=MODEL_NAME,
280
+ # input=f"Return ONLY JSON action for this observation:\n{json.dumps(observation)}",
281
+ # max_output_tokens=100,
282
+ # temperature=0.1
283
+ # )
284
+
285
+ # # Extract text safely
286
+ # text = ""
287
+ # if response.output:
288
+ # for item in response.output:
289
+ # for part in item.content:
290
+ # if hasattr(part, "text"):
291
+ # text += part.text
292
+
293
+ # text = text.strip()
294
+
295
+ # start = text.find("{")
296
+ # end = text.rfind("}")
297
+ # if start != -1 and end != -1:
298
+ # action = json.loads(text[start:end+1])
299
+
300
+ # except Exception as e:
301
+ # error_msg = str(e)[:50]
302
+
303
+ # # fallback AFTER LLM attempt
304
+ # if not action:
305
+ # tasks = observation.get("tasks", [])
306
+ # if tasks:
307
+ # action = {"type": "work", "task_id": tasks[0]["id"]}
308
+ # else:
309
+ # action = {"type": "break"}
310
+
311
+ # action_str = json.dumps(action).replace(" ", "")
312
+
313
+ # # STEP ENV
314
+ # try:
315
+ # step_data = post_json(
316
+ # f"{ENV_BASE_URL}/step",
317
+ # {"session_id": session_id, "action": action}
318
+ # )
319
+ # observation = step_data["observation"]
320
+ # reward = step_data.get("reward", 0.0)
321
+ # done = step_data.get("done", False)
322
+ # info = step_data.get("info", {})
323
+ # except Exception as e:
324
+ # reward = 0.0
325
+ # done = True
326
+ # error_msg = error_msg or str(e)[:50]
327
+
328
+ # rewards.append(reward)
329
+
330
+ # log_step(step, action_str, reward, done, error_msg)
331
+
332
+ # score = info.get("final_score", 0.0)
333
+ # success = score >= SUCCESS_SCORE_THRESHOLD
334
+
335
+ # log_end(success, step, score, rewards)
336
+
337
+
338
+ # if __name__ == "__main__":
339
+ # main()
340
+
341
+
342
  import os
343
  import json
344
  import urllib.request
 
360
 
361
  from openai import OpenAI
362
 
363
+
364
  def post_json(url: str, payload: dict) -> dict:
365
  data = json.dumps(payload).encode("utf-8")
366
  req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"})
 
370
  except urllib.error.HTTPError as e:
371
  raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
372
 
 
 
 
 
 
 
 
373
 
374
+ # ── STRICT ENV (NO FALLBACKS) ────────────────────────────────────────────────
375
  API_BASE_URL = os.environ.get("API_BASE_URL")
376
  API_KEY = os.environ.get("API_KEY")
377
  MODEL_NAME = os.environ.get("MODEL_NAME")
 
378
 
379
  if not API_BASE_URL:
380
  raise ValueError("API_BASE_URL must be set")
 
381
  if not API_KEY:
382
  raise ValueError("API_KEY must be set")
383
+ if not MODEL_NAME:
384
+ raise ValueError("MODEL_NAME must be set")
385
 
386
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
 
387
 
388
  TASK_NAME = "schedule-optimization"
389
  BENCHMARK = "cognitive-load-manager"
390
  SUCCESS_SCORE_THRESHOLD = 0.5
391
  MAX_STEPS = 50
392
 
393
+
394
  def log_start(task: str, env: str, model: str) -> None:
395
  print(f"[START] task={task} env={env} model={model}", flush=True)
396
 
397
+
398
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
399
  error_val = error if error else "null"
400
  done_val = str(done).lower()
 
403
  flush=True,
404
  )
405
 
406
+
407
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
408
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
409
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
410
 
411
+
412
  def main():
413
+ # ALWAYS use API_BASE_URL + API_KEY from environment β€” never bypass the proxy.
 
 
414
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
415
 
416
  task_id = os.getenv("CLM_LEVEL", "hard")
 
434
  history = []
435
  info = {}
436
 
437
+ system_prompt = """You are an AI task scheduler managing cognitive load.
 
 
 
 
 
 
438
  CRITICAL RULES:
439
  1. If "fatigue_level" is "high" or "medium", output {"type": "break"}. Do NOT work until fatigue is "low".
440
  2. If "stress_warning" is true, {"type": "break"} reduces stress safely.
441
  3. Find tasks where "progress" < 1.0. Output {"type": "work", "task_id": "<id>"}. Do NOT work on 1.0 tasks.
442
+ 4. Respond ONLY with raw JSON. No markdown, no explanation.
443
+ Valid actions: {"type": "work", "task_id": "id"}, {"type": "break"}, {"type": "delay"}, {"type": "switch", "task_id": "id"}"""
444
+
445
+ while not done and step < MAX_STEPS:
446
+ step += 1
447
+
 
 
 
 
 
 
448
  action = None
449
  error_msg = None
450
 
451
+ # 2. πŸ”₯ FORCE LLM CALL via proxy β€” uses client.responses.create (required by validator)
452
+ history_str = "\n".join(history[-5:]) if history else "No previous actions."
453
+ user_prompt = f"{system_prompt}\n\nPrevious 5 Steps:\n{history_str}\n\nCurrent Observation:\n{json.dumps(observation)}\n\nReturn ONLY a JSON action:"
454
+
455
  try:
456
+ response = client.responses.create(
457
  model=MODEL_NAME,
458
+ input=user_prompt,
459
+ max_output_tokens=100,
 
 
460
  temperature=0.1,
 
461
  )
462
+
463
+ # Extract text from response safely
464
+ text = ""
465
+ if response.output:
466
+ for item in response.output:
467
+ for part in item.content:
468
+ if hasattr(part, "text"):
469
+ text += part.text
470
+
471
+ text = text.strip()
472
+
473
+ # Strip markdown fences if present
474
+ if text.startswith("```json"):
475
+ text = text[7:]
476
+ if text.startswith("```"):
477
+ text = text[3:]
478
+ if text.endswith("```"):
479
+ text = text[:-3]
480
+
481
+ start_idx = text.find("{")
482
+ end_idx = text.rfind("}")
483
  if start_idx != -1 and end_idx != -1:
484
+ action = json.loads(text[start_idx:end_idx + 1])
485
+
486
  except Exception as e:
487
  error_msg = str(e)[:50]
488
 
489
+ # Fallback heuristic ONLY if LLM call failed / returned unparseable output
490
  if not action:
491
  tasks = observation.get("tasks", [])
492
  incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
493
+ fatigue = observation.get("visible_state", {}).get("fatigue_level")
494
+ if fatigue in ("high", "medium"):
495
  action = {"type": "break"}
496
  elif incomp:
497
  action = {"type": "work", "task_id": incomp[0]["id"]}
 
504
  try:
505
  step_data = post_json(f"{ENV_BASE_URL}/step", {
506
  "session_id": session_id,
507
+ "action": action,
508
  })
509
  observation = step_data["observation"]
510
  reward = step_data.get("reward", 0.0)
 
523
  success = score >= SUCCESS_SCORE_THRESHOLD
524
  log_end(success=success, steps=step, score=score, rewards=rewards)
525
 
526
+
527
  if __name__ == "__main__":
528
+ main()