Roshan818 commited on
Commit
a0f94c2
Β·
1 Parent(s): b77936a

fix: lazy openenv.core imports so FactoryEnv works on all Python envs

Browse files

- factory_env/env.py: wrap openenv.core Environment import in try/except
- factory_env/models.py: lazy imports for Action/Observation/State base classes,
add explicit done/reward fields to FactoryObservation for fallback case
- grader.py: clean FactoryEnv-only grader, no pure-Python fallback

Files changed (3) hide show
  1. factory_env/env.py +8 -2
  2. factory_env/models.py +13 -2
  3. grader.py +18 -102
factory_env/env.py CHANGED
@@ -1,13 +1,19 @@
1
  import random
2
  from typing import List, Optional
3
 
4
- from openenv.core import Environment
 
 
 
 
 
 
5
 
6
  from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
7
  from factory_env.tasks import TASKS
8
 
9
 
10
- class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]):
11
  """Smart Factory Scheduling Environment β€” OpenEnv compliant."""
12
 
13
  SUPPORTS_CONCURRENT_SESSIONS = True
 
1
  import random
2
  from typing import List, Optional
3
 
4
+ # Lazy base-class: import openenv.core only when it's available.
5
+ # This lets FactoryEnv be imported (e.g. by the grader) even in minimal
6
+ # environments where openenv-core's gradio/PIL chain fails to load.
7
+ try:
8
+ from openenv.core import Environment as _EnvBase
9
+ except Exception:
10
+ _EnvBase = object # type: ignore[assignment,misc]
11
 
12
  from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
13
  from factory_env.tasks import TASKS
14
 
15
 
16
+ class FactoryEnv(_EnvBase):
17
  """Smart Factory Scheduling Environment β€” OpenEnv compliant."""
18
 
19
  SUPPORTS_CONCURRENT_SESSIONS = True
factory_env/models.py CHANGED
@@ -1,6 +1,14 @@
1
  from typing import List, Optional
2
  from pydantic import BaseModel, ConfigDict, Field
3
- from openenv.core import Action as BaseAction, Observation as BaseObservation, State as BaseState
 
 
 
 
 
 
 
 
4
 
5
 
6
  class Machine(BaseModel):
@@ -32,7 +40,10 @@ class FactoryAction(BaseAction):
32
 
33
 
34
  class FactoryObservation(BaseObservation):
35
- """Inherits: done (bool), reward (float|None), metadata (dict)"""
 
 
 
36
  machines: List[Machine] = Field(default_factory=list)
37
  pending_jobs: List[Job] = Field(default_factory=list)
38
  completed_jobs: List[Job] = Field(default_factory=list)
 
1
  from typing import List, Optional
2
  from pydantic import BaseModel, ConfigDict, Field
3
+
4
+ # Lazy openenv base classes β€” fall back to pydantic BaseModel when the
5
+ # openenv.core import chain (which pulls in gradio/PIL) is unavailable.
6
+ try:
7
+ from openenv.core import Action as BaseAction, Observation as BaseObservation, State as BaseState
8
+ except Exception:
9
+ BaseAction = BaseModel # type: ignore[assignment,misc]
10
+ BaseObservation = BaseModel # type: ignore[assignment,misc]
11
+ BaseState = BaseModel # type: ignore[assignment,misc]
12
 
13
 
14
  class Machine(BaseModel):
 
40
 
41
 
42
  class FactoryObservation(BaseObservation):
43
+ """Inherits done/reward/metadata from openenv base when available;
44
+ defined here explicitly so the class works when falling back to BaseModel."""
45
+ done: bool = False
46
+ reward: Optional[float] = None
47
  machines: List[Machine] = Field(default_factory=list)
48
  pending_jobs: List[Job] = Field(default_factory=list)
49
  completed_jobs: List[Job] = Field(default_factory=list)
grader.py CHANGED
@@ -1,25 +1,17 @@
1
  """
2
  Graders for Smart Factory Scheduling tasks.
3
 
4
- Primary path: imports FactoryEnv, runs a full deterministic heuristic episode,
5
- and scores the result using the real environment state.
6
-
7
- Fallback path (if factory_env is unavailable): a minimal pure-Python
8
- simulation is used so the validator can still load and call these functions.
9
-
10
- All three public functions:
11
- - Accept an optional state/env argument for scoring a finished episode.
12
- - When called with no argument, run their own deterministic episode.
13
- - Always return a float strictly in (0.0, 1.0).
14
  """
15
 
