Aldrimore's picture
OpenEnv Submission
6195f6a
import random
from typing import List, Optional
from openenv.core import Environment
from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
from factory_env.tasks import TASKS
class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]):
"""Smart Factory Scheduling Environment — OpenEnv compliant."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, task: str = "easy", seed: int = 42):
super().__init__()
if task not in TASKS:
raise ValueError(f"Unknown task '{task}'. Choose from: {list(TASKS.keys())}")
self.task = task
self.seed = seed
self.config = TASKS[task]
self._rng = random.Random(seed)
self.machines: List[Machine] = []
self.jobs: List[Job] = []
self.completed_jobs: List[Job] = []
self.late_jobs: int = 0
self.time: int = 0
self.max_steps: int = self.config["max_steps"]
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
use_seed = seed if seed is not None else self.seed
self._rng = random.Random(use_seed)
self.time = 0
self.completed_jobs = []
self.late_jobs = 0
cfg = self.config
self.machines = [
Machine(id=f"M{i+1}", status="idle", failure_rate=cfg.get("failure_rate", 0.0))
for i in range(cfg["num_machines"])
]
self.jobs = []
for i in range(cfg["num_jobs"]):
proc_time = self._rng.randint(*cfg["job_time_range"])
deadline = self.time + proc_time + self._rng.randint(*cfg["deadline_slack"])
priority = self._rng.randint(1, cfg.get("max_priority", 1))
self.jobs.append(Job(id=f"J{i+1}", remaining_time=proc_time, deadline=deadline, priority=priority))
return self._make_obs(reward=None, done=False)
def step(self, action: FactoryAction, timeout_s: Optional[float] = None, **kwargs) -> FactoryObservation:
reward = 0.0
if action.action_type == "assign_job":
job = self._find_job(action.job_id)
machine = self._find_machine(action.machine_id)
if job is None or machine is None or machine.status != "idle":
reward -= 0.1
else:
job.assigned_machine = machine.id
machine.status = "busy"
machine.current_job = job.id
reward += 0.1
elif action.action_type == "repair":
machine = self._find_machine(action.machine_id)
if machine and machine.status == "broken":
machine.status = "idle"
reward += 0.05
else:
reward -= 0.05
self.time += 1
for machine in self.machines:
if machine.status == "busy":
job = self._find_job(machine.current_job)
if job:
job.remaining_time -= 1
if job.remaining_time <= 0:
on_time = self.time <= job.deadline
reward += (1.0 + 0.2 * job.priority) if on_time else 0.3
if not on_time:
self.late_jobs += 1
self.jobs.remove(job)
self.completed_jobs.append(job)
machine.status = "idle"
machine.current_job = None
if machine.status == "busy" and machine.failure_rate > 0:
if self._rng.random() < machine.failure_rate:
machine.status = "broken"
stalled = self._find_job(machine.current_job)
if stalled:
stalled.assigned_machine = None
machine.current_job = None
if self.jobs:
reward -= sum(1 for m in self.machines if m.status == "idle") * 0.05
for job in self.jobs:
if self.time > job.deadline:
reward -= 0.1
done = self.time >= self.max_steps or len(self.jobs) == 0
return self._make_obs(reward=reward, done=done)
@property
def state(self) -> FactoryState:
return FactoryState(
machines=list(self.machines),
pending_jobs=list(self.jobs),
completed_jobs=list(self.completed_jobs),
time=self.time,
task=self.task,
late_jobs=self.late_jobs,
step_count=self.time,
)
def _make_obs(self, reward, done: bool) -> FactoryObservation:
return FactoryObservation(
machines=list(self.machines),
pending_jobs=list(self.jobs),
completed_jobs=list(self.completed_jobs),
time=self.time,
max_steps=self.max_steps,
task=self.task,
reward=reward,
done=done,
)
def _find_job(self, job_id: Optional[str]) -> Optional[Job]:
return next((j for j in self.jobs if j.id == job_id), None) if job_id else None
def _find_machine(self, machine_id: Optional[str]) -> Optional[Machine]:
return next((m for m in self.machines if m.id == machine_id), None) if machine_id else None