Aldrimore commited on
Commit
c01d448
·
1 Parent(s): fe75421

Fix task selection: reset() accepts task param, all 3 tasks verified

Browse files
Files changed (3) hide show
  1. factory_env/env.py +7 -0
  2. inference.py +1 -1
  3. pyproject.toml +20 -0
factory_env/env.py CHANGED
@@ -28,6 +28,13 @@ class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]):
28
  self.max_steps: int = self.config["max_steps"]
29
 
30
  def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
 
 
 
 
 
 
 
31
  use_seed = seed if seed is not None else self.seed
32
  self._rng = random.Random(use_seed)
33
  self.time = 0
 
28
  self.max_steps: int = self.config["max_steps"]
29
 
30
  def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
31
+ # Allow task to be overridden at reset time (e.g. from inference script)
32
+ task = kwargs.get("task", self.task)
33
+ if task != self.task and task in TASKS:
34
+ self.task = task
35
+ self.config = TASKS[task]
36
+ self.max_steps = self.config["max_steps"]
37
+
38
  use_seed = seed if seed is not None else self.seed
39
  self._rng = random.Random(use_seed)
40
  self.time = 0
inference.py CHANGED
@@ -189,7 +189,7 @@ async def main() -> None:
189
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
190
 
191
  try:
192
- result = await env.reset()
193
  obs = result.observation
194
  last_reward = 0.0
195
 
 
189
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
190
 
191
  try:
192
+ result = await env.reset(task=TASK_NAME)
193
  obs = result.observation
194
  last_reward = 0.0
195
 
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68"]
3
+ build-backend = "setuptools.backends.legacy:build"
4
+
5
+ [project]
6
+ name = "factory-env"
7
+ version = "1.0.0"
8
+ description = "Smart Factory Scheduling — OpenEnv RL Environment"
9
+ requires-python = ">=3.11"
10
+ dependencies = [
11
+ "pydantic>=2.0",
12
+ "openai>=1.0",
13
+ "anthropic>=0.90",
14
+ "gradio>=6.0",
15
+ "openenv-core>=0.2.3",
16
+ "fastapi>=0.100",
17
+ "uvicorn>=0.23",
18
+ "websockets>=12.0",
19
+ "httpx>=0.27",
20
+ ]