soumi guria commited on
Commit
f07ddc1
·
1 Parent(s): b8dbf99

moved files to correct location and changed the prompt

Browse files
Files changed (4) hide show
  1. README.md +0 -0
  2. baseline/inference.py +0 -141
  3. baseline/requirements.txt +0 -3
  4. inference.py +155 -0
README.md CHANGED
Binary files a/README.md and b/README.md differ
 
baseline/inference.py DELETED
@@ -1,141 +0,0 @@
1
- import os
2
- import requests
3
- import json
4
- from dotenv import load_dotenv
5
-
6
- load_dotenv()
7
-
8
- API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
9
- HF_ROUTER_URL = os.getenv(
10
- "HF_ROUTER_URL",
11
- "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct"
12
- )
13
- HF_TOKEN = os.getenv("HF_TOKEN")
14
-
15
- def call_hf_router(prompt: str) -> dict:
16
- if not HF_TOKEN:
17
- return None
18
-
19
- headers = {
20
- "Authorization": f"Bearer {HF_TOKEN}",
21
- "Content-Type": "application/json"
22
- }
23
-
24
- payload = {
25
- "inputs": prompt,
26
- "parameters": {
27
- "max_new_tokens": 150,
28
- "temperature": 0.1,
29
- "return_full_text": False
30
- }
31
- }
32
-
33
- try:
34
- response = requests.post(HF_ROUTER_URL, headers=headers, json=payload)
35
- if response.status_code == 200:
36
- result = response.json()
37
- if isinstance(result, list) and len(result) > 0:
38
- text = result[0].get("generated_text", "")
39
-
40
- # Extract JSON block
41
- start_idx = text.find("{")
42
- end_idx = text.rfind("}")
43
- if start_idx != -1 and end_idx != -1:
44
- json_str = text[start_idx:end_idx+1]
45
- return json.loads(json_str)
46
- return None
47
- except Exception as e:
48
- print(f"Error calling HF Router: {e}")
49
- return None
50
-
51
- def run_level(level: str):
52
- print(f"\n{'='*40}")
53
- print(f"--- Running Level: {level.upper()} ---")
54
- print(f"{'='*40}")
55
-
56
- # 1. Reset Environment
57
- res = requests.post(f"{API_BASE_URL}/reset", json={"level": level})
58
- if res.status_code != 200:
59
- print(f"Failed to reset: {res.text}")
60
- return
61
-
62
- data = res.json()
63
- session_id = data["session_id"]
64
- observation = data["observation"]
65
-
66
- done = False
67
- step = 0
68
- total_reward = 0.0
69
- info = {}
70
-
71
- while not done:
72
- step += 1
73
- print(f"\nStep {step}")
74
-
75
- # 2. Call LLM for next action
76
- prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
77
- You are an AI agent managing tasks with deadlines under cognitive load.
78
- Your goals: Complete all tasks efficiently, avoiding burnout and minimizing stress.
79
- Respond ONLY with a valid JSON object representing your chosen action, with no extra text surrounding it.
80
- <|eot_id|><|start_header_id|>user<|end_header_id|>
81
- Current Observation:
82
- {json.dumps(observation, indent=2)}
83
-
84
- Available Actions:
85
- - {{"type": "work", "task_id": "<id>"}} - work on a specific task
86
- - {{"type": "break"}} - increases energy, decreases stress
87
- - {{"type": "switch", "task_id": "<id>"}} - switch focus without working
88
- - {{"type": "delay"}} - delays actions slightly reducing stress
89
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
90
-
91
- action = call_hf_router(prompt)
92
-
93
- # Fallback heuristic logic if HF router fails or no token
94
- if not action:
95
- tasks = observation.get("tasks", [])
96
- incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
97
- if observation.get("visible_state", {}).get("fatigue_level") == "high":
98
- action = {"type": "break"}
99
- elif incomp:
100
- action = {"type": "work", "task_id": incomp[0]["id"]}
101
- else:
102
- action = {"type": "delay"}
103
-
104
- print(f"Agent Action: {action}")
105
-
106
- # 3. Step Environment
107
- res = requests.post(f"{API_BASE_URL}/step", json={
108
- "session_id": session_id,
109
- "action": action
110
- })
111
-
112
- if res.status_code != 200:
113
- print(f"Failed to step: {res.text}")
114
- break
115
-
116
- step_data = res.json()
117
- observation = step_data["observation"]
118
- reward = step_data["reward"]
119
- done = step_data["done"]
120
- info = step_data["info"]
121
-
122
- total_reward += reward
123
- print(f"Reward: {reward:.2f}")
124
-
125
- print("\n--- Episode Finished ---")
126
- print(f"Total Reward: {total_reward:.2f}")
127
- if "final_score" in info:
128
- print(f"Final Score (Grader): {info['final_score']:.2f}")
129
-
130
- # Get final state
131
- state_res = requests.get(f"{API_BASE_URL}/state", params={"session_id": session_id})
132
- if state_res.status_code == 200:
133
- st = state_res.json()
134
- print(f"Final Energy: {st.get('energy', 0):.2f}, Final Stress: {st.get('stress', 0):.2f}")
135
-
136
- if __name__ == "__main__":
137
- if not HF_TOKEN:
138
- print("Warning: HF_TOKEN not set. Using fallback heuristic agent.")
139
-
140
- for level in ["easy", "medium", "hard"]:
141
- run_level(level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baseline/requirements.txt DELETED
@@ -1,3 +0,0 @@
1
- openai
2
- python-dotenv
3
- requests
 
 
 
 
inference.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ from typing import List, Optional
5
+ from dotenv import load_dotenv
6
+ from openai import OpenAI
7
+
8
+ load_dotenv()
9
+
10
+ API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
11
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+
14
+ TASK_NAME = "schedule-optimization"
15
+ BENCHMARK = "cognitive-load-manager"
16
+ SUCCESS_SCORE_THRESHOLD = 0.5 # Need 50% score basically
17
+ MAX_STEPS = 50
18
+
19
+ def log_start(task: str, env: str, model: str) -> None:
20
+ print(f"[START] task={task} env={env} model={model}", flush=True)
21
+
22
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
23
+ error_val = error if error else "null"
24
+ done_val = str(done).lower()
25
+ print(
26
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
27
+ flush=True,
28
+ )
29
+
30
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
31
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
32
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
33
+
34
+ def main():
35
+ # OpenAI client mapping to Hugging Face router, requiring HF_TOKEN
36
+ client = None
37
+ if HF_TOKEN:
38
+ # Initialize an OpenAI client but point it to HF standard completions API
39
+ hf_api_base = "https://router.huggingface.co/v1"
40
+ client = OpenAI(base_url=hf_api_base, api_key=HF_TOKEN)
41
+
42
+ # Initialize Environment
43
+ level = os.getenv("CLM_LEVEL", "hard")
44
+
45
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
46
+
47
+ # 1. Reset Environment
48
+ try:
49
+ res = requests.post(f"{API_BASE_URL}/reset", json={"level": level})
50
+ res.raise_for_status()
51
+ data = res.json()
52
+ except Exception as e:
53
+ log_step(step=0, action="reset", reward=0.0, done=True, error=str(e)[:50])
54
+ log_end(success=False, steps=0, score=0.0, rewards=[])
55
+ return
56
+
57
+ session_id = data["session_id"]
58
+ observation = data["observation"]
59
+
60
+ done = False
61
+ step = 0
62
+ rewards = []
63
+ history = []
64
+ info = {}
65
+
66
+ while not done and step < MAX_STEPS:
67
+ step += 1
68
+
69
+ # 2. Extract action via OpenAI interface (pointing to HF)
70
+ history_str = "\n".join(history[-5:]) if history else "No previous actions."
71
+ prompt = f"""
72
+ You are an AI agent managing tasks with deadlines under cognitive load.
73
+ Your goals: Complete all tasks efficiently, avoiding burnout and minimizing stress.
74
+
75
+ CRITICAL RULES:
76
+ 1. If your fatigue_level is "high" or energy drops too low, you MUST prioritize {{"type": "break"}} otherwise you will hit Burnout and fail!
77
+ 2. Do not work on a task if its progress is 1.0 (completed). Keep track of task statuses!
78
+
79
+ Previous 5 Steps History:
80
+ {history_str}
81
+
82
+ Current Observation:
83
+ {json.dumps(observation, indent=2)}
84
+
85
+ Respond ONLY with a valid JSON object representing your next action:
86
+ {{"type": "work", "task_id": "id"}} or {{"type": "break"}} or {{"type": "delay"}} or {{"type": "switch", "task_id": "id"}}
87
+ """
88
+ action = None
89
+ error_msg = None
90
+
91
+ if client:
92
+ try:
93
+ completion = client.chat.completions.create(
94
+ model=MODEL_NAME,
95
+ messages=[
96
+ {"role": "user", "content": prompt}
97
+ ],
98
+ temperature=0.1,
99
+ max_tokens=150
100
+ )
101
+ action_text = (completion.choices[0].message.content or "").strip()
102
+ # strip potential code blocks if model hallucinates them
103
+ if action_text.startswith("```json"): action_text = action_text[7:]
104
+ if action_text.endswith("```"): action_text = action_text[:-3]
105
+
106
+ start_idx = action_text.find("{")
107
+ end_idx = action_text.rfind("}")
108
+ if start_idx != -1 and end_idx != -1:
109
+ json_str = action_text[start_idx:end_idx+1]
110
+ action = json.loads(json_str)
111
+ except Exception as e:
112
+ error_msg = str(e)[:50]
113
+
114
+ # Fallback heuristic logic if action could not be parsed
115
+ if not action:
116
+ tasks = observation.get("tasks", [])
117
+ incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
118
+ if observation.get("visible_state", {}).get("fatigue_level") == "high":
119
+ action = {"type": "break"}
120
+ elif incomp:
121
+ action = {"type": "work", "task_id": incomp[0]["id"]}
122
+ else:
123
+ action = {"type": "delay"}
124
+
125
+ # Stringify action densely for stdout formatting
126
+ action_str = json.dumps(action).replace(" ", "")
127
+
128
+ # 3. Process action in Env
129
+ try:
130
+ res = requests.post(f"{API_BASE_URL}/step", json={
131
+ "session_id": session_id,
132
+ "action": action
133
+ })
134
+ res.raise_for_status()
135
+ step_data = res.json()
136
+
137
+ observation = step_data["observation"]
138
+ reward = step_data.get("reward", 0.0)
139
+ done = step_data.get("done", False)
140
+ info = step_data.get("info", {})
141
+ except Exception as e:
142
+ reward = 0.0
143
+ done = True
144
+ error_msg = error_msg or str(e)[:50]
145
+
146
+ rewards.append(reward)
147
+ history.append(f"Step {step} Action: {action_str} -> Reward: {reward}")
148
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
149
+
150
+ score = info.get("final_score", 0.0)
151
+ success = score >= SUCCESS_SCORE_THRESHOLD
152
+ log_end(success=success, steps=step, score=score, rewards=rewards)
153
+
154
+ if __name__ == "__main__":
155
+ main()