Spaces:
Sleeping
Sleeping
Fix task selection: reset() accepts task param, all 3 tasks verified
Browse files- factory_env/env.py +7 -0
- inference.py +1 -1
- pyproject.toml +20 -0
factory_env/env.py
CHANGED
|
@@ -28,6 +28,13 @@ class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]):
|
|
| 28 |
self.max_steps: int = self.config["max_steps"]
|
| 29 |
|
| 30 |
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
use_seed = seed if seed is not None else self.seed
|
| 32 |
self._rng = random.Random(use_seed)
|
| 33 |
self.time = 0
|
|
|
|
| 28 |
self.max_steps: int = self.config["max_steps"]
|
| 29 |
|
| 30 |
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
|
| 31 |
+
# Allow task to be overridden at reset time (e.g. from inference script)
|
| 32 |
+
task = kwargs.get("task", self.task)
|
| 33 |
+
if task != self.task and task in TASKS:
|
| 34 |
+
self.task = task
|
| 35 |
+
self.config = TASKS[task]
|
| 36 |
+
self.max_steps = self.config["max_steps"]
|
| 37 |
+
|
| 38 |
use_seed = seed if seed is not None else self.seed
|
| 39 |
self._rng = random.Random(use_seed)
|
| 40 |
self.time = 0
|
inference.py
CHANGED
|
@@ -189,7 +189,7 @@ async def main() -> None:
|
|
| 189 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 190 |
|
| 191 |
try:
|
| 192 |
-
result = await env.reset()
|
| 193 |
obs = result.observation
|
| 194 |
last_reward = 0.0
|
| 195 |
|
|
|
|
| 189 |
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 190 |
|
| 191 |
try:
|
| 192 |
+
result = await env.reset(task=TASK_NAME)
|
| 193 |
obs = result.observation
|
| 194 |
last_reward = 0.0
|
| 195 |
|
pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68"]
|
| 3 |
+
build-backend = "setuptools.backends.legacy:build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "factory-env"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Smart Factory Scheduling — OpenEnv RL Environment"
|
| 9 |
+
requires-python = ">=3.11"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"pydantic>=2.0",
|
| 12 |
+
"openai>=1.0",
|
| 13 |
+
"anthropic>=0.90",
|
| 14 |
+
"gradio>=6.0",
|
| 15 |
+
"openenv-core>=0.2.3",
|
| 16 |
+
"fastapi>=0.100",
|
| 17 |
+
"uvicorn>=0.23",
|
| 18 |
+
"websockets>=12.0",
|
| 19 |
+
"httpx>=0.27",
|
| 20 |
+
]
|