16
  from __future__ import annotations
17
 
18
- import random
19
- from typing import Any, List, Optional
20
 
21
-
22
- # ── Score formula (shared by both paths) ─────────────────────────────────────
23
 
24
  def _compute(completed: int, on_time: int, total: int, late: int) -> float:
25
  if total == 0:
@@ -32,8 +24,8 @@ def _compute(completed: int, on_time: int, total: int, late: int) -> float:
32
  return round(max(0.001, min(0.999, score)), 4)
33
 
34
 
35
- def _score_obj(obj: Any) -> float:
36
- """Score from a finished env object or state dict."""
37
  if isinstance(obj, dict):
38
  done_list = obj.get("completed_jobs", []) or []
39
  pend_list = obj.get("pending_jobs", []) or []
@@ -53,17 +45,14 @@ def _score_obj(obj: Any) -> float:
53
  t = int(getattr(obj, "time", 0) or 0)
54
  completed = len(done_list)
55
  total = completed + len(pend_list)
56
- on_time = sum(
57
- 1 for j in done_list
58
- if getattr(j, "deadline", 0) >= t
59
- )
60
  return _compute(completed, on_time, total, late)
61
 
62
 
63
- # ── Primary path: use the real FactoryEnv ────────────────────────────────────
64
 
65
  def _heuristic(obs):
66
- """Earliest-deadline-first heuristic action (works on FactoryObservation)."""
67
  from factory_env.models import FactoryAction
68
  for m in obs.machines:
69
  if m.status == "broken":
@@ -76,8 +65,10 @@ def _heuristic(obs):
76
  return None
77
 
78
 
79
- def _run_factory_episode(task: str, seed: int = 42) -> float:
80
- """Run a full heuristic episode on the real FactoryEnv and return score."""
 
 
81
  from factory_env.env import FactoryEnv
82
  from factory_env.models import FactoryAction
83
 
@@ -91,81 +82,6 @@ def _run_factory_episode(task: str, seed: int = 42) -> float:
91
  return _score_obj(env)
92
 
93
 
94
- # ── Fallback path: pure-Python mini-simulation ───────────────────────────────
95
-
96
- _TASK_CFG = {
97
- "easy": dict(nm=2, nj=3, fr=0.00, ms=20, jtr=(2,4), ds=(2,5), mp=1),
98
- "medium": dict(nm=4, nj=7, fr=0.08, ms=30, jtr=(2,5), ds=(2,6), mp=2),
99
- "hard": dict(nm=6, nj=12, fr=0.15, ms=40, jtr=(2,6), ds=(1,5), mp=3),
100
- }
101
-
102
-
103
- def _run_mini_episode(task: str, seed: int = 42) -> float:
104
- """Pure-Python fallback simulation (no external deps)."""
105
- cfg = _TASK_CFG[task]
106
- rng = random.Random(seed)
107
-
108
- machines = [{"id": f"M{i+1}", "status": "idle", "job": None,
109
- "fr": cfg["fr"]} for i in range(cfg["nm"])]
110
- jobs = []
111
- for i in range(cfg["nj"]):
112
- pt = rng.randint(*cfg["jtr"])
113
- dl = pt + rng.randint(*cfg["ds"])
114
- jobs.append({"id": f"J{i+1}", "rt": pt, "dl": dl,
115
- "pr": rng.randint(1, cfg["mp"])})
116
-
117
- completed, late, t = [], 0, 0
118
-
119
- for _ in range(cfg["ms"]):
120
- if not jobs:
121
- break
122
- # repair broken machines
123
- for m in machines:
124
- if m["status"] == "broken":
125
- m["status"] = "idle"
126
- break
127
- # assign jobs EDF
128
- for j in sorted(jobs, key=lambda x: (x["dl"], -x["pr"])):
129
- for m in machines:
130
- if m["status"] == "idle":
131
- m["status"] = "busy"
132
- m["job"] = j["id"]
133
- j["m"] = m["id"]
134
- break
135
-
136
- t += 1
137
- for m in machines:
138
- if m["status"] == "busy":
139
- j = next((x for x in jobs if x["id"] == m["job"]), None)
140
- if j:
141
- j["rt"] -= 1
142
- if j["rt"] <= 0:
143
- if t > j["dl"]:
144
- late += 1
145
- completed.append(j)
146
- jobs.remove(j)
147
- m["status"] = "idle"
148
- m["job"] = None
149
- if m["status"] == "busy" and cfg["fr"] > 0:
150
- if rng.random() < cfg["fr"]:
151
- m["status"] = "broken"
152
- m["job"] = None
153
-
154
- total = len(completed) + len(jobs)
155
- n = len(completed)
156
- on_time = max(0, n - late)
157
- return _compute(n, on_time, total, late)
158
-
159
-
160
- # ── Episode runner (tries FactoryEnv, falls back if unavailable) ─────────────
161
-
162
- def _episode(task: str) -> float:
163
- try:
164
- return _run_factory_episode(task)
165
- except Exception:
166
- return _run_mini_episode(task)
167
-
168
-
169
  # ── Public graders ────────────────────────────────────────────────────────────
