Spaces:
Sleeping
Sleeping
| import random | |
| from typing import List, Optional | |
| from openenv.core import Environment | |
| from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job | |
| from factory_env.tasks import TASKS | |
| class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]): | |
| """Smart Factory Scheduling Environment — OpenEnv compliant.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self, task: str = "easy", seed: int = 42): | |
| super().__init__() | |
| if task not in TASKS: | |
| raise ValueError(f"Unknown task '{task}'. Choose from: {list(TASKS.keys())}") | |
| self.task = task | |
| self.seed = seed | |
| self.config = TASKS[task] | |
| self._rng = random.Random(seed) | |
| self.machines: List[Machine] = [] | |
| self.jobs: List[Job] = [] | |
| self.completed_jobs: List[Job] = [] | |
| self.late_jobs: int = 0 | |
| self.time: int = 0 | |
| self.max_steps: int = self.config["max_steps"] | |
| def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation: | |
| use_seed = seed if seed is not None else self.seed | |
| self._rng = random.Random(use_seed) | |
| self.time = 0 | |
| self.completed_jobs = [] | |
| self.late_jobs = 0 | |
| cfg = self.config | |
| self.machines = [ | |
| Machine(id=f"M{i+1}", status="idle", failure_rate=cfg.get("failure_rate", 0.0)) | |
| for i in range(cfg["num_machines"]) | |
| ] | |
| self.jobs = [] | |
| for i in range(cfg["num_jobs"]): | |
| proc_time = self._rng.randint(*cfg["job_time_range"]) | |
| deadline = self.time + proc_time + self._rng.randint(*cfg["deadline_slack"]) | |
| priority = self._rng.randint(1, cfg.get("max_priority", 1)) | |
| self.jobs.append(Job(id=f"J{i+1}", remaining_time=proc_time, deadline=deadline, priority=priority)) | |
| return self._make_obs(reward=None, done=False) | |
| def step(self, action: FactoryAction, timeout_s: Optional[float] = None, **kwargs) -> FactoryObservation: | |
| reward = 0.0 | |
| if action.action_type == "assign_job": | |
| job = self._find_job(action.job_id) | |
| machine = self._find_machine(action.machine_id) | |
| if job is None or machine is None or machine.status != "idle": | |
| reward -= 0.1 | |
| else: | |
| job.assigned_machine = machine.id | |
| machine.status = "busy" | |
| machine.current_job = job.id | |
| reward += 0.1 | |
| elif action.action_type == "repair": | |
| machine = self._find_machine(action.machine_id) | |
| if machine and machine.status == "broken": | |
| machine.status = "idle" | |
| reward += 0.05 | |
| else: | |
| reward -= 0.05 | |
| self.time += 1 | |
| for machine in self.machines: | |
| if machine.status == "busy": | |
| job = self._find_job(machine.current_job) | |
| if job: | |
| job.remaining_time -= 1 | |
| if job.remaining_time <= 0: | |
| on_time = self.time <= job.deadline | |
| reward += (1.0 + 0.2 * job.priority) if on_time else 0.3 | |
| if not on_time: | |
| self.late_jobs += 1 | |
| self.jobs.remove(job) | |
| self.completed_jobs.append(job) | |
| machine.status = "idle" | |
| machine.current_job = None | |
| if machine.status == "busy" and machine.failure_rate > 0: | |
| if self._rng.random() < machine.failure_rate: | |
| machine.status = "broken" | |
| stalled = self._find_job(machine.current_job) | |
| if stalled: | |
| stalled.assigned_machine = None | |
| machine.current_job = None | |
| if self.jobs: | |
| reward -= sum(1 for m in self.machines if m.status == "idle") * 0.05 | |
| for job in self.jobs: | |
| if self.time > job.deadline: | |
| reward -= 0.1 | |
| done = self.time >= self.max_steps or len(self.jobs) == 0 | |
| return self._make_obs(reward=reward, done=done) | |
| def state(self) -> FactoryState: | |
| return FactoryState( | |
| machines=list(self.machines), | |
| pending_jobs=list(self.jobs), | |
| completed_jobs=list(self.completed_jobs), | |
| time=self.time, | |
| task=self.task, | |
| late_jobs=self.late_jobs, | |
| step_count=self.time, | |
| ) | |
| def _make_obs(self, reward, done: bool) -> FactoryObservation: | |
| return FactoryObservation( | |
| machines=list(self.machines), | |
| pending_jobs=list(self.jobs), | |
| completed_jobs=list(self.completed_jobs), | |
| time=self.time, | |
| max_steps=self.max_steps, | |
| task=self.task, | |
| reward=reward, | |
| done=done, | |
| ) | |
| def _find_job(self, job_id: Optional[str]) -> Optional[Job]: | |
| return next((j for j in self.jobs if j.id == job_id), None) if job_id else None | |
| def _find_machine(self, machine_id: Optional[str]) -> Optional[Machine]: | |
| return next((m for m in self.machines if m.id == machine_id), None) if machine_id else None | |