Shreeraj Mummidivarapu commited on
Commit
bdf2b53
Β·
unverified Β·
1 Parent(s): 42f16db

Eswar Ki Krupa !!

Browse files
Files changed (1) hide show
  1. inference.py +179 -132
inference.py CHANGED
@@ -1,166 +1,213 @@
1
  #!/usr/bin/env python3
2
- """
3
- inference.py β€” Robust LLM Agent for WildfireContainment-v0
4
- Uses OpenAI-compatible client (required by hackathon validator).
5
- """
6
  import os
7
- import sys
8
  import json
9
- import time
10
-
11
- # ── Required env vars (os.getenv required by validator) ──────────────────────
12
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
13
- MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
14
- HF_TOKEN = os.getenv("HF_TOKEN", "")
15
-
16
- BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
17
-
18
- if not HF_TOKEN:
19
- print("[WARN] HF_TOKEN not set β€” LLM calls will use greedy fallback.", flush=True)
20
 
21
- # ── OpenAI client (required by validator: OpenAI( + base_url=API_BASE_URL) ───
22
  try:
23
- from openai import OpenAI
24
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "missing")
25
- OPENAI_AVAILABLE = True
26
  except ImportError:
27
- client = None
28
- OPENAI_AVAILABLE = False
29
- print("[WARN] openai package not installed β€” using greedy fallback", flush=True)
30
 
31
- import requests
32
 
33
- TASK_STEPS = 3
 
 
 
 
 
 
 
34
 
 
 
35
 
36
- def log(msg):
37
- print(msg, flush=True)
 
38
 
39
 
40
- def reset():
41
- """Reset environment via API."""
42
- try:
43
- r = requests.post(f"{BASE_URL}/reset", timeout=10)
44
- r.raise_for_status()
45
- return r.json()
46
- except Exception as e:
47
- log(f"[ERROR] reset failed: {e}")
48
- return None
49
 
50
 
51
- def step(actions):
52
- """Step environment via API."""
53
- try:
54
- payload = {"actions": actions}
55
- r = requests.post(f"{BASE_URL}/step", json=payload, timeout=10)
56
- r.raise_for_status()
57
- return r.json()
58
- except Exception as e:
59
- log(f"[ERROR] step failed: {e}")
60
- return None
61
 
62
 
63
- def get_llm_action(obs_text):
64
- """Get action from LLM via OpenAI-compatible client, or fallback."""
65
- if not OPENAI_AVAILABLE or not HF_TOKEN or client is None:
66
- return [{"move": 8, "act": False}] * 3
 
 
67
  try:
68
- prompt = (
69
- f"Fire report: {obs_text[:500]}. "
70
- "Choose 3 actions (move 0-8, act true/false). "
71
- 'JSON only: {"actions": [{"move": 8, "act": false}, ...]}'
72
- )
73
- completion = client.chat.completions.create(
74
- model=MODEL_NAME,
75
- messages=[{"role": "user", "content": prompt}],
76
- temperature=0.0,
77
- max_tokens=100,
78
- )
79
- content = completion.choices[0].message.content.strip()
80
- content = content.replace("```json", "").replace("```", "").strip()
81
- parsed = json.loads(content)
82
- return parsed.get("actions", [{"move": 8, "act": False}] * 3)
83
- except Exception:
84
- return [{"move": 8, "act": False}] * 3
85
-
86
 
87
- def compute_score(obs):
88
- """Compute validator-safe score from observation."""
89
- try:
90
- if not obs:
91
- return 0.5
92
- fire_grid = obs.get("fire_grid", [])
93
- structure_grid = obs.get("structure_grid", [])
94
- if not fire_grid or not structure_grid:
95
- return 0.5
96
 
97
- fire_cells = sum(1 for row in fire_grid for cell in row if cell > 0.1)
98
- structures_remaining = sum(1 for row in structure_grid for cell in row if cell == 1)
99
- total_cells = 20 * 20
100
- initial_structures = 10
101
 
102
- struct_score = structures_remaining / max(initial_structures, 1)
103
- fire_score = max(0.0, 1.0 - (fire_cells / total_cells))
104
- raw = (struct_score * 0.6) + (fire_score * 0.4)
105
 
106
- return round(max(0.01, min(0.99, raw)), 3)
107
- except Exception:
108
- return 0.5
 
 
 
109
 
110
 
111
- def run_task(task_id):
112
- """Run one task and emit logs."""
113
- log(f"[START] task={task_id} steps={TASK_STEPS}")
 
 
 
 
114
 
115
- result = reset()
116
- if not result:
117
- log(f"[END] task={task_id} score=0.5")
118
- return 0.5
119
 
120
- obs = result.get("observation", {})
121
- scores = []
122
-
123
- for step_num in range(1, TASK_STEPS + 1):
124
- obs_text = json.dumps(obs)[:500]
125
- actions = get_llm_action(obs_text)
126
-
127
- step_result = step(actions)
128
- if not step_result:
129
- break
130
-
131
- obs = step_result.get("observation", {})
132
- reward = step_result.get("reward", 0.0)
133
- done = step_result.get("done", False)
134
-
135
- score = compute_score(obs)
136
- scores.append(score)
137
-
138
- safe_reward = max(0.01, min(0.99, reward)) if reward else 0.5
139
-
140
- log(f"[STEP] task={task_id} step={step_num} reward={safe_reward:.3f} score={score:.3f} done={done}")
141
 
142
- if done:
143
- break
144
 
145
- final_score = max(0.01, min(0.99, sum(scores) / len(scores))) if scores else 0.5
146
- log(f"[END] task={task_id} score={final_score:.3f}")
147
- return final_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
149
 
150
- def main():
151
- tasks = ["easy", "medium", "hard"]
152
- all_scores = {}
153
 
154
- for task_id in tasks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  try:
156
- score = run_task(task_id)
157
- all_scores[task_id] = score
158
- except Exception as e:
159
- log(f"[ERROR] task {task_id} failed: {e}")
160
- all_scores[task_id] = 0.5
161
-
162
- avg = max(0.01, min(0.99, sum(all_scores.values()) / len(all_scores))) if all_scores else 0.5
163
- log(f"[SUMMARY] scores={json.dumps(all_scores)} average={avg:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
+
 
 
 
3
  import os
 
4
  import json
5
+ import sys
6
+ import urllib.request
7
+ import urllib.error
8
+ from typing import List, Optional
 
 
 
 
 
 
 
9
 
10
+ # ── Load .env for local development ──────────────────────────────────────────
11
  try:
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
 
14
  except ImportError:
15
+ pass # dotenv not available in validator β€” env vars are injected directly
 
 
16
 
17
+ from openai import OpenAI
18
 
19
+ # ── Credentials ───────────────────────────────────────────────────────────────
20
+ # The hackathon validator INJECTS API_BASE_URL and API_KEY into the environment.
21
+ # We MUST use those values directly β€” never override them with HF_TOKEN or defaults.
22
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
23
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
24
+ if not API_KEY:
25
+ print("WARNING: API_KEY not set. LLM calls will fail.", file=sys.stderr, flush=True)
26
+ API_KEY = "missing"
27
 
28
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
29
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://huggingface.co/spaces/anonymousDevil/cognitive-load-manager")
30
 
31
+ print("DEBUG BASE URL:", API_BASE_URL, flush=True)
32
+ print("DEBUG MODEL:", MODEL_NAME, flush=True)
33
+ print("DEBUG ENV URL:", ENV_BASE_URL, flush=True)
34
 
35
 
36
+ # ── CLIENT ─────────────────────────────────────────────────────
37
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
 
 
 
 
 
 
38
 
39
 
40
+ # ── CONFIG ─────────────────────────────────────────────────────
41
+ TASK_NAME = "schedule-optimization"
42
+ BENCHMARK = "cognitive-load-manager"
43
+ SUCCESS_SCORE_THRESHOLD = 0.5
44
+ MAX_STEPS = 50
 
 
 
 
 
45
 
46
 
47
+ # ── HTTP ───────────────────────────────────────────────────────
48
+ def post_json(url: str, payload: dict) -> dict:
49
+ data = json.dumps(payload).encode("utf-8")
50
+ req = urllib.request.Request(
51
+ url, data=data, headers={"Content-Type": "application/json"}
52
+ )
53
  try:
54
+ with urllib.request.urlopen(req, timeout=30) as res:
55
+ return json.loads(res.read().decode("utf-8"))
56
+ except urllib.error.HTTPError as e:
57
+ raise Exception(f"HTTP {e.code}: {e.read().decode('utf-8')[:200]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # ── LOGGING ────────────────────────────────────────────────────
61
+ def log_start(task: str, env: str, model: str) -> None:
62
+ print(f"[START] task={task} env={env} model={model}", flush=True)
 
63
 
 
 
 
64
 
65
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
66
+ print(
67
+ f"[STEP] step={step} action={action} reward={reward:.2f} "
68
+ f"done={str(done).lower()} error={error or 'null'}",
69
+ flush=True,
70
+ )
71
 
72
 
73
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
74
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
75
+ print(
76
+ f"[END] success={str(success).lower()} steps={steps} "
77
+ f"score={score:.3f} rewards={rewards_str}",
78
+ flush=True,
79
+ )
80
 
 
 
 
 
81
 
82
+ # ── MAIN ───────────────────────────────────────────────────────
83
+ def main():
84
+ task_id = os.getenv("CLM_LEVEL", "hard")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
 
87
 
88
+ # ── 1. Reset environment ─────────────────────────────────────
89
+ try:
90
+ data = post_json(f"{ENV_BASE_URL}/reset", {"task_id": task_id})
91
+ session_id = data.get("session_id", "default")
92
+ observation = data["observation"]
93
+ except Exception as e:
94
+ log_step(step=0, action="reset", reward=0.0, done=True, error=str(e)[:80])
95
+ log_end(success=False, steps=0, score=0.0, rewards=[])
96
+ return
97
+
98
+ done = False
99
+ step = 0
100
+ rewards: List[float] = []
101
+ history: List[str] = []
102
+ info: dict = {}
103
+
104
+ # ── 2. Agent loop ────────────────────────────────────────────
105
+ while not done and step < MAX_STEPS:
106
+ step += 1
107
+
108
+ history_str = "\n".join(history[-5:]) if history else "No previous actions."
109
+
110
+ system_prompt = (
111
+ "You are an AI task scheduler managing human cognitive load.\n"
112
+ "You MUST respond with ONLY a JSON object (no markdown, no explanation).\n\n"
113
+ "ACTION FORMAT: {\"type\": \"<action>\", \"task_id\": \"<id or null>\"}\n"
114
+ "Valid types:\n"
115
+ " - \"work\" : work on task_id (requires task_id)\n"
116
+ " - \"break\" : rest to recover energy (task_id: null)\n"
117
+ " - \"switch\": switch to a different task_id (requires task_id)\n"
118
+ " - \"delay\" : wait/do nothing (task_id: null)\n\n"
119
+ "STRATEGY:\n"
120
+ "1. If fatigue_level is 'high' OR stress_warning is true β†’ {\"type\": \"break\", \"task_id\": null}\n"
121
+ "2. If fatigue_level is 'medium' and stress is manageable β†’ {\"type\": \"work\", \"task_id\": \"<earliest deadline incomplete task>\"}\n"
122
+ "3. Otherwise β†’ {\"type\": \"work\", \"task_id\": \"<earliest deadline incomplete task>\"}\n"
123
+ "4. Pick incomplete tasks (progress < 1.0) with the earliest deadline first.\n"
124
+ )
125
 
126
+ user_prompt = (
127
+ f"Previous 5 steps:\n{history_str}\n\n"
128
+ f"Current observation:\n{json.dumps(observation, indent=2)}\n\n"
129
+ "What is your next action JSON?"
130
+ )
131
 
132
+ action: Optional[dict] = None
133
+ error_msg: Optional[str] = None
 
134
 
135
+ # ── LLM call through the validator proxy ─────────────────
136
+ try:
137
+ completion = client.chat.completions.create(
138
+ model=MODEL_NAME,
139
+ messages=[
140
+ {"role": "system", "content": system_prompt},
141
+ {"role": "user", "content": user_prompt},
142
+ ],
143
+ temperature=0.1,
144
+ max_tokens=150,
145
+ )
146
+ text = (completion.choices[0].message.content or "").strip()
147
+
148
+ # Strip markdown fences if present
149
+ if text.startswith("```json"):
150
+ text = text[7:]
151
+ if text.startswith("```"):
152
+ text = text[3:]
153
+ if text.endswith("```"):
154
+ text = text[:-3]
155
+ text = text.strip()
156
+
157
+ # Extract JSON
158
+ s = text.find("{")
159
+ e = text.rfind("}")
160
+ if s != -1 and e != -1:
161
+ action = json.loads(text[s : e + 1])
162
+ except Exception as ex:
163
+ error_msg = str(ex)[:80]
164
+
165
+ # ── Heuristic fallback (only if LLM call failed / unparseable) ───
166
+ if not action:
167
+ tasks = observation.get("tasks", [])
168
+ incomp = [t for t in tasks if t.get("progress", 0.0) < 1.0]
169
+ fs = observation.get("visible_state", {})
170
+ if fs.get("fatigue_level") in ("high", "medium") or fs.get("stress_warning"):
171
+ action = {"type": "break"}
172
+ elif incomp:
173
+ action = {"type": "work", "task_id": incomp[0]["id"]}
174
+ else:
175
+ action = {"type": "delay"}
176
+
177
+ # Validate action type
178
+ valid_types = {"work", "break", "switch", "delay"}
179
+ if action.get("type") not in valid_types:
180
+ action = {"type": "delay"}
181
+
182
+ action_str = json.dumps(action, separators=(",", ":"))
183
+
184
+ # ── ENV STEP ─────────────────────────────────────────────
185
  try:
186
+ step_data = post_json(
187
+ f"{ENV_BASE_URL}/step",
188
+ {"session_id": session_id, "action": action},
189
+ )
190
+ observation = step_data["observation"]
191
+ reward = float(step_data.get("reward", 0.0))
192
+ done = bool(step_data.get("done", False))
193
+ info = step_data.get("info", {})
194
+ except Exception as ex:
195
+ reward = 0.0
196
+ done = True
197
+ error_msg = error_msg or str(ex)[:80]
198
+
199
+ rewards.append(reward)
200
+ history.append(f"Step {step}: {action_str} -> reward={reward:.2f}")
201
+
202
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
203
+
204
+ # ── 3. Final scoring ─────────────────────────────────────────
205
+ score = float(info.get("final_score", 0.0))
206
+ if score == 0.0 and rewards:
207
+ score = sum(rewards) / len(rewards)
208
+ success = score >= SUCCESS_SCORE_THRESHOLD
209
+
210
+ log_end(success, step, score, rewards)
211
 
212
 
213
  if __name__ == "__main__":