Spaces:
Sleeping
Sleeping
OpenEnv Submission
Browse files- .gitignore +38 -0
- Dockerfile +17 -0
- README.md +105 -1
- app.py +212 -0
- factory_env/__init__.py +2 -0
- factory_env/env.py +133 -92
- factory_env/grader.py +15 -3
- factory_env/models.py +50 -21
- factory_env/tasks.py +33 -15
- inference.py +71 -175
- requirements.txt +7 -3
- server.py +21 -0
- train.py +217 -0
.gitignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.egg-info/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
venv/
|
| 11 |
+
.venv/
|
| 12 |
+
env/
|
| 13 |
+
|
| 14 |
+
# Secrets
|
| 15 |
+
.env
|
| 16 |
+
.env.*
|
| 17 |
+
|
| 18 |
+
# OS
|
| 19 |
+
.DS_Store
|
| 20 |
+
Thumbs.db
|
| 21 |
+
|
| 22 |
+
# IDE
|
| 23 |
+
.vscode/
|
| 24 |
+
.idea/
|
| 25 |
+
|
| 26 |
+
# Logs
|
| 27 |
+
*.log
|
| 28 |
+
|
| 29 |
+
# Training runs
|
| 30 |
+
runs/
|
| 31 |
+
|
| 32 |
+
# Docker
|
| 33 |
+
*.tar
|
| 34 |
+
|
| 35 |
+
# Hackathon docs
|
| 36 |
+
rule.txt
|
| 37 |
+
"Meta RL Hackathon.docx"
|
| 38 |
+
Meta\ RL\ Hackathon.docx
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
ENV FACTORY_TASK=easy
|
| 11 |
+
ENV API_BASE_URL=https://router.huggingface.co/v1
|
| 12 |
+
ENV MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 13 |
+
ENV PORT=7860
|
| 14 |
+
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
CMD ["python", "server.py"]
|
README.md
CHANGED
|
@@ -1 +1,105 @@
|
|
| 1 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Smart Factory Scheduling Environment
|
| 2 |
+
|
| 3 |
+
An [OpenEnv](https://github.com/openenv/openenv)-compliant RL environment simulating real-world industrial scheduling: assign jobs to machines, handle breakdowns, and maximise throughput within deadlines.
|
| 4 |
+
|
| 5 |
+
## Observation Space
|
| 6 |
+
|
| 7 |
+
| Field | Type | Description |
|
| 8 |
+
|-------|------|-------------|
|
| 9 |
+
| `machines` | List[Machine] | id, status (idle/busy/broken), current_job, failure_rate |
|
| 10 |
+
| `pending_jobs` | List[Job] | id, remaining_time, deadline, priority (1-3), assigned_machine |
|
| 11 |
+
| `completed_jobs` | List[Job] | Jobs finished this episode |
|
| 12 |
+
| `time` | int | Current time step |
|
| 13 |
+
| `max_steps` | int | Episode length |
|
| 14 |
+
| `done` | bool | Episode terminated |
|
| 15 |
+
| `reward` | float | Reward from last action |
|
| 16 |
+
|
| 17 |
+
## Action Space
|
| 18 |
+
|
| 19 |
+
| Action | Effect |
|
| 20 |
+
|--------|--------|
|
| 21 |
+
| `assign_job <job_id> <machine_id>` | Assign pending job to idle machine |
|
| 22 |
+
| `repair <machine_id>` | Restore broken machine to idle |
|
| 23 |
+
| `wait` | Advance time with no change |
|
| 24 |
+
|
| 25 |
+
## Reward Function
|
| 26 |
+
|
| 27 |
+
| Event | Reward |
|
| 28 |
+
|-------|--------|
|
| 29 |
+
| Job completed on time | +1.00 + 0.20 × priority |
|
| 30 |
+
| Job completed late | +0.30 |
|
| 31 |
+
| Valid assignment | +0.10 |
|
| 32 |
+
| Invalid action | −0.10 |
|
| 33 |
+
| Idle machine (pending jobs exist) | −0.05 per machine |
|
| 34 |
+
| Job past deadline | −0.10 per step |
|
| 35 |
+
| Repair broken machine | +0.05 |
|
| 36 |
+
|
| 37 |
+
## Tasks
|
| 38 |
+
|
| 39 |
+
| Task | Machines | Jobs | Failure Rate | Max Steps | Baseline Score |
|
| 40 |
+
|------|----------|------|-------------|-----------|----------------|
|
| 41 |
+
| easy | 2 | 3 | 0% | 20 | 1.000 |
|
| 42 |
+
| medium | 4 | 7 | 8% | 30 | ~0.557 |
|
| 43 |
+
| hard | 6 | 12 | 15% | 40 | ~0.457 |
|
| 44 |
+
|
| 45 |
+
**Score formula:** `0.5 × completion_rate + 0.3 × on_time_rate + 0.2 × utilization_bonus`
|
| 46 |
+
|
| 47 |
+
## Setup
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Run HTTP Server (HF Space)
|
| 54 |
+
```bash
|
| 55 |
+
python server.py
|
| 56 |
+
# Routes: GET /health POST /reset POST /step GET /state GET /schema
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Run Inference (LLM agent)
|
| 60 |
+
```bash
|
| 61 |
+
export OPENAI_API_KEY=<your-key>
|
| 62 |
+
export FACTORY_TASK=easy # easy | medium | hard
|
| 63 |
+
python inference.py
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Run RL Training
|
| 67 |
+
```bash
|
| 68 |
+
python train.py --task easy --episodes 10 --provider openai
|
| 69 |
+
python train.py --task medium --episodes 10 --provider claude
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### Interactive Demo
|
| 73 |
+
```bash
|
| 74 |
+
python app.py # opens at http://localhost:7860
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Docker
|
| 78 |
+
```bash
|
| 79 |
+
docker build -t factory-env .
|
| 80 |
+
docker run -e OPENAI_API_KEY=<key> -e FACTORY_TASK=easy -p 7860:7860 factory-env
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Baseline Scores
|
| 84 |
+
|
| 85 |
+
| Task | Score | Steps |
|
| 86 |
+
|------|-------|-------|
|
| 87 |
+
| easy | 1.000 | 4 |
|
| 88 |
+
| medium | ~0.529 | 12 |
|
| 89 |
+
| hard | ~0.533 | 34 |
|
| 90 |
+
|
| 91 |
+
## Project Structure
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
├── factory_env/
|
| 95 |
+
│ ├── env.py # FactoryEnv (openenv.core.Environment)
|
| 96 |
+
│ ├── models.py # FactoryAction, FactoryObservation, FactoryState
|
| 97 |
+
│ ├── tasks.py # Task configurations
|
| 98 |
+
│ └── grader.py # Score computation
|
| 99 |
+
├── inference.py # LLM baseline agent
|
| 100 |
+
├── train.py # Multi-episode RL training loop
|
| 101 |
+
├── server.py # FastAPI HTTP server for HF Space
|
| 102 |
+
├── app.py # Gradio interactive demo
|
| 103 |
+
├── openenv.yaml # OpenEnv metadata
|
| 104 |
+
└── Dockerfile
|
| 105 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Smart Factory Scheduling — Interactive Gradio Demo
|
| 3 |
+
Run: python app.py → http://localhost:7860
|
| 4 |
+
"""
|
| 5 |
+
import asyncio, os
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from factory_env.env import FactoryEnv
|
| 9 |
+
from factory_env.grader import score_episode
|
| 10 |
+
from factory_env.models import FactoryAction as Action
|
| 11 |
+
|
| 12 |
+
_env: Optional[FactoryEnv] = None
|
| 13 |
+
_obs = None
|
| 14 |
+
_rewards: List[float] = []
|
| 15 |
+
_history: List[dict] = []
|
| 16 |
+
_step_num: int = 0
|
| 17 |
+
|
| 18 |
+
STATUS_EMOJI = {"idle": "🟢", "busy": "🔵", "broken": "🔴"}
|
| 19 |
+
SYSTEM_PROMPT = "You are a factory scheduler. Reply with ONE action:\n assign_job <job_id> <machine_id>\n repair <machine_id>\n wait"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _llm_client(provider, api_key):
|
| 23 |
+
if "Claude" in provider:
|
| 24 |
+
import anthropic
|
| 25 |
+
return ("claude", anthropic.Anthropic(api_key=api_key or os.getenv("ANTHROPIC_API_KEY")))
|
| 26 |
+
from openai import OpenAI
|
| 27 |
+
base = "https://api.openai.com/v1" if "OpenAI" in provider else "https://router.huggingface.co/v1"
|
| 28 |
+
return ("openai", OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN"), base_url=base))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _call_llm(provider_tuple, model, obs, last_reward, step):
|
| 32 |
+
kind, client = provider_tuple
|
| 33 |
+
machines = "\n".join(f" {m.id}: {m.status}" + (f" ({m.current_job})" if m.current_job else "") for m in obs.machines)
|
| 34 |
+
jobs = "\n".join(f" {j.id}: t={j.remaining_time} dl={j.deadline} p={j.priority}" for j in obs.pending_jobs) or " (none)"
|
| 35 |
+
user = f"Step {step}/{obs.max_steps} | t={obs.time} | reward={last_reward:+.2f}\nMachines:\n{machines}\nJobs:\n{jobs}\nAction:"
|
| 36 |
+
try:
|
| 37 |
+
if kind == "claude":
|
| 38 |
+
r = client.messages.create(model=model, max_tokens=50, system=SYSTEM_PROMPT, messages=[{"role":"user","content":user}])
|
| 39 |
+
return r.content[0].text.strip().splitlines()[0]
|
| 40 |
+
else:
|
| 41 |
+
r = client.chat.completions.create(model=model, temperature=0.2, max_tokens=50,
|
| 42 |
+
messages=[{"role":"system","content":SYSTEM_PROMPT},{"role":"user","content":user}])
|
| 43 |
+
return (r.choices[0].message.content or "wait").strip().splitlines()[0]
|
| 44 |
+
except Exception as e:
|
| 45 |
+
return f"wait # {e}"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _parse(text):
|
| 49 |
+
try:
|
| 50 |
+
p = text.strip().split()
|
| 51 |
+
if p[0] == "assign_job" and len(p) == 3: return Action(action_type="assign_job", job_id=p[1], machine_id=p[2])
|
| 52 |
+
if p[0] == "repair" and len(p) == 2: return Action(action_type="repair", machine_id=p[1])
|
| 53 |
+
except: pass
|
| 54 |
+
return Action(action_type="wait")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _heuristic(obs) -> Tuple[Action, str]:
|
| 58 |
+
for m in obs.machines:
|
| 59 |
+
if m.status == "broken": return Action(action_type="repair", machine_id=m.id), f"repair {m.id}"
|
| 60 |
+
for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
|
| 61 |
+
for m in obs.machines:
|
| 62 |
+
if m.status == "idle":
|
| 63 |
+
return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), f"assign_job {j.id} {m.id}"
|
| 64 |
+
return Action(action_type="wait"), "wait"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _render_state(obs):
|
| 68 |
+
if obs is None: return "*Reset to start*"
|
| 69 |
+
lines = [f"### ⏱ Time: {obs.time} / {obs.max_steps}",
|
| 70 |
+
"\n**Machines**", "| ID | Status | Job |", "|---|---|---|"]
|
| 71 |
+
for m in obs.machines:
|
| 72 |
+
lines.append(f"| {m.id} | {STATUS_EMOJI.get(m.status,'')} {m.status} | {m.current_job or '—'} |")
|
| 73 |
+
lines.append("\n**Pending Jobs**")
|
| 74 |
+
if obs.pending_jobs:
|
| 75 |
+
lines += ["| ID | Remaining | Deadline | Priority |", "|---|---|---|---|"]
|
| 76 |
+
for j in sorted(obs.pending_jobs, key=lambda x: x.deadline):
|
| 77 |
+
urgent = "🔥" if obs.time + j.remaining_time > j.deadline else ""
|
| 78 |
+
lines.append(f"| {j.id} {urgent} | {j.remaining_time} | {j.deadline} | {'★'*j.priority} |")
|
| 79 |
+
else:
|
| 80 |
+
lines.append("*All jobs completed! ✅*")
|
| 81 |
+
if obs.completed_jobs:
|
| 82 |
+
lines.append(f"\n**Completed:** {len(obs.completed_jobs)} ✅")
|
| 83 |
+
return "\n".join(lines)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _render_log(history):
|
| 87 |
+
if not history: return "*No steps yet*"
|
| 88 |
+
rows = ["| Step | Action | Reward | Done |", "|---|---|---|---|"]
|
| 89 |
+
for h in history[-15:]:
|
| 90 |
+
r = h["reward"]; icon = "🟢" if r > 0.3 else ("🔴" if r < -0.05 else "🟡")
|
| 91 |
+
rows.append(f"| {h['step']} | `{h['action']}` | {icon} {r:+.2f} | {'✅' if h['done'] else ''} |")
|
| 92 |
+
return "\n".join(rows)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _render_score(rewards, env):
|
| 96 |
+
if not rewards or not env: return ""
|
| 97 |
+
s = score_episode(env)
|
| 98 |
+
bar = "█" * int(s * 20) + "░" * (20 - int(s * 20))
|
| 99 |
+
return f"**Score:** {s:.4f} `[{bar}]`\n**Completed:** {len(env.completed_jobs)} | **Late:** {env.late_jobs} | **Total Reward:** {sum(rewards):.2f}"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def reset_env(task):
|
| 103 |
+
global _env, _obs, _rewards, _history, _step_num
|
| 104 |
+
_env = FactoryEnv(task=task, seed=42); _obs = _env.reset()
|
| 105 |
+
_rewards = []; _history = []; _step_num = 0
|
| 106 |
+
return _render_state(_obs), _render_log([]), "", f"✅ Reset — **{task}**: {len(_obs.machines)} machines, {len(_obs.pending_jobs)} jobs"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def manual_step(text):
|
| 110 |
+
global _obs, _rewards, _history, _step_num
|
| 111 |
+
if _env is None: return _render_state(None), _render_log([]), "", "⚠ Reset first."
|
| 112 |
+
if _obs.done: return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), "✅ Episode done."
|
| 113 |
+
_step_num += 1
|
| 114 |
+
_obs = _env.step(_parse(text.strip()))
|
| 115 |
+
r = _obs.reward or 0.0; _rewards.append(r); _history.append({"step": _step_num, "action": text.strip(), "reward": r, "done": _obs.done})
|
| 116 |
+
status = f"Step {_step_num}: `{text.strip()}` → **{r:+.2f}**"
|
| 117 |
+
if _obs.done: status += f"\n\n🏁 Done! Score: **{score_episode(_env):.4f}**"
|
| 118 |
+
return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def heuristic_step():
|
| 122 |
+
global _obs, _rewards, _history, _step_num
|
| 123 |
+
if _env is None: return _render_state(None), _render_log([]), "", "⚠ Reset first."
|
| 124 |
+
if _obs.done: return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), "✅ Episode done."
|
| 125 |
+
action, action_text = _heuristic(_obs)
|
| 126 |
+
_step_num += 1
|
| 127 |
+
_obs = _env.step(action)
|
| 128 |
+
r = _obs.reward or 0.0; _rewards.append(r); _history.append({"step": _step_num, "action": f"[H] {action_text}", "reward": r, "done": _obs.done})
|
| 129 |
+
status = f"[Heuristic] Step {_step_num}: `{action_text}` → **{r:+.2f}**"
|
| 130 |
+
if _obs.done: status += f"\n\n🏁 Done! Score: **{score_episode(_env):.4f}**"
|
| 131 |
+
return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def llm_step(provider, api_key, model):
|
| 135 |
+
global _obs, _rewards, _history, _step_num
|
| 136 |
+
if _env is None: return _render_state(None), _render_log([]), "", "⚠ Reset first.", ""
|
| 137 |
+
if _obs.done: return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), "✅ Episode done.", ""
|
| 138 |
+
try: client = _llm_client(provider, api_key)
|
| 139 |
+
except Exception as e: return _render_state(_obs), _render_log(_history), "", f"⚠ {e}", ""
|
| 140 |
+
action_text = _call_llm(client, model, _obs, _rewards[-1] if _rewards else 0.0, _step_num + 1)
|
| 141 |
+
action = _parse(action_text)
|
| 142 |
+
if action.action_type == "wait" and (_obs.pending_jobs or any(m.status == "broken" for m in _obs.machines)):
|
| 143 |
+
action, action_text = _heuristic(_obs)
|
| 144 |
+
action_text = f"[fallback] {action_text}"
|
| 145 |
+
_step_num += 1
|
| 146 |
+
_obs = _env.step(action)
|
| 147 |
+
r = _obs.reward or 0.0; _rewards.append(r); _history.append({"step": _step_num, "action": f"[LLM] {action_text}", "reward": r, "done": _obs.done})
|
| 148 |
+
status = f"[LLM] Step {_step_num}: `{action_text}` → **{r:+.2f}**"
|
| 149 |
+
if _obs.done: status += f"\n\n🏁 Done! Score: **{score_episode(_env):.4f}**"
|
| 150 |
+
return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status, action_text
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def run_full_episode(provider, api_key, model, task):
|
| 154 |
+
global _env, _obs, _rewards, _history, _step_num
|
| 155 |
+
_env = FactoryEnv(task=task, seed=42); _obs = _env.reset()
|
| 156 |
+
_rewards = []; _history = []; _step_num = 0
|
| 157 |
+
try: client = _llm_client(provider, api_key)
|
| 158 |
+
except Exception as e: return _render_state(_obs), _render_log([]), "", f"⚠ {e}", ""
|
| 159 |
+
log_lines = []
|
| 160 |
+
while not _obs.done and _step_num < _obs.max_steps:
|
| 161 |
+
action_text = _call_llm(client, model, _obs, _rewards[-1] if _rewards else 0.0, _step_num + 1)
|
| 162 |
+
action = _parse(action_text)
|
| 163 |
+
if action.action_type == "wait" and (_obs.pending_jobs or any(m.status == "broken" for m in _obs.machines)):
|
| 164 |
+
action, action_text = _heuristic(_obs); action_text = f"[fb] {action_text}"
|
| 165 |
+
_step_num += 1; _obs = _env.step(action)
|
| 166 |
+
r = _obs.reward or 0.0; _rewards.append(r)
|
| 167 |
+
_history.append({"step": _step_num, "action": action_text, "reward": r, "done": _obs.done})
|
| 168 |
+
log_lines.append(f"Step {_step_num:2d}: {action_text:<35s} r={r:+.2f}")
|
| 169 |
+
s = score_episode(_env)
|
| 170 |
+
status = f"🏁 **Done!** Score: **{s:.4f}** | Completed: {len(_env.completed_jobs)} | Late: {_env.late_jobs}"
|
| 171 |
+
return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status, "\n".join(log_lines)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def build_ui():
|
| 175 |
+
with gr.Blocks(title="Smart Factory RL") as demo:
|
| 176 |
+
gr.Markdown("# 🏭 Smart Factory Scheduling — Interactive RL Demo")
|
| 177 |
+
with gr.Row():
|
| 178 |
+
with gr.Column(scale=1):
|
| 179 |
+
gr.Markdown("### ⚙️ Setup")
|
| 180 |
+
task_dd = gr.Dropdown(["easy","medium","hard"], value="easy", label="Task")
|
| 181 |
+
provider_dd = gr.Dropdown(["OpenAI (GPT)","Claude (Anthropic)","HuggingFace Router"], value="OpenAI (GPT)", label="Provider")
|
| 182 |
+
api_key_box = gr.Textbox(label="API Key", type="password", placeholder="sk-... or sk-ant-...")
|
| 183 |
+
model_box = gr.Textbox(label="Model", value="gpt-4o-mini")
|
| 184 |
+
reset_btn = gr.Button("🔄 Reset", variant="primary")
|
| 185 |
+
gr.Markdown("### 🎮 Manual")
|
| 186 |
+
manual_input = gr.Textbox(label="Action", placeholder="assign_job J1 M1 | repair M2 | wait")
|
| 187 |
+
with gr.Row():
|
| 188 |
+
manual_btn = gr.Button("▶ Execute")
|
| 189 |
+
heuristic_btn = gr.Button("🤖 Heuristic Step")
|
| 190 |
+
gr.Markdown("### 🧠 LLM")
|
| 191 |
+
with gr.Row():
|
| 192 |
+
llm_step_btn = gr.Button("🔮 LLM Step", variant="secondary")
|
| 193 |
+
llm_ep_btn = gr.Button("⚡ Run Full Episode", variant="primary")
|
| 194 |
+
llm_out = gr.Textbox(label="LLM Output", interactive=False)
|
| 195 |
+
status_md = gr.Markdown("*Press Reset to start*")
|
| 196 |
+
with gr.Column(scale=2):
|
| 197 |
+
gr.Markdown("### 🏭 Factory State")
|
| 198 |
+
state_md = gr.Markdown("*Reset to start*")
|
| 199 |
+
gr.Markdown("### 📊 Score")
|
| 200 |
+
score_md = gr.Markdown("")
|
| 201 |
+
gr.Markdown("### 📋 Step Log")
|
| 202 |
+
log_md = gr.Markdown("*No steps yet*")
|
| 203 |
+
reset_btn.click(reset_env, [task_dd], [state_md, log_md, score_md, status_md])
|
| 204 |
+
manual_btn.click(manual_step, [manual_input], [state_md, log_md, score_md, status_md])
|
| 205 |
+
heuristic_btn.click(heuristic_step, [], [state_md, log_md, score_md, status_md])
|
| 206 |
+
llm_step_btn.click(llm_step, [provider_dd, api_key_box, model_box], [state_md, log_md, score_md, status_md, llm_out])
|
| 207 |
+
llm_ep_btn.click(run_full_episode, [provider_dd, api_key_box, model_box, task_dd], [state_md, log_md, score_md, status_md, llm_out])
|
| 208 |
+
return demo
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
build_ui().launch(server_name="0.0.0.0", server_port=7860, show_error=True, theme=gr.themes.Soft())
|
factory_env/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from factory_env.env import FactoryEnv
|
| 2 |
+
from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
|
factory_env/env.py
CHANGED
|
@@ -1,93 +1,134 @@
|
|
| 1 |
import random
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import random
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from openenv.core import Environment
|
| 5 |
+
|
| 6 |
+
from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
|
| 7 |
+
from factory_env.tasks import TASKS
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]):
|
| 11 |
+
"""Smart Factory Scheduling Environment — OpenEnv compliant."""
|
| 12 |
+
|
| 13 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 14 |
+
|
| 15 |
+
def __init__(self, task: str = "easy", seed: int = 42):
|
| 16 |
+
super().__init__()
|
| 17 |
+
if task not in TASKS:
|
| 18 |
+
raise ValueError(f"Unknown task '{task}'. Choose from: {list(TASKS.keys())}")
|
| 19 |
+
self.task = task
|
| 20 |
+
self.seed = seed
|
| 21 |
+
self.config = TASKS[task]
|
| 22 |
+
self._rng = random.Random(seed)
|
| 23 |
+
self.machines: List[Machine] = []
|
| 24 |
+
self.jobs: List[Job] = []
|
| 25 |
+
self.completed_jobs: List[Job] = []
|
| 26 |
+
self.late_jobs: int = 0
|
| 27 |
+
self.time: int = 0
|
| 28 |
+
self.max_steps: int = self.config["max_steps"]
|
| 29 |
+
|
| 30 |
+
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
|
| 31 |
+
use_seed = seed if seed is not None else self.seed
|
| 32 |
+
self._rng = random.Random(use_seed)
|
| 33 |
+
self.time = 0
|
| 34 |
+
self.completed_jobs = []
|
| 35 |
+
self.late_jobs = 0
|
| 36 |
+
|
| 37 |
+
cfg = self.config
|
| 38 |
+
self.machines = [
|
| 39 |
+
Machine(id=f"M{i+1}", status="idle", failure_rate=cfg.get("failure_rate", 0.0))
|
| 40 |
+
for i in range(cfg["num_machines"])
|
| 41 |
+
]
|
| 42 |
+
self.jobs = []
|
| 43 |
+
for i in range(cfg["num_jobs"]):
|
| 44 |
+
proc_time = self._rng.randint(*cfg["job_time_range"])
|
| 45 |
+
deadline = self.time + proc_time + self._rng.randint(*cfg["deadline_slack"])
|
| 46 |
+
priority = self._rng.randint(1, cfg.get("max_priority", 1))
|
| 47 |
+
self.jobs.append(Job(id=f"J{i+1}", remaining_time=proc_time, deadline=deadline, priority=priority))
|
| 48 |
+
|
| 49 |
+
return self._make_obs(reward=None, done=False)
|
| 50 |
+
|
| 51 |
+
def step(self, action: FactoryAction, timeout_s: Optional[float] = None, **kwargs) -> FactoryObservation:
|
| 52 |
+
reward = 0.0
|
| 53 |
+
|
| 54 |
+
if action.action_type == "assign_job":
|
| 55 |
+
job = self._find_job(action.job_id)
|
| 56 |
+
machine = self._find_machine(action.machine_id)
|
| 57 |
+
if job is None or machine is None or machine.status != "idle":
|
| 58 |
+
reward -= 0.1
|
| 59 |
+
else:
|
| 60 |
+
job.assigned_machine = machine.id
|
| 61 |
+
machine.status = "busy"
|
| 62 |
+
machine.current_job = job.id
|
| 63 |
+
reward += 0.1
|
| 64 |
+
elif action.action_type == "repair":
|
| 65 |
+
machine = self._find_machine(action.machine_id)
|
| 66 |
+
if machine and machine.status == "broken":
|
| 67 |
+
machine.status = "idle"
|
| 68 |
+
reward += 0.05
|
| 69 |
+
else:
|
| 70 |
+
reward -= 0.05
|
| 71 |
+
|
| 72 |
+
self.time += 1
|
| 73 |
+
|
| 74 |
+
for machine in self.machines:
|
| 75 |
+
if machine.status == "busy":
|
| 76 |
+
job = self._find_job(machine.current_job)
|
| 77 |
+
if job:
|
| 78 |
+
job.remaining_time -= 1
|
| 79 |
+
if job.remaining_time <= 0:
|
| 80 |
+
on_time = self.time <= job.deadline
|
| 81 |
+
reward += (1.0 + 0.2 * job.priority) if on_time else 0.3
|
| 82 |
+
if not on_time:
|
| 83 |
+
self.late_jobs += 1
|
| 84 |
+
self.jobs.remove(job)
|
| 85 |
+
self.completed_jobs.append(job)
|
| 86 |
+
machine.status = "idle"
|
| 87 |
+
machine.current_job = None
|
| 88 |
+
|
| 89 |
+
if machine.status == "busy" and machine.failure_rate > 0:
|
| 90 |
+
if self._rng.random() < machine.failure_rate:
|
| 91 |
+
machine.status = "broken"
|
| 92 |
+
stalled = self._find_job(machine.current_job)
|
| 93 |
+
if stalled:
|
| 94 |
+
stalled.assigned_machine = None
|
| 95 |
+
machine.current_job = None
|
| 96 |
+
|
| 97 |
+
if self.jobs:
|
| 98 |
+
reward -= sum(1 for m in self.machines if m.status == "idle") * 0.05
|
| 99 |
+
for job in self.jobs:
|
| 100 |
+
if self.time > job.deadline:
|
| 101 |
+
reward -= 0.1
|
| 102 |
+
|
| 103 |
+
done = self.time >= self.max_steps or len(self.jobs) == 0
|
| 104 |
+
return self._make_obs(reward=reward, done=done)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def state(self) -> FactoryState:
|
| 108 |
+
return FactoryState(
|
| 109 |
+
machines=list(self.machines),
|
| 110 |
+
pending_jobs=list(self.jobs),
|
| 111 |
+
completed_jobs=list(self.completed_jobs),
|
| 112 |
+
time=self.time,
|
| 113 |
+
task=self.task,
|
| 114 |
+
late_jobs=self.late_jobs,
|
| 115 |
+
step_count=self.time,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _make_obs(self, reward, done: bool) -> FactoryObservation:
|
| 119 |
+
return FactoryObservation(
|
| 120 |
+
machines=list(self.machines),
|
| 121 |
+
pending_jobs=list(self.jobs),
|
| 122 |
+
completed_jobs=list(self.completed_jobs),
|
| 123 |
+
time=self.time,
|
| 124 |
+
max_steps=self.max_steps,
|
| 125 |
+
task=self.task,
|
| 126 |
+
reward=reward,
|
| 127 |
+
done=done,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def _find_job(self, job_id: Optional[str]) -> Optional[Job]:
|
| 131 |
+
return next((j for j in self.jobs if j.id == job_id), None) if job_id else None
|
| 132 |
+
|
| 133 |
+
def _find_machine(self, machine_id: Optional[str]) -> Optional[Machine]:
|
| 134 |
+
return next((m for m in self.machines if m.id == machine_id), None) if machine_id else None
|
factory_env/grader.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
| 1 |
-
def compute_score(
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def compute_score(completed, on_time, total_jobs, late_jobs, task="easy"):
|
| 2 |
+
if total_jobs == 0:
|
| 3 |
+
return 0.0
|
| 4 |
+
completion_rate = completed / total_jobs
|
| 5 |
+
on_time_rate = on_time / max(completed, 1)
|
| 6 |
+
utilization_bonus = max(0.0, 1.0 - late_jobs / max(completed, 1))
|
| 7 |
+
score = 0.5 * completion_rate + 0.3 * on_time_rate + 0.2 * utilization_bonus
|
| 8 |
+
return round(max(0.0, min(1.0, score)), 4)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def score_episode(env) -> float:
|
| 12 |
+
total = len(env.completed_jobs) + len(env.jobs)
|
| 13 |
+
completed = len(env.completed_jobs)
|
| 14 |
+
on_time = sum(1 for j in env.completed_jobs if env.time <= j.deadline)
|
| 15 |
+
return compute_score(completed, on_time, total, env.late_jobs, env.task)
|
factory_env/models.py
CHANGED
|
@@ -1,26 +1,55 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class Machine(BaseModel):
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class Job(BaseModel):
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List, Optional
|
| 2 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 3 |
+
from openenv.core import Action as BaseAction, Observation as BaseObservation, State as BaseState
|
| 4 |
+
|
| 5 |
|
| 6 |
class Machine(BaseModel):
|
| 7 |
+
model_config = ConfigDict(extra="forbid")
|
| 8 |
+
id: str
|
| 9 |
+
status: str # idle | busy | broken
|
| 10 |
+
current_job: Optional[str] = None
|
| 11 |
+
failure_rate: float = 0.0
|
| 12 |
+
|
| 13 |
|
| 14 |
class Job(BaseModel):
|
| 15 |
+
model_config = ConfigDict(extra="forbid")
|
| 16 |
+
id: str
|
| 17 |
+
remaining_time: int
|
| 18 |
+
deadline: int
|
| 19 |
+
priority: int = 1
|
| 20 |
+
assigned_machine: Optional[str] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FactoryAction(BaseAction):
|
| 24 |
+
"""
|
| 25 |
+
action_type: assign_job | repair | wait
|
| 26 |
+
job_id: required for assign_job
|
| 27 |
+
machine_id: required for assign_job / repair
|
| 28 |
+
"""
|
| 29 |
+
action_type: str
|
| 30 |
+
job_id: Optional[str] = None
|
| 31 |
+
machine_id: Optional[str] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class FactoryObservation(BaseObservation):
|
| 35 |
+
"""Inherits: done (bool), reward (float|None), metadata (dict)"""
|
| 36 |
+
machines: List[Machine] = Field(default_factory=list)
|
| 37 |
+
pending_jobs: List[Job] = Field(default_factory=list)
|
| 38 |
+
completed_jobs: List[Job] = Field(default_factory=list)
|
| 39 |
+
time: int = 0
|
| 40 |
+
max_steps: int = 20
|
| 41 |
+
task: str = "easy"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FactoryState(BaseState):
|
| 45 |
+
machines: List[Machine] = Field(default_factory=list)
|
| 46 |
+
pending_jobs: List[Job] = Field(default_factory=list)
|
| 47 |
+
completed_jobs: List[Job] = Field(default_factory=list)
|
| 48 |
+
time: int = 0
|
| 49 |
+
task: str = "easy"
|
| 50 |
+
late_jobs: int = 0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Aliases for backward compatibility
|
| 54 |
+
Action = FactoryAction
|
| 55 |
+
Observation = FactoryObservation
|
factory_env/tasks.py
CHANGED
|
@@ -1,17 +1,35 @@
|
|
| 1 |
TASKS = {
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
}
|
|
|
|
| 1 |
TASKS = {
|
| 2 |
+
"easy": {
|
| 3 |
+
"num_machines": 2,
|
| 4 |
+
"num_jobs": 3,
|
| 5 |
+
"failures": False,
|
| 6 |
+
"failure_rate": 0.0,
|
| 7 |
+
"max_priority": 1,
|
| 8 |
+
"job_time_range": (2, 5),
|
| 9 |
+
"deadline_slack": (4, 8),
|
| 10 |
+
"max_steps": 20,
|
| 11 |
+
"description": "Assign 3 jobs to 2 machines with no failures.",
|
| 12 |
+
},
|
| 13 |
+
"medium": {
|
| 14 |
+
"num_machines": 4,
|
| 15 |
+
"num_jobs": 7,
|
| 16 |
+
"failures": True,
|
| 17 |
+
"failure_rate": 0.08,
|
| 18 |
+
"max_priority": 2,
|
| 19 |
+
"job_time_range": (3, 7),
|
| 20 |
+
"deadline_slack": (2, 5),
|
| 21 |
+
"max_steps": 30,
|
| 22 |
+
"description": "Manage 7 jobs across 4 machines with random breakdowns.",
|
| 23 |
+
},
|
| 24 |
+
"hard": {
|
| 25 |
+
"num_machines": 6,
|
| 26 |
+
"num_jobs": 12,
|
| 27 |
+
"failures": True,
|
| 28 |
+
"failure_rate": 0.15,
|
| 29 |
+
"max_priority": 3,
|
| 30 |
+
"job_time_range": (3, 8),
|
| 31 |
+
"deadline_slack": (1, 4),
|
| 32 |
+
"max_steps": 40,
|
| 33 |
+
"description": "Optimize throughput across 12 jobs and 6 machines under frequent failures.",
|
| 34 |
+
},
|
| 35 |
}
|
inference.py
CHANGED
|
@@ -1,248 +1,144 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
===================================
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
import asyncio
|
| 8 |
import os
|
| 9 |
import textwrap
|
| 10 |
-
from typing import List, Optional
|
| 11 |
|
| 12 |
from openai import OpenAI
|
| 13 |
|
| 14 |
from factory_env.env import FactoryEnv
|
| 15 |
-
from factory_env.models import Action
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 22 |
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 23 |
-
|
| 24 |
TASK_NAME = os.getenv("FACTORY_TASK", "easy")
|
| 25 |
BENCHMARK = "factory_env"
|
| 26 |
-
|
| 27 |
-
MAX_STEPS = 20
|
| 28 |
TEMPERATURE = 0.2
|
| 29 |
-
MAX_TOKENS =
|
| 30 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
Available actions:
|
| 45 |
-
1. assign_job <job_id> <machine_id>
|
| 46 |
-
2. wait
|
| 47 |
-
|
| 48 |
-
Rules:
|
| 49 |
-
- Only assign jobs that exist
|
| 50 |
-
- Only assign to idle machines
|
| 51 |
-
- One action per step
|
| 52 |
-
|
| 53 |
-
Respond ONLY with the action string.
|
| 54 |
-
Example:
|
| 55 |
-
assign_job J1 M1
|
| 56 |
-
"""
|
| 57 |
-
).strip()
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# =========================
|
| 61 |
-
# LOGGING FUNCTIONS (STRICT FORMAT)
|
| 62 |
-
# =========================
|
| 63 |
def log_start(task: str, env: str, model: str) -> None:
|
| 64 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 65 |
|
| 66 |
|
| 67 |
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 68 |
-
|
| 69 |
-
done_val = str(done).lower()
|
| 70 |
-
print(
|
| 71 |
-
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 72 |
-
flush=True,
|
| 73 |
-
)
|
| 74 |
|
| 75 |
|
| 76 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 77 |
-
|
| 78 |
-
print(
|
| 79 |
-
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 80 |
-
flush=True,
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
# =========================
|
| 85 |
-
# PROMPT BUILDER
|
| 86 |
-
# =========================
|
| 87 |
-
def build_user_prompt(step, obs, last_reward):
|
| 88 |
-
machines_str = "\n".join(
|
| 89 |
-
[f"{m.id}: {m.status} (job={m.current_job})" for m in obs.machines]
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
jobs_str = "\n".join(
|
| 93 |
-
[f"{j.id}: remaining={j.remaining_time}, deadline={j.deadline}" for j in obs.pending_jobs]
|
| 94 |
-
) or "None"
|
| 95 |
-
|
| 96 |
-
return textwrap.dedent(
|
| 97 |
-
f"""
|
| 98 |
-
Step: {step}
|
| 99 |
|
| 100 |
-
Current Time: {obs.time}
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
Pending Jobs:
|
| 106 |
-
{jobs_str}
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
What action do you take?
|
| 111 |
-
"""
|
| 112 |
-
).strip()
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
# =========================
|
| 116 |
-
# LLM CALL
|
| 117 |
-
# =========================
|
| 118 |
-
def get_model_action(client: OpenAI, step, obs, last_reward) -> str:
|
| 119 |
try:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
completion = client.chat.completions.create(
|
| 123 |
model=MODEL_NAME,
|
| 124 |
-
messages=[
|
| 125 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 126 |
-
{"role": "user", "content": user_prompt},
|
| 127 |
-
],
|
| 128 |
temperature=TEMPERATURE,
|
| 129 |
max_tokens=MAX_TOKENS,
|
| 130 |
)
|
| 131 |
-
|
| 132 |
-
text = (completion.choices[0].message.content or "").strip()
|
| 133 |
-
return text if text else "wait"
|
| 134 |
-
|
| 135 |
except Exception as e:
|
| 136 |
print(f"[DEBUG] LLM error: {e}", flush=True)
|
| 137 |
return "wait"
|
| 138 |
|
| 139 |
|
| 140 |
-
# =========================
|
| 141 |
-
# ACTION PARSER
|
| 142 |
-
# =========================
|
| 143 |
def parse_action(text: str) -> Action:
|
| 144 |
try:
|
| 145 |
parts = text.strip().split()
|
| 146 |
-
|
| 147 |
if parts[0] == "assign_job" and len(parts) == 3:
|
| 148 |
-
return Action(
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
machine_id=parts[2],
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
elif parts[0] == "wait":
|
| 155 |
-
return Action(action_type="wait")
|
| 156 |
-
|
| 157 |
except Exception:
|
| 158 |
pass
|
| 159 |
-
|
| 160 |
-
# fallback safe action
|
| 161 |
return Action(action_type="wait")
|
| 162 |
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
for
|
| 169 |
-
for
|
| 170 |
-
if
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
machine_id=machine.id,
|
| 175 |
-
)
|
| 176 |
-
return Action(action_type="wait")
|
| 177 |
|
| 178 |
|
| 179 |
-
|
| 180 |
-
# MAIN LOOP
|
| 181 |
-
# =========================
|
| 182 |
-
async def main():
|
| 183 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 184 |
-
|
| 185 |
-
env = FactoryEnv(task=TASK_NAME)
|
| 186 |
-
|
| 187 |
rewards: List[float] = []
|
| 188 |
steps_taken = 0
|
| 189 |
score = 0.0
|
| 190 |
success = False
|
| 191 |
|
| 192 |
-
log_start(task=
|
| 193 |
|
| 194 |
try:
|
| 195 |
-
|
| 196 |
-
obs = result.observation
|
| 197 |
last_reward = 0.0
|
| 198 |
|
| 199 |
-
for step in range(1,
|
| 200 |
-
if
|
| 201 |
break
|
| 202 |
-
|
| 203 |
-
# LLM decision
|
| 204 |
action_text = get_model_action(client, step, obs, last_reward)
|
| 205 |
-
|
| 206 |
-
# Parse action
|
| 207 |
action = parse_action(action_text)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
action_text = "heuristic_assign"
|
| 213 |
-
|
| 214 |
-
# Step env
|
| 215 |
-
result = await env.step(action)
|
| 216 |
-
|
| 217 |
-
obs = result.observation
|
| 218 |
-
reward = result.reward or 0.0
|
| 219 |
-
done = result.done
|
| 220 |
-
error = None
|
| 221 |
-
|
| 222 |
rewards.append(reward)
|
| 223 |
steps_taken = step
|
| 224 |
last_reward = reward
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
if done:
|
| 229 |
break
|
| 230 |
|
| 231 |
-
|
| 232 |
-
if rewards:
|
| 233 |
-
score = sum(rewards) / len(rewards)
|
| 234 |
-
score = max(0.0, min(1.0, score))
|
| 235 |
-
|
| 236 |
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 237 |
-
|
| 238 |
finally:
|
| 239 |
-
try:
|
| 240 |
-
await env.close()
|
| 241 |
-
except Exception as e:
|
| 242 |
-
print(f"[DEBUG] env.close error: {e}", flush=True)
|
| 243 |
-
|
| 244 |
log_end(success, steps_taken, score, rewards)
|
| 245 |
|
| 246 |
|
| 247 |
if __name__ == "__main__":
|
| 248 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
+
Inference Script — Smart Factory Scheduling Environment
|
| 3 |
+
========================================================
|
| 4 |
+
Mandatory env vars (per hackathon spec):
|
| 5 |
+
OPENAI_API_KEY API key (also accepts HF_TOKEN for HF router)
|
| 6 |
+
API_BASE_URL LLM endpoint (default: HF router)
|
| 7 |
+
MODEL_NAME Model ID (default: Qwen/Qwen2.5-72B-Instruct)
|
| 8 |
+
HF_TOKEN HuggingFace token
|
| 9 |
+
FACTORY_TASK easy | medium | hard (default: easy)
|
| 10 |
+
|
| 11 |
+
STDOUT FORMAT:
|
| 12 |
+
[START] task=<name> env=factory_env model=<model>
|
| 13 |
+
[STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 14 |
+
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
|
| 15 |
"""
|
| 16 |
|
|
|
|
| 17 |
import os
|
| 18 |
import textwrap
|
| 19 |
+
from typing import List, Optional, Tuple
|
| 20 |
|
| 21 |
from openai import OpenAI
|
| 22 |
|
| 23 |
from factory_env.env import FactoryEnv
|
| 24 |
+
from factory_env.models import FactoryAction as Action
|
| 25 |
+
from factory_env.grader import score_episode
|
| 26 |
+
|
| 27 |
+
API_KEY = (
|
| 28 |
+
os.getenv("OPENAI_API_KEY")
|
| 29 |
+
or os.getenv("HF_TOKEN")
|
| 30 |
+
or os.getenv("API_KEY")
|
| 31 |
+
)
|
| 32 |
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 33 |
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
|
|
|
| 34 |
TASK_NAME = os.getenv("FACTORY_TASK", "easy")
|
| 35 |
BENCHMARK = "factory_env"
|
|
|
|
|
|
|
| 36 |
TEMPERATURE = 0.2
|
| 37 |
+
MAX_TOKENS = 80
|
| 38 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 39 |
|
| 40 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 41 |
+
You are controlling a smart factory scheduling system.
|
| 42 |
+
Goal: complete all jobs before their deadlines, keep machines busy, repair broken machines.
|
| 43 |
+
Actions (respond with EXACTLY one line):
|
| 44 |
+
assign_job <job_id> <machine_id>
|
| 45 |
+
repair <machine_id>
|
| 46 |
+
wait
|
| 47 |
+
Respond with ONLY the action string.
|
| 48 |
+
""").strip()
|
| 49 |
+
|
| 50 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def log_start(task: str, env: str, model: str) -> None:
|
| 52 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 53 |
|
| 54 |
|
| 55 |
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 56 |
+
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 60 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
|
|
|
| 62 |
|
| 63 |
+
def build_prompt(step: int, obs, last_reward: float) -> str:
|
| 64 |
+
machines = "\n".join(f" {m.id}: {m.status}" + (f" ({m.current_job})" if m.current_job else "") for m in obs.machines)
|
| 65 |
+
jobs = "\n".join(f" {j.id}: remaining={j.remaining_time}, deadline={j.deadline}, priority={j.priority}" for j in obs.pending_jobs) or " (none)"
|
| 66 |
+
return f"Step {step}/{obs.max_steps} | time={obs.time} | last_reward={last_reward:+.2f}\nMachines:\n{machines}\nPending Jobs:\n{jobs}\nAction:"
|
| 67 |
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
def get_model_action(client: OpenAI, step: int, obs, last_reward: float) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
try:
|
| 71 |
+
resp = client.chat.completions.create(
|
|
|
|
|
|
|
| 72 |
model=MODEL_NAME,
|
| 73 |
+
messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_prompt(step, obs, last_reward)}],
|
|
|
|
|
|
|
|
|
|
| 74 |
temperature=TEMPERATURE,
|
| 75 |
max_tokens=MAX_TOKENS,
|
| 76 |
)
|
| 77 |
+
return (resp.choices[0].message.content or "wait").strip().splitlines()[0]
|
|
|
|
|
|
|
|
|
|
| 78 |
except Exception as e:
|
| 79 |
print(f"[DEBUG] LLM error: {e}", flush=True)
|
| 80 |
return "wait"
|
| 81 |
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
def parse_action(text: str) -> Action:
|
| 84 |
try:
|
| 85 |
parts = text.strip().split()
|
|
|
|
| 86 |
if parts[0] == "assign_job" and len(parts) == 3:
|
| 87 |
+
return Action(action_type="assign_job", job_id=parts[1], machine_id=parts[2])
|
| 88 |
+
if parts[0] == "repair" and len(parts) == 2:
|
| 89 |
+
return Action(action_type="repair", machine_id=parts[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
except Exception:
|
| 91 |
pass
|
|
|
|
|
|
|
| 92 |
return Action(action_type="wait")
|
| 93 |
|
| 94 |
|
| 95 |
+
def heuristic_action(obs) -> Tuple[Action, str]:
|
| 96 |
+
for m in obs.machines:
|
| 97 |
+
if m.status == "broken":
|
| 98 |
+
return Action(action_type="repair", machine_id=m.id), f"repair {m.id}"
|
| 99 |
+
for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
|
| 100 |
+
for m in obs.machines:
|
| 101 |
+
if m.status == "idle":
|
| 102 |
+
s = f"assign_job {j.id} {m.id}"
|
| 103 |
+
return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), s
|
| 104 |
+
return Action(action_type="wait"), "wait"
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
+
def run_task(task_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
| 108 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 109 |
+
env = FactoryEnv(task=task_name)
|
|
|
|
|
|
|
| 110 |
rewards: List[float] = []
|
| 111 |
steps_taken = 0
|
| 112 |
score = 0.0
|
| 113 |
success = False
|
| 114 |
|
| 115 |
+
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 116 |
|
| 117 |
try:
|
| 118 |
+
obs = env.reset()
|
|
|
|
| 119 |
last_reward = 0.0
|
| 120 |
|
| 121 |
+
for step in range(1, obs.max_steps + 1):
|
| 122 |
+
if obs.done:
|
| 123 |
break
|
|
|
|
|
|
|
| 124 |
action_text = get_model_action(client, step, obs, last_reward)
|
|
|
|
|
|
|
| 125 |
action = parse_action(action_text)
|
| 126 |
+
if action.action_type == "wait" and (obs.pending_jobs or any(m.status == "broken" for m in obs.machines)):
|
| 127 |
+
action, action_text = heuristic_action(obs)
|
| 128 |
+
obs = env.step(action)
|
| 129 |
+
reward = obs.reward or 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
rewards.append(reward)
|
| 131 |
steps_taken = step
|
| 132 |
last_reward = reward
|
| 133 |
+
log_step(step, action_text, reward, obs.done, None)
|
| 134 |
+
if obs.done:
|
|
|
|
|
|
|
| 135 |
break
|
| 136 |
|
| 137 |
+
score = score_episode(env)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
success = score >= SUCCESS_SCORE_THRESHOLD
|
|
|
|
| 139 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
log_end(success, steps_taken, score, rewards)
|
| 141 |
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
| 144 |
+
run_task(TASK_NAME)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
pydantic
|
| 2 |
-
openai
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pydantic>=2.0
|
| 2 |
+
openai>=1.0
|
| 3 |
+
anthropic>=0.90
|
| 4 |
+
gradio>=6.0
|
| 5 |
+
openenv-core>=0.2.3
|
| 6 |
+
fastapi>=0.100
|
| 7 |
+
uvicorn>=0.23
|
server.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv HTTP Server — Smart Factory Scheduling
|
| 3 |
+
Routes: GET /health POST /reset POST /step GET /state GET /schema
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from openenv.core import create_app
|
| 7 |
+
from factory_env.env import FactoryEnv
|
| 8 |
+
from factory_env.models import FactoryAction, FactoryObservation
|
| 9 |
+
|
| 10 |
+
TASK = os.getenv("FACTORY_TASK", "easy")
|
| 11 |
+
|
| 12 |
+
app = create_app(
|
| 13 |
+
env=lambda: FactoryEnv(task=TASK, seed=42),
|
| 14 |
+
action_cls=FactoryAction,
|
| 15 |
+
observation_cls=FactoryObservation,
|
| 16 |
+
env_name="factory_env",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
import uvicorn
|
| 21 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|
train.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RL Training Loop — Smart Factory Scheduling
|
| 3 |
+
============================================
|
| 4 |
+
Strategy: Online In-Context RL — best trajectory fed as few-shot example each episode.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
export OPENAI_API_KEY=sk-... # OpenAI
|
| 8 |
+
export ANTHROPIC_API_KEY=sk-ant-... # Claude
|
| 9 |
+
python train.py --task easy --episodes 10 --provider openai
|
| 10 |
+
python train.py --task medium --episodes 10 --provider claude
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
from factory_env.env import FactoryEnv
|
| 22 |
+
from factory_env.grader import score_episode
|
| 23 |
+
from factory_env.models import FactoryAction as Action
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_openai_client():
|
| 27 |
+
from openai import OpenAI
|
| 28 |
+
key = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 29 |
+
base = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
|
| 30 |
+
return OpenAI(api_key=key, base_url=base)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_claude_client():
|
| 34 |
+
import anthropic
|
| 35 |
+
return anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class Step:
|
| 40 |
+
step: int
|
| 41 |
+
obs_text: str
|
| 42 |
+
action_text: str
|
| 43 |
+
reward: float
|
| 44 |
+
done: bool
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class Episode:
|
| 49 |
+
episode_num: int
|
| 50 |
+
task: str
|
| 51 |
+
steps: List[Step] = field(default_factory=list)
|
| 52 |
+
total_reward: float = 0.0
|
| 53 |
+
score: float = 0.0
|
| 54 |
+
completed: int = 0
|
| 55 |
+
late: int = 0
|
| 56 |
+
|
| 57 |
+
def to_few_shot(self, max_steps: int = 6) -> str:
|
| 58 |
+
lines = [f"# Best trajectory so far (score={self.score:.2f}, completed={self.completed} jobs)"]
|
| 59 |
+
for s in self.steps[:max_steps]:
|
| 60 |
+
lines.append(f"[Obs] {s.obs_text}")
|
| 61 |
+
lines.append(f"[Action] {s.action_text} → reward: {s.reward:+.2f}")
|
| 62 |
+
return "\n".join(lines)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
SYSTEM_PROMPT = """You are an expert factory scheduling AI.
|
| 66 |
+
Goal: complete all jobs before deadlines, keep machines busy, repair broken machines.
|
| 67 |
+
Actions (one per step):
|
| 68 |
+
assign_job <job_id> <machine_id>
|
| 69 |
+
repair <machine_id>
|
| 70 |
+
wait
|
| 71 |
+
Tips: Fix broken machines first. Sort by earliest deadline. High-priority jobs give bonus reward."""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def obs_to_text(obs) -> str:
|
| 75 |
+
machines = ", ".join(f"{m.id}:{m.status}" + (f"({m.current_job})" if m.current_job else "") for m in obs.machines)
|
| 76 |
+
jobs = ", ".join(f"{j.id}[t={j.remaining_time},dl={j.deadline},p={j.priority}]" for j in obs.pending_jobs) or "none"
|
| 77 |
+
return f"t={obs.time} | machines: {machines} | pending: {jobs}"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def call_llm(messages: list, provider: str, client, model: str) -> str:
|
| 81 |
+
try:
|
| 82 |
+
if provider == "claude":
|
| 83 |
+
system = next((m["content"] for m in messages if m["role"] == "system"), "")
|
| 84 |
+
user_msgs = [m for m in messages if m["role"] != "system"]
|
| 85 |
+
resp = client.messages.create(model=model, max_tokens=60, system=system, messages=user_msgs)
|
| 86 |
+
return resp.content[0].text.strip().splitlines()[0]
|
| 87 |
+
else:
|
| 88 |
+
resp = client.chat.completions.create(model=model, messages=messages, temperature=0.2, max_tokens=60)
|
| 89 |
+
return (resp.choices[0].message.content or "wait").strip().splitlines()[0]
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f" [LLM error] {e}")
|
| 92 |
+
return "wait"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parse_action(text: str) -> Action:
|
| 96 |
+
try:
|
| 97 |
+
parts = text.strip().split()
|
| 98 |
+
if parts[0] == "assign_job" and len(parts) == 3:
|
| 99 |
+
return Action(action_type="assign_job", job_id=parts[1], machine_id=parts[2])
|
| 100 |
+
if parts[0] == "repair" and len(parts) == 2:
|
| 101 |
+
return Action(action_type="repair", machine_id=parts[1])
|
| 102 |
+
except Exception:
|
| 103 |
+
pass
|
| 104 |
+
return Action(action_type="wait")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def heuristic_action(obs) -> Tuple[Action, str]:
|
| 108 |
+
for m in obs.machines:
|
| 109 |
+
if m.status == "broken":
|
| 110 |
+
return Action(action_type="repair", machine_id=m.id), f"repair {m.id}"
|
| 111 |
+
for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
|
| 112 |
+
for m in obs.machines:
|
| 113 |
+
if m.status == "idle":
|
| 114 |
+
s = f"assign_job {j.id} {m.id}"
|
| 115 |
+
return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), s
|
| 116 |
+
return Action(action_type="wait"), "wait"
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def run_episode(task, episode_num, provider, client, model, best_episode, seed=42, verbose=True) -> Episode:
|
| 120 |
+
env = FactoryEnv(task=task, seed=seed)
|
| 121 |
+
obs = env.reset()
|
| 122 |
+
last_reward = 0.0
|
| 123 |
+
ep = Episode(episode_num=episode_num, task=task)
|
| 124 |
+
|
| 125 |
+
if verbose:
|
| 126 |
+
print(f"\n Episode {episode_num} | task={task} | seed={seed}")
|
| 127 |
+
print(f" {len(obs.machines)} machines, {len(obs.pending_jobs)} jobs, {obs.max_steps} steps")
|
| 128 |
+
|
| 129 |
+
for step in range(1, obs.max_steps + 1):
|
| 130 |
+
if obs.done:
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
obs_text = obs_to_text(obs)
|
| 134 |
+
few_shot = best_episode.to_few_shot() if best_episode and step == 1 else ""
|
| 135 |
+
user = f"{few_shot}\n\n---\n" if few_shot else ""
|
| 136 |
+
user += f"Step {step} | Last reward: {last_reward:+.2f}\n{obs_text}\n\nAction:"
|
| 137 |
+
|
| 138 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}]
|
| 139 |
+
action_text = call_llm(messages, provider, client, model)
|
| 140 |
+
action = parse_action(action_text)
|
| 141 |
+
|
| 142 |
+
if action.action_type == "wait" and (obs.pending_jobs or any(m.status == "broken" for m in obs.machines)):
|
| 143 |
+
action, action_text = heuristic_action(obs)
|
| 144 |
+
|
| 145 |
+
obs = env.step(action)
|
| 146 |
+
reward = obs.reward or 0.0
|
| 147 |
+
last_reward = reward
|
| 148 |
+
ep.steps.append(Step(step, obs_text, action_text, reward, obs.done))
|
| 149 |
+
ep.total_reward += reward
|
| 150 |
+
|
| 151 |
+
if verbose:
|
| 152 |
+
marker = "✓" if reward > 0.5 else ("✗" if reward < -0.05 else "·")
|
| 153 |
+
print(f" [{marker}] step={step:2d} {action_text:<30s} r={reward:+.2f}")
|
| 154 |
+
|
| 155 |
+
if obs.done:
|
| 156 |
+
break
|
| 157 |
+
|
| 158 |
+
ep.score = score_episode(env)
|
| 159 |
+
ep.completed = len(env.completed_jobs)
|
| 160 |
+
ep.late = env.late_jobs
|
| 161 |
+
|
| 162 |
+
if verbose:
|
| 163 |
+
print(f" → score={ep.score:.4f} completed={ep.completed} late={ep.late}")
|
| 164 |
+
|
| 165 |
+
return ep
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def train(task, num_episodes, provider, model, save_dir="runs", verbose=True):
|
| 169 |
+
print(f"\n{'='*60}")
|
| 170 |
+
print(f" Smart Factory RL Training")
|
| 171 |
+
print(f" Task: {task} | Episodes: {num_episodes} | Provider: {provider} | Model: {model}")
|
| 172 |
+
print(f"{'='*60}")
|
| 173 |
+
|
| 174 |
+
client = get_claude_client() if provider == "claude" else get_openai_client()
|
| 175 |
+
Path(save_dir).mkdir(exist_ok=True)
|
| 176 |
+
|
| 177 |
+
scores = []
|
| 178 |
+
best_episode = None
|
| 179 |
+
|
| 180 |
+
for ep_num in range(1, num_episodes + 1):
|
| 181 |
+
ep = run_episode(task, ep_num, provider, client, model, best_episode, seed=42 + ep_num - 1, verbose=verbose)
|
| 182 |
+
scores.append(ep.score)
|
| 183 |
+
if best_episode is None or ep.score > best_episode.score:
|
| 184 |
+
best_episode = ep
|
| 185 |
+
print(f" ★ New best: score={ep.score:.4f}")
|
| 186 |
+
if ep_num < num_episodes:
|
| 187 |
+
time.sleep(1.0)
|
| 188 |
+
|
| 189 |
+
print(f"\n{'='*60}")
|
| 190 |
+
print(f" Training Complete — {num_episodes} episodes | Task: {task}")
|
| 191 |
+
print(f" First: {scores[0]:.4f} | Last: {scores[-1]:.4f} | Best: {max(scores):.4f}")
|
| 192 |
+
print(f"\n Score per episode:")
|
| 193 |
+
for i, s in enumerate(scores, 1):
|
| 194 |
+
print(f" ep{i:02d}: {s:.4f} {'█' * int(s * 20)}")
|
| 195 |
+
|
| 196 |
+
out = Path(save_dir) / f"{task}_{provider}_{num_episodes}ep.json"
|
| 197 |
+
out.write_text(json.dumps({"task": task, "provider": provider, "model": model, "num_episodes": num_episodes, "scores": scores, "best_score": max(scores), "final_score": scores[-1]}, indent=2))
|
| 198 |
+
print(f"\n Results saved → {out}")
|
| 199 |
+
return scores
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def main():
|
| 203 |
+
parser = argparse.ArgumentParser()
|
| 204 |
+
parser.add_argument("--task", default="easy", choices=["easy", "medium", "hard"])
|
| 205 |
+
parser.add_argument("--episodes", type=int, default=5)
|
| 206 |
+
parser.add_argument("--provider", default="openai", choices=["openai", "claude"])
|
| 207 |
+
parser.add_argument("--model", default="")
|
| 208 |
+
parser.add_argument("--save-dir", default="runs")
|
| 209 |
+
parser.add_argument("--quiet", action="store_true")
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
if not args.model:
|
| 212 |
+
args.model = "claude-sonnet-4-6" if args.provider == "claude" else "gpt-4o-mini"
|
| 213 |
+
train(args.task, args.episodes, args.provider, args.model, args.save_dir, not args.quiet)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
main()
|