170
 
171
  def score_easy(state_or_env=None) -> float:
@@ -173,7 +89,7 @@ def score_easy(state_or_env=None) -> float:
173
  Returns float in (0.0, 1.0)."""
174
  if state_or_env is not None:
175
  return _score_obj(state_or_env)
176
- return _episode("easy")
177
 
178
 
179
  def score_medium(state_or_env=None) -> float:
@@ -181,7 +97,7 @@ def score_medium(state_or_env=None) -> float:
181
  Returns float in (0.0, 1.0)."""
182
  if state_or_env is not None:
183
  return _score_obj(state_or_env)
184
- return _episode("medium")
185
 
186
 
187
  def score_hard(state_or_env=None) -> float:
@@ -189,4 +105,4 @@ def score_hard(state_or_env=None) -> float:
189
  Returns float in (0.0, 1.0)."""
190
  if state_or_env is not None:
191
  return _score_obj(state_or_env)
192
- return _episode("hard")
 
1
  """
2
  Graders for Smart Factory Scheduling tasks.
3
 
4
+ Each public function:
5
+ - Accepts an optional state/env argument to score a finished episode.
6
+ - When called with no argument, runs a deterministic heuristic episode
7
+ on the real FactoryEnv and returns the score.
8
+ - Always returns a float strictly in (0.0, 1.0).
 
 
 
 
 
9
  """
10
 
11
  from __future__ import annotations
12
 
 
 
13
 
14
+ # ── Score formula ─────────────────────────────────────────────────────────────
 
15
 
16
  def _compute(completed: int, on_time: int, total: int, late: int) -> float:
17
  if total == 0:
 
24
  return round(max(0.001, min(0.999, score)), 4)
25
 
26
 
27
+ def _score_obj(obj) -> float:
28
+ """Score from a finished FactoryEnv object or state dict."""
29
  if isinstance(obj, dict):
30
  done_list = obj.get("completed_jobs", []) or []
31
  pend_list = obj.get("pending_jobs", []) or []
 
45
  t = int(getattr(obj, "time", 0) or 0)
46
  completed = len(done_list)
47
  total = completed + len(pend_list)
48
+ on_time = sum(1 for j in done_list if getattr(j, "deadline", 0) >= t)
 
 
 
49
  return _compute(completed, on_time, total, late)
50
 
51
 
52
+ # ── Heuristic agent ───────────────────────────────────────────────────────────
53
 
54
  def _heuristic(obs):
55
+ """Earliest-deadline-first heuristic that runs on a FactoryObservation."""
56
  from factory_env.models import FactoryAction
57
  for m in obs.machines:
58
  if m.status == "broken":
 
65
  return None
66
 
67
 
68
+ # ── Episode runner ────────────────────────────────────────────────────────────
69
+
70
+ def _run_episode(task: str, seed: int = 42) -> float:
71
+ """Run a full heuristic episode on FactoryEnv and return the graded score."""
72
  from factory_env.env import FactoryEnv
73
  from factory_env.models import FactoryAction
74
 
 
82
  return _score_obj(env)
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # ── Public graders ────────────────────────────────────────────────────────────
86
 
87
  def score_easy(state_or_env=None) -> float:
 
89
  Returns float in (0.0, 1.0)."""
90
  if state_or_env is not None:
91
  return _score_obj(state_or_env)
92
+ return _run_episode("easy")
93
 
94
 
95
  def score_medium(state_or_env=None) -> float:
 
97
  Returns float in (0.0, 1.0)."""
98
  if state_or_env is not None:
99
  return _score_obj(state_or_env)
100
+ return _run_episode("medium")
101
 
102
 
103
  def score_hard(state_or_env=None) -> float:
 
105
  Returns float in (0.0, 1.0)."""
106
  if state_or_env is not None:
107
  return _score_obj(state_or_env)
108
+ return _run_episode("hard")