Roshan818 commited on
Commit
4c41e84
Β·
1 Parent(s): 4754135

fix: zero-arg self-contained graders, inference runs all 3 tasks

Browse files
Files changed (2) hide show
  1. grader.py +75 -28
  2. inference.py +90 -81
grader.py CHANGED
@@ -1,16 +1,21 @@
1
  """
2
  Graders for Smart Factory Scheduling tasks.
3
- Called by the OpenEnv validator to score an episode.
4
 
5
- Each grader accepts either:
6
- - an env object (has .completed_jobs, .jobs, .time, .late_jobs attributes)
7
- - a state dict (has "completed_jobs", "pending_jobs", "time", "late_jobs" keys)
8
 
9
- Returns a float strictly in (0, 1).
 
 
10
  """
11
 
 
12
 
13
- def _compute(completed, on_time, total, late):
 
 
 
14
  if total == 0:
15
  return 0.001
16
  score = (
@@ -21,37 +26,79 @@ def _compute(completed, on_time, total, late):
21
  return round(max(0.001, min(0.999, score)), 4)
22
 
23
 
24
- def _score(state_or_env):
25
- if isinstance(state_or_env, dict):
26
- done = state_or_env.get("completed_jobs", []) or []
27
- pending = state_or_env.get("pending_jobs", []) or []
28
- late = state_or_env.get("late_jobs", 0) or 0
29
- t = state_or_env.get("time", 0) or 0
 
30
  else:
31
- done = list(getattr(state_or_env, "completed_jobs", []) or [])
32
- pending = list(getattr(state_or_env, "jobs", getattr(state_or_env, "pending_jobs", [])) or [])
33
- late = getattr(state_or_env, "late_jobs", 0) or 0
34
- t = getattr(state_or_env, "time", 0) or 0
35
-
36
- completed = len(done)
37
- total = completed + len(pending)
38
- on_time = sum(
39
- 1 for j in done
40
- if (j.get("deadline", 0) if isinstance(j, dict) else getattr(j, "deadline", 0)) >= t
 
 
 
41
  )
42
  return _compute(completed, on_time, total, late)
43
 
44
 
45
- def score_easy(state_or_env):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  """Grade an easy-task episode (2 machines, 3 jobs, no failures)."""
47
- return _score(state_or_env)
 
 
48
 
49
 
50
- def score_medium(state_or_env):
51
  """Grade a medium-task episode (4 machines, 7 jobs, 8% failures)."""
52
- return _score(state_or_env)
 
 
53
 
54
 
55
- def score_hard(state_or_env):
56
  """Grade a hard-task episode (6 machines, 12 jobs, 15% failures)."""
57
- return _score(state_or_env)
 
 
 
1
  """
2
  Graders for Smart Factory Scheduling tasks.
 
3
 
4
+ Each grader is self-contained: when called with no arguments it creates a
5
+ FactoryEnv, runs a deterministic heuristic episode, and returns a score
6
+ strictly in (0, 1).
7
 
8
+ Alternatively, pass an env object or state dict from an already-run episode:
9
+ score_easy(env) # env object with .completed_jobs, .jobs, .time …
10
+ score_easy(state) # dict with "completed_jobs", "pending_jobs", "time" …
11
  """
12
 
13
+ from __future__ import annotations
14
 
15
+
16
+ # ── internal helpers ──────────────────────────────────────────────────────────
17
+
18
+ def _compute(completed: int, on_time: int, total: int, late: int) -> float:
19
  if total == 0:
20
  return 0.001
21
  score = (
 
26
  return round(max(0.001, min(0.999, score)), 4)
27
 
28
 
29
+ def _score_from(obj) -> float:
30
+ """Accept env object or state dict and return a score."""
31
+ if isinstance(obj, dict):
32
+ done_jobs = obj.get("completed_jobs", []) or []
33
+ pending = obj.get("pending_jobs", []) or []
34
+ late = obj.get("late_jobs", 0) or 0
35
+ t = obj.get("time", 0) or 0
36
  else:
37
+ done_jobs = list(getattr(obj, "completed_jobs", []) or [])
38
+ pending = list(
39
+ getattr(obj, "jobs", getattr(obj, "pending_jobs", []))
40
+ ) or []
41
+ late = getattr(obj, "late_jobs", 0) or 0
42
+ t = getattr(obj, "time", 0) or 0
43
+
44
+ completed = len(done_jobs)
45
+ total = completed + len(pending)
46
+ on_time = sum(
47
+ 1 for j in done_jobs
48
+ if (j.get("deadline", 0) if isinstance(j, dict)
49
+ else getattr(j, "deadline", 0)) >= t
50
  )
51
  return _compute(completed, on_time, total, late)
52
 
53
 
54
+ def _heuristic_action(obs):
55
+ """Simple earliest-deadline-first heuristic."""
56
+ from factory_env.models import FactoryAction
57
+ for m in obs.machines:
58
+ if m.status == "broken":
59
+ return FactoryAction(action_type="repair", machine_id=m.id)
60
+ for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
61
+ for m in obs.machines:
62
+ if m.status == "idle":
63
+ return FactoryAction(
64
+ action_type="assign_job", job_id=j.id, machine_id=m.id
65
+ )
66
+ return None # wait
67
+
68
+
69
+ def _run_episode(task: str, seed: int = 42) -> float:
70
+ """Run one full heuristic episode and return the graded score."""
71
+ from factory_env.env import FactoryEnv
72
+ from factory_env.models import FactoryAction
73
+
74
+ env = FactoryEnv(task=task, seed=seed)
75
+ obs = env.reset()
76
+ for _ in range(obs.max_steps):
77
+ if obs.done:
78
+ break
79
+ action = _heuristic_action(obs) or FactoryAction(action_type="wait")
80
+ obs = env.step(action)
81
+ return _score_from(env)
82
+
83
+
84
+ # ── public graders ────────────────────────────────────────────────────────────
85
+
86
+ def score_easy(state_or_env=None) -> float:
87
  """Grade an easy-task episode (2 machines, 3 jobs, no failures)."""
88
+ if state_or_env is not None:
89
+ return _score_from(state_or_env)
90
+ return _run_episode("easy")
91
 
92
 
93
+ def score_medium(state_or_env=None) -> float:
94
  """Grade a medium-task episode (4 machines, 7 jobs, 8% failures)."""
95
+ if state_or_env is not None:
96
+ return _score_from(state_or_env)
97
+ return _run_episode("medium")
98
 
99
 
100
+ def score_hard(state_or_env=None) -> float:
101
  """Grade a hard-task episode (6 machines, 12 jobs, 15% failures)."""
102
+ if state_or_env is not None:
103
+ return _score_from(state_or_env)
104
+ return _run_episode("hard")
inference.py CHANGED
@@ -1,20 +1,19 @@
1
  """
2
  Inference Script β€” Smart Factory Scheduling Environment
3
  =======================================================
4
- Connects to a running factory_env server via WebSocket and runs an LLM agent.
5
-
6
- Mandatory env vars (per hackathon spec):
7
- HF_TOKEN HuggingFace API key (also used as OPENAI_API_KEY)
8
- API_BASE_URL LLM endpoint (default: HF router)
9
- MODEL_NAME Model ID (default: Qwen/Qwen2.5-72B-Instruct)
10
-
11
- Optional env vars:
12
- ENV_URL URL of running factory_env server (default: http://localhost:7860)
13
- IMAGE_NAME Docker image name β€” if set, spins up a container instead of ENV_URL
14
- FACTORY_TASK easy | medium | hard (default: easy)
15
-
16
- STDOUT FORMAT (strict β€” do not alter):
17
- [START] task=<name> env=factory_env model=<model>
18
  [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
19
  [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
20
  """
@@ -32,15 +31,19 @@ from factory_env.models import FactoryAction
32
 
33
  # ── Configuration ────────────────────────────────────────────────────────────
34
  HF_TOKEN = os.getenv("HF_TOKEN")
 
35
  API_BASE_URL: str = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
36
- MODEL_NAME: str = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
37
- TASK_NAME: str = os.getenv("FACTORY_TASK", "easy")
38
- ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860")
39
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
40
- BENCHMARK: str = "factory_env"
41
- TEMPERATURE: float = 0.2
42
- MAX_TOKENS: int = 80
43
- SUCCESS_SCORE_THRESHOLD: float = 0.5
 
 
 
44
 
45
  SYSTEM_PROMPT = textwrap.dedent("""
46
  You are controlling a smart factory scheduling system.
@@ -53,14 +56,13 @@ SYSTEM_PROMPT = textwrap.dedent("""
53
  """).strip()
54
 
55
 
56
- # ── Log helpers (strict format required by judges) ────────────────────────────
57
  def log_start(task: str, env: str, model: str) -> None:
58
  print(f"[START] task={task} env={env} model={model}", flush=True)
59
 
60
 
61
- def log_step(
62
- step: int, action: str, reward: float, done: bool, error: Optional[str]
63
- ) -> None:
64
  print(
65
  f"[STEP] step={step} action={action.replace(' ', '_')} reward={reward:.2f} "
66
  f"done={str(done).lower()} error={error or 'null'}",
@@ -68,9 +70,8 @@ def log_step(
68
  )
69
 
70
 
71
- def log_end(
72
- success: bool, steps: int, score: float, rewards: List[float]
73
- ) -> None:
74
  print(
75
  f"[END] success={str(success).lower()} steps={steps} "
76
  f"score={score:.3f} rewards={','.join(f'{r:.2f}' for r in rewards)}",
@@ -86,7 +87,8 @@ def build_prompt(step: int, obs, last_reward: float) -> str:
86
  )
87
  jobs = (
88
  "\n".join(
89
- f" {j.id}: remaining={j.remaining_time}, deadline={j.deadline}, priority={j.priority}"
 
90
  for j in obs.pending_jobs
91
  )
92
  or " (none)"
@@ -104,7 +106,7 @@ def get_model_action(client: OpenAI, step: int, obs, last_reward: float) -> str:
104
  model=MODEL_NAME,
105
  messages=[
106
  {"role": "system", "content": SYSTEM_PROMPT},
107
- {"role": "user", "content": build_prompt(step, obs, last_reward)},
108
  ],
109
  temperature=TEMPERATURE,
110
  max_tokens=MAX_TOKENS,
@@ -122,7 +124,8 @@ def parse_action(text: str) -> FactoryAction:
122
  try:
123
  parts = text.strip().split()
124
  if parts[0] == "assign_job" and len(parts) == 3:
125
- return FactoryAction(action_type="assign_job", job_id=parts[1], machine_id=parts[2])
 
126
  if parts[0] == "repair" and len(parts) == 2:
127
  return FactoryAction(action_type="repair", machine_id=parts[1])
128
  except Exception:
@@ -131,7 +134,6 @@ def parse_action(text: str) -> FactoryAction:
131
 
132
 
133
  def heuristic_action(obs):
134
- """Fallback heuristic when LLM returns an ineffective wait."""
135
  for m in obs.machines:
136
  if m.status == "broken":
137
  return FactoryAction(action_type="repair", machine_id=m.id), f"repair {m.id}"
@@ -139,54 +141,41 @@ def heuristic_action(obs):
139
  for m in obs.machines:
140
  if m.status == "idle":
141
  s = f"assign_job {j.id} {m.id}"
142
- return FactoryAction(action_type="assign_job", job_id=j.id, machine_id=m.id), s
 
143
  return FactoryAction(action_type="wait"), "wait"
144
 
145
 
146
  # ── Score from final state ────────────────────────────────────────────────────
147
  def score_from_state(state, task: str) -> float:
148
- """Compute episode score from WebSocket state response."""
149
  completed_jobs = getattr(state, "completed_jobs", []) or []
150
- pending_jobs = getattr(state, "pending_jobs", []) or []
151
- late_jobs = getattr(state, "late_jobs", 0) or 0
152
- time = getattr(state, "time", 0) or 0
153
-
154
- completed = len(completed_jobs)
155
- total = completed + len(pending_jobs)
156
-
157
- # on_time: jobs whose deadline hasn't passed by end of episode (matches grader)
158
  on_time = sum(
159
  1 for j in completed_jobs
160
- if (j.get("deadline", 0) if isinstance(j, dict) else j.deadline) >= time
 
161
  )
162
-
163
  return compute_score(completed, on_time, total, late_jobs, task)
164
 
165
 
166
- # ── Main ──────────────────────────────────────────────────────────────────────
167
- async def main() -> None:
168
- llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
169
-
170
  rewards: List[float] = []
171
  steps_taken = 0
172
- score = 0.0
173
- success = False
174
-
175
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
176
 
177
- # Connect to environment β€” Docker image, direct URL, or localhost
178
- if LOCAL_IMAGE_NAME:
179
- print(f"[DEBUG] Spinning up Docker image: {LOCAL_IMAGE_NAME}", flush=True)
180
- env = await FactoryEnvClient.from_docker_image(LOCAL_IMAGE_NAME)
181
- else:
182
- url = ENV_URL or "http://localhost:7860"
183
- print(f"[DEBUG] Connecting to: {url}", flush=True)
184
- env = FactoryEnvClient(base_url=url)
185
- await env.connect()
186
 
187
  try:
188
- result = await env.reset(task=TASK_NAME)
189
- obs = result.observation
190
  last_reward = 0.0
191
 
192
  for step in range(1, obs.max_steps + 1):
@@ -194,44 +183,64 @@ async def main() -> None:
194
  break
195
 
196
  action_text = get_model_action(llm_client, step, obs, last_reward)
197
- action = parse_action(action_text)
198
 
199
- # Heuristic fallback: if LLM returns wait but there's work to do
200
  if action.action_type == "wait" and (
201
- obs.pending_jobs or any(m.status == "broken" for m in obs.machines)
 
202
  ):
203
  action, action_text = heuristic_action(obs)
204
 
205
- result = await env.step(action)
206
- obs = result.observation
207
- reward = result.reward or 0.0
208
- done = result.done
209
  rewards.append(reward)
210
  steps_taken = step
211
  last_reward = reward
212
 
213
- log_step(step=step, action=action_text, reward=reward, done=done, error=None)
214
-
215
  if done:
216
  break
217
 
218
- # Compute score from final WebSocket state
219
  try:
220
- state = await env.state()
221
- score = score_from_state(state, TASK_NAME)
222
  except Exception as exc:
223
- print(f"[DEBUG] state() failed, falling back to reward sum: {exc}", flush=True)
224
- max_reward = {"easy": 4.0, "medium": 12.0, "hard": 20.0}.get(TASK_NAME, 10.0)
225
- score = min(max(sum(rewards) / max_reward, 0.0), 1.0)
 
226
 
227
  success = score >= SUCCESS_SCORE_THRESHOLD
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  finally:
230
  try:
231
- await env.close()
232
  except Exception as exc:
233
  print(f"[DEBUG] env.close() error: {exc}", flush=True)
234
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
235
 
236
 
237
  if __name__ == "__main__":
 
1
  """
2
  Inference Script β€” Smart Factory Scheduling Environment
3
  =======================================================
4
+ Runs an LLM agent against the factory_env server for all 3 tasks
5
+ (easy, medium, hard) and emits structured stdout logs.
6
+
7
+ Environment variables:
8
+ HF_TOKEN HuggingFace / API key (no default β€” required)
9
+ API_BASE_URL LLM endpoint (default: HF router)
10
+ MODEL_NAME Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
11
+ IMAGE_NAME Docker image name β€” if set, spins up a container
12
+ ENV_URL Server URL (default: http://localhost:7860)
13
+ FACTORY_TASK Run a single task: easy | medium | hard (default: run all 3)
14
+
15
+ STDOUT FORMAT (one [START] / N [STEP] / one [END] per task):
16
+ [START] task=<task> env=factory_env model=<model>
 
17
  [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
18
  [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
19
  """
 
31
 
32
  # ── Configuration ────────────────────────────────────────────────────────────
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
35
  API_BASE_URL: str = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
36
+ MODEL_NAME: str = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
37
+ IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
38
+ ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860")
39
+ BENCHMARK: str = "factory_env"
40
+ TEMPERATURE: float = 0.2
41
+ MAX_TOKENS: int = 80
42
+ SUCCESS_SCORE_THRESHOLD = 0.5
43
+
44
+ # Run a single task if FACTORY_TASK is set, otherwise run all three
45
+ _single = os.getenv("FACTORY_TASK", "").strip()
46
+ TASKS: List[str] = [_single] if _single else ["easy", "medium", "hard"]
47
 
48
  SYSTEM_PROMPT = textwrap.dedent("""
49
  You are controlling a smart factory scheduling system.
 
56
  """).strip()
57
 
58
 
59
+ # ── Log helpers ───────────────────────────────────────────────────────────────
60
  def log_start(task: str, env: str, model: str) -> None:
61
  print(f"[START] task={task} env={env} model={model}", flush=True)
62
 
63
 
64
+ def log_step(step: int, action: str, reward: float, done: bool,
65
+ error: Optional[str]) -> None:
 
66
  print(
67
  f"[STEP] step={step} action={action.replace(' ', '_')} reward={reward:.2f} "
68
  f"done={str(done).lower()} error={error or 'null'}",
 
70
  )
71
 
72
 
73
+ def log_end(success: bool, steps: int, score: float,
74
+ rewards: List[float]) -> None:
 
75
  print(
76
  f"[END] success={str(success).lower()} steps={steps} "
77
  f"score={score:.3f} rewards={','.join(f'{r:.2f}' for r in rewards)}",
 
87
  )
88
  jobs = (
89
  "\n".join(
90
+ f" {j.id}: remaining={j.remaining_time}, deadline={j.deadline},"
91
+ f" priority={j.priority}"
92
  for j in obs.pending_jobs
93
  )
94
  or " (none)"
 
106
  model=MODEL_NAME,
107
  messages=[
108
  {"role": "system", "content": SYSTEM_PROMPT},
109
+ {"role": "user", "content": build_prompt(step, obs, last_reward)},
110
  ],
111
  temperature=TEMPERATURE,
112
  max_tokens=MAX_TOKENS,
 
124
  try:
125
  parts = text.strip().split()
126
  if parts[0] == "assign_job" and len(parts) == 3:
127
+ return FactoryAction(action_type="assign_job",
128
+ job_id=parts[1], machine_id=parts[2])
129
  if parts[0] == "repair" and len(parts) == 2:
130
  return FactoryAction(action_type="repair", machine_id=parts[1])
131
  except Exception:
 
134
 
135
 
136
  def heuristic_action(obs):
 
137
  for m in obs.machines:
138
  if m.status == "broken":
139
  return FactoryAction(action_type="repair", machine_id=m.id), f"repair {m.id}"
 
141
  for m in obs.machines:
142
  if m.status == "idle":
143
  s = f"assign_job {j.id} {m.id}"
144
+ return FactoryAction(action_type="assign_job",
145
+ job_id=j.id, machine_id=m.id), s
146
  return FactoryAction(action_type="wait"), "wait"
147
 
148
 
149
  # ── Score from final state ────────────────────────────────────────────────────
150
  def score_from_state(state, task: str) -> float:
 
151
  completed_jobs = getattr(state, "completed_jobs", []) or []
152
+ pending_jobs = getattr(state, "pending_jobs", []) or []
153
+ late_jobs = getattr(state, "late_jobs", 0) or 0
154
+ time = getattr(state, "time", 0) or 0
155
+ completed = len(completed_jobs)
156
+ total = completed + len(pending_jobs)
 
 
 
157
  on_time = sum(
158
  1 for j in completed_jobs
159
+ if (j.get("deadline", 0) if isinstance(j, dict)
160
+ else j.deadline) >= time
161
  )
 
162
  return compute_score(completed, on_time, total, late_jobs, task)
163
 
164
 
165
+ # ── Single-task episode ───────────────────────────────────────────────────────
166
+ async def run_task(env_client: FactoryEnvClient,
167
+ llm_client: OpenAI,
168
+ task: str) -> None:
169
  rewards: List[float] = []
170
  steps_taken = 0
171
+ score = 0.0
172
+ success = False
 
 
173
 
174
+ log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
 
 
 
 
 
 
 
 
175
 
176
  try:
177
+ result = await env_client.reset(task=task)
178
+ obs = result.observation
179
  last_reward = 0.0
180
 
181
  for step in range(1, obs.max_steps + 1):
 
183
  break
184
 
185
  action_text = get_model_action(llm_client, step, obs, last_reward)
186
+ action = parse_action(action_text)
187
 
 
188
  if action.action_type == "wait" and (
189
+ obs.pending_jobs
190
+ or any(m.status == "broken" for m in obs.machines)
191
  ):
192
  action, action_text = heuristic_action(obs)
193
 
194
+ result = await env_client.step(action)
195
+ obs = result.observation
196
+ reward = result.reward or 0.0
197
+ done = result.done
198
  rewards.append(reward)
199
  steps_taken = step
200
  last_reward = reward
201
 
202
+ log_step(step=step, action=action_text,
203
+ reward=reward, done=done, error=None)
204
  if done:
205
  break
206
 
 
207
  try:
208
+ state = await env_client.state()
209
+ score = score_from_state(state, task)
210
  except Exception as exc:
211
+ print(f"[DEBUG] state() failed: {exc}", flush=True)
212
+ max_r = {"easy": 4.0, "medium": 12.0, "hard": 20.0}.get(task, 10.0)
213
+ raw = sum(rewards) / max_r if max_r > 0 else 0.0
214
+ score = round(max(0.001, min(0.999, raw)), 4)
215
 
216
  success = score >= SUCCESS_SCORE_THRESHOLD
217
 
218
+ finally:
219
+ log_end(success=success, steps=steps_taken,
220
+ score=score, rewards=rewards)
221
+
222
+
223
+ # ── Main ──────────────────────────────────────────────────────────────────────
224
+ async def main() -> None:
225
+ llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
226
+
227
+ if IMAGE_NAME:
228
+ print(f"[DEBUG] Spinning up Docker image: {IMAGE_NAME}", flush=True)
229
+ env_client = await FactoryEnvClient.from_docker_image(IMAGE_NAME)
230
+ else:
231
+ url = ENV_URL or "http://localhost:7860"
232
+ print(f"[DEBUG] Connecting to: {url}", flush=True)
233
+ env_client = FactoryEnvClient(base_url=url)
234
+ await env_client.connect()
235
+
236
+ try:
237
+ for task in TASKS:
238
+ await run_task(env_client, llm_client, task)
239
  finally:
240
  try:
241
+ await env_client.close()
242
  except Exception as exc:
243
  print(f"[DEBUG] env.close() error: {exc}", flush=True)
 
244
 
245
 
246
  if __name__ == "__main__":