Aldrimore commited on
Commit
6195f6a
·
1 Parent(s): bc0dd7f

OpenEnv Submission

Browse files
Files changed (13) hide show
  1. .gitignore +38 -0
  2. Dockerfile +17 -0
  3. README.md +105 -1
  4. app.py +212 -0
  5. factory_env/__init__.py +2 -0
  6. factory_env/env.py +133 -92
  7. factory_env/grader.py +15 -3
  8. factory_env/models.py +50 -21
  9. factory_env/tasks.py +33 -15
  10. inference.py +71 -175
  11. requirements.txt +7 -3
  12. server.py +21 -0
  13. train.py +217 -0
.gitignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.egg-info/
6
+ dist/
7
+ build/
8
+
9
+ # Virtual environments
10
+ venv/
11
+ .venv/
12
+ env/
13
+
14
+ # Secrets
15
+ .env
16
+ .env.*
17
+
18
+ # OS
19
+ .DS_Store
20
+ Thumbs.db
21
+
22
+ # IDE
23
+ .vscode/
24
+ .idea/
25
+
26
+ # Logs
27
+ *.log
28
+
29
+ # Training runs
30
+ runs/
31
+
32
+ # Docker
33
+ *.tar
34
+
35
+ # Hackathon docs
36
+ rule.txt
37
+ "Meta RL Hackathon.docx"
38
+ Meta\ RL\ Hackathon.docx
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ ENV FACTORY_TASK=easy
11
+ ENV API_BASE_URL=https://router.huggingface.co/v1
12
+ ENV MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
13
+ ENV PORT=7860
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["python", "server.py"]
README.md CHANGED
@@ -1 +1,105 @@
1
- # OpenEnvRLScheduling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Smart Factory Scheduling Environment
2
+
3
+ An [OpenEnv](https://github.com/openenv/openenv)-compliant RL environment simulating real-world industrial scheduling: assign jobs to machines, handle breakdowns, and maximise throughput within deadlines.
4
+
5
+ ## Observation Space
6
+
7
+ | Field | Type | Description |
8
+ |-------|------|-------------|
9
+ | `machines` | List[Machine] | id, status (idle/busy/broken), current_job, failure_rate |
10
+ | `pending_jobs` | List[Job] | id, remaining_time, deadline, priority (1-3), assigned_machine |
11
+ | `completed_jobs` | List[Job] | Jobs finished this episode |
12
+ | `time` | int | Current time step |
13
+ | `max_steps` | int | Episode length |
14
+ | `done` | bool | Episode terminated |
15
+ | `reward` | float | Reward from last action |
16
+
17
+ ## Action Space
18
+
19
+ | Action | Effect |
20
+ |--------|--------|
21
+ | `assign_job <job_id> <machine_id>` | Assign pending job to idle machine |
22
+ | `repair <machine_id>` | Restore broken machine to idle |
23
+ | `wait` | Advance time with no change |
24
+
25
+ ## Reward Function
26
+
27
+ | Event | Reward |
28
+ |-------|--------|
29
+ | Job completed on time | +1.00 + 0.20 × priority |
30
+ | Job completed late | +0.30 |
31
+ | Valid assignment | +0.10 |
32
+ | Invalid action | −0.10 |
33
+ | Idle machine (pending jobs exist) | −0.05 per machine |
34
+ | Job past deadline | −0.10 per step |
35
+ | Repair broken machine | +0.05 |
36
+
37
+ ## Tasks
38
+
39
+ | Task | Machines | Jobs | Failure Rate | Max Steps | Baseline Score |
40
+ |------|----------|------|-------------|-----------|----------------|
41
+ | easy | 2 | 3 | 0% | 20 | 1.000 |
42
+ | medium | 4 | 7 | 8% | 30 | ~0.557 |
43
+ | hard | 6 | 12 | 15% | 40 | ~0.457 |
44
+
45
+ **Score formula:** `0.5 × completion_rate + 0.3 × on_time_rate + 0.2 × utilization_bonus`
46
+
47
+ ## Setup
48
+
49
+ ```bash
50
+ pip install -r requirements.txt
51
+ ```
52
+
53
+ ### Run HTTP Server (HF Space)
54
+ ```bash
55
+ python server.py
56
+ # Routes: GET /health POST /reset POST /step GET /state GET /schema
57
+ ```
58
+
59
+ ### Run Inference (LLM agent)
60
+ ```bash
61
+ export OPENAI_API_KEY=<your-key>
62
+ export FACTORY_TASK=easy # easy | medium | hard
63
+ python inference.py
64
+ ```
65
+
66
+ ### Run RL Training
67
+ ```bash
68
+ python train.py --task easy --episodes 10 --provider openai
69
+ python train.py --task medium --episodes 10 --provider claude
70
+ ```
71
+
72
+ ### Interactive Demo
73
+ ```bash
74
+ python app.py # opens at http://localhost:7860
75
+ ```
76
+
77
+ ### Docker
78
+ ```bash
79
+ docker build -t factory-env .
80
+ docker run -e OPENAI_API_KEY=<key> -e FACTORY_TASK=easy -p 7860:7860 factory-env
81
+ ```
82
+
83
+ ## Baseline Scores
84
+
85
+ | Task | Score | Steps |
86
+ |------|-------|-------|
87
+ | easy | 1.000 | 4 |
88
+ | medium | ~0.529 | 12 |
89
+ | hard | ~0.533 | 34 |
90
+
91
+ ## Project Structure
92
+
93
+ ```
94
+ ├── factory_env/
95
+ │ ├── env.py # FactoryEnv (openenv.core.Environment)
96
+ │ ├── models.py # FactoryAction, FactoryObservation, FactoryState
97
+ │ ├── tasks.py # Task configurations
98
+ │ └── grader.py # Score computation
99
+ ├── inference.py # LLM baseline agent
100
+ ├── train.py # Multi-episode RL training loop
101
+ ├── server.py # FastAPI HTTP server for HF Space
102
+ ├── app.py # Gradio interactive demo
103
+ ├── openenv.yaml # OpenEnv metadata
104
+ └── Dockerfile
105
+ ```
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smart Factory Scheduling — Interactive Gradio Demo
3
+ Run: python app.py → http://localhost:7860
4
+ """
5
+ import asyncio, os
6
+ from typing import List, Optional, Tuple
7
+ import gradio as gr
8
+ from factory_env.env import FactoryEnv
9
+ from factory_env.grader import score_episode
10
+ from factory_env.models import FactoryAction as Action
11
+
12
+ _env: Optional[FactoryEnv] = None
13
+ _obs = None
14
+ _rewards: List[float] = []
15
+ _history: List[dict] = []
16
+ _step_num: int = 0
17
+
18
+ STATUS_EMOJI = {"idle": "🟢", "busy": "🔵", "broken": "🔴"}
19
+ SYSTEM_PROMPT = "You are a factory scheduler. Reply with ONE action:\n assign_job <job_id> <machine_id>\n repair <machine_id>\n wait"
20
+
21
+
22
+ def _llm_client(provider, api_key):
23
+ if "Claude" in provider:
24
+ import anthropic
25
+ return ("claude", anthropic.Anthropic(api_key=api_key or os.getenv("ANTHROPIC_API_KEY")))
26
+ from openai import OpenAI
27
+ base = "https://api.openai.com/v1" if "OpenAI" in provider else "https://router.huggingface.co/v1"
28
+ return ("openai", OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN"), base_url=base))
29
+
30
+
31
+ def _call_llm(provider_tuple, model, obs, last_reward, step):
32
+ kind, client = provider_tuple
33
+ machines = "\n".join(f" {m.id}: {m.status}" + (f" ({m.current_job})" if m.current_job else "") for m in obs.machines)
34
+ jobs = "\n".join(f" {j.id}: t={j.remaining_time} dl={j.deadline} p={j.priority}" for j in obs.pending_jobs) or " (none)"
35
+ user = f"Step {step}/{obs.max_steps} | t={obs.time} | reward={last_reward:+.2f}\nMachines:\n{machines}\nJobs:\n{jobs}\nAction:"
36
+ try:
37
+ if kind == "claude":
38
+ r = client.messages.create(model=model, max_tokens=50, system=SYSTEM_PROMPT, messages=[{"role":"user","content":user}])
39
+ return r.content[0].text.strip().splitlines()[0]
40
+ else:
41
+ r = client.chat.completions.create(model=model, temperature=0.2, max_tokens=50,
42
+ messages=[{"role":"system","content":SYSTEM_PROMPT},{"role":"user","content":user}])
43
+ return (r.choices[0].message.content or "wait").strip().splitlines()[0]
44
+ except Exception as e:
45
+ return f"wait # {e}"
46
+
47
+
48
+ def _parse(text):
49
+ try:
50
+ p = text.strip().split()
51
+ if p[0] == "assign_job" and len(p) == 3: return Action(action_type="assign_job", job_id=p[1], machine_id=p[2])
52
+ if p[0] == "repair" and len(p) == 2: return Action(action_type="repair", machine_id=p[1])
53
+ except: pass
54
+ return Action(action_type="wait")
55
+
56
+
57
+ def _heuristic(obs) -> Tuple[Action, str]:
58
+ for m in obs.machines:
59
+ if m.status == "broken": return Action(action_type="repair", machine_id=m.id), f"repair {m.id}"
60
+ for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
61
+ for m in obs.machines:
62
+ if m.status == "idle":
63
+ return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), f"assign_job {j.id} {m.id}"
64
+ return Action(action_type="wait"), "wait"
65
+
66
+
67
+ def _render_state(obs):
68
+ if obs is None: return "*Reset to start*"
69
+ lines = [f"### ⏱ Time: {obs.time} / {obs.max_steps}",
70
+ "\n**Machines**", "| ID | Status | Job |", "|---|---|---|"]
71
+ for m in obs.machines:
72
+ lines.append(f"| {m.id} | {STATUS_EMOJI.get(m.status,'')} {m.status} | {m.current_job or '—'} |")
73
+ lines.append("\n**Pending Jobs**")
74
+ if obs.pending_jobs:
75
+ lines += ["| ID | Remaining | Deadline | Priority |", "|---|---|---|---|"]
76
+ for j in sorted(obs.pending_jobs, key=lambda x: x.deadline):
77
+ urgent = "🔥" if obs.time + j.remaining_time > j.deadline else ""
78
+ lines.append(f"| {j.id} {urgent} | {j.remaining_time} | {j.deadline} | {'★'*j.priority} |")
79
+ else:
80
+ lines.append("*All jobs completed! ✅*")
81
+ if obs.completed_jobs:
82
+ lines.append(f"\n**Completed:** {len(obs.completed_jobs)} ✅")
83
+ return "\n".join(lines)
84
+
85
+
86
+ def _render_log(history):
87
+ if not history: return "*No steps yet*"
88
+ rows = ["| Step | Action | Reward | Done |", "|---|---|---|---|"]
89
+ for h in history[-15:]:
90
+ r = h["reward"]; icon = "🟢" if r > 0.3 else ("🔴" if r < -0.05 else "🟡")
91
+ rows.append(f"| {h['step']} | `{h['action']}` | {icon} {r:+.2f} | {'✅' if h['done'] else ''} |")
92
+ return "\n".join(rows)
93
+
94
+
95
+ def _render_score(rewards, env):
96
+ if not rewards or not env: return ""
97
+ s = score_episode(env)
98
+ bar = "█" * int(s * 20) + "░" * (20 - int(s * 20))
99
+ return f"**Score:** {s:.4f} `[{bar}]`\n**Completed:** {len(env.completed_jobs)} | **Late:** {env.late_jobs} | **Total Reward:** {sum(rewards):.2f}"
100
+
101
+
102
+ def reset_env(task):
103
+ global _env, _obs, _rewards, _history, _step_num
104
+ _env = FactoryEnv(task=task, seed=42); _obs = _env.reset()
105
+ _rewards = []; _history = []; _step_num = 0
106
+ return _render_state(_obs), _render_log([]), "", f"✅ Reset — **{task}**: {len(_obs.machines)} machines, {len(_obs.pending_jobs)} jobs"
107
+
108
+
109
+ def manual_step(text):
110
+ global _obs, _rewards, _history, _step_num
111
+ if _env is None: return _render_state(None), _render_log([]), "", "⚠ Reset first."
112
+ if _obs.done: return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), "✅ Episode done."
113
+ _step_num += 1
114
+ _obs = _env.step(_parse(text.strip()))
115
+ r = _obs.reward or 0.0; _rewards.append(r); _history.append({"step": _step_num, "action": text.strip(), "reward": r, "done": _obs.done})
116
+ status = f"Step {_step_num}: `{text.strip()}` → **{r:+.2f}**"
117
+ if _obs.done: status += f"\n\n🏁 Done! Score: **{score_episode(_env):.4f}**"
118
+ return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status
119
+
120
+
121
+ def heuristic_step():
122
+ global _obs, _rewards, _history, _step_num
123
+ if _env is None: return _render_state(None), _render_log([]), "", "⚠ Reset first."
124
+ if _obs.done: return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), "✅ Episode done."
125
+ action, action_text = _heuristic(_obs)
126
+ _step_num += 1
127
+ _obs = _env.step(action)
128
+ r = _obs.reward or 0.0; _rewards.append(r); _history.append({"step": _step_num, "action": f"[H] {action_text}", "reward": r, "done": _obs.done})
129
+ status = f"[Heuristic] Step {_step_num}: `{action_text}` → **{r:+.2f}**"
130
+ if _obs.done: status += f"\n\n🏁 Done! Score: **{score_episode(_env):.4f}**"
131
+ return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status
132
+
133
+
134
+ def llm_step(provider, api_key, model):
135
+ global _obs, _rewards, _history, _step_num
136
+ if _env is None: return _render_state(None), _render_log([]), "", "⚠ Reset first.", ""
137
+ if _obs.done: return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), "✅ Episode done.", ""
138
+ try: client = _llm_client(provider, api_key)
139
+ except Exception as e: return _render_state(_obs), _render_log(_history), "", f"⚠ {e}", ""
140
+ action_text = _call_llm(client, model, _obs, _rewards[-1] if _rewards else 0.0, _step_num + 1)
141
+ action = _parse(action_text)
142
+ if action.action_type == "wait" and (_obs.pending_jobs or any(m.status == "broken" for m in _obs.machines)):
143
+ action, action_text = _heuristic(_obs)
144
+ action_text = f"[fallback] {action_text}"
145
+ _step_num += 1
146
+ _obs = _env.step(action)
147
+ r = _obs.reward or 0.0; _rewards.append(r); _history.append({"step": _step_num, "action": f"[LLM] {action_text}", "reward": r, "done": _obs.done})
148
+ status = f"[LLM] Step {_step_num}: `{action_text}` → **{r:+.2f}**"
149
+ if _obs.done: status += f"\n\n🏁 Done! Score: **{score_episode(_env):.4f}**"
150
+ return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status, action_text
151
+
152
+
153
+ def run_full_episode(provider, api_key, model, task):
154
+ global _env, _obs, _rewards, _history, _step_num
155
+ _env = FactoryEnv(task=task, seed=42); _obs = _env.reset()
156
+ _rewards = []; _history = []; _step_num = 0
157
+ try: client = _llm_client(provider, api_key)
158
+ except Exception as e: return _render_state(_obs), _render_log([]), "", f"⚠ {e}", ""
159
+ log_lines = []
160
+ while not _obs.done and _step_num < _obs.max_steps:
161
+ action_text = _call_llm(client, model, _obs, _rewards[-1] if _rewards else 0.0, _step_num + 1)
162
+ action = _parse(action_text)
163
+ if action.action_type == "wait" and (_obs.pending_jobs or any(m.status == "broken" for m in _obs.machines)):
164
+ action, action_text = _heuristic(_obs); action_text = f"[fb] {action_text}"
165
+ _step_num += 1; _obs = _env.step(action)
166
+ r = _obs.reward or 0.0; _rewards.append(r)
167
+ _history.append({"step": _step_num, "action": action_text, "reward": r, "done": _obs.done})
168
+ log_lines.append(f"Step {_step_num:2d}: {action_text:<35s} r={r:+.2f}")
169
+ s = score_episode(_env)
170
+ status = f"🏁 **Done!** Score: **{s:.4f}** | Completed: {len(_env.completed_jobs)} | Late: {_env.late_jobs}"
171
+ return _render_state(_obs), _render_log(_history), _render_score(_rewards, _env), status, "\n".join(log_lines)
172
+
173
+
174
+ def build_ui():
175
+ with gr.Blocks(title="Smart Factory RL") as demo:
176
+ gr.Markdown("# 🏭 Smart Factory Scheduling — Interactive RL Demo")
177
+ with gr.Row():
178
+ with gr.Column(scale=1):
179
+ gr.Markdown("### ⚙️ Setup")
180
+ task_dd = gr.Dropdown(["easy","medium","hard"], value="easy", label="Task")
181
+ provider_dd = gr.Dropdown(["OpenAI (GPT)","Claude (Anthropic)","HuggingFace Router"], value="OpenAI (GPT)", label="Provider")
182
+ api_key_box = gr.Textbox(label="API Key", type="password", placeholder="sk-... or sk-ant-...")
183
+ model_box = gr.Textbox(label="Model", value="gpt-4o-mini")
184
+ reset_btn = gr.Button("🔄 Reset", variant="primary")
185
+ gr.Markdown("### 🎮 Manual")
186
+ manual_input = gr.Textbox(label="Action", placeholder="assign_job J1 M1 | repair M2 | wait")
187
+ with gr.Row():
188
+ manual_btn = gr.Button("▶ Execute")
189
+ heuristic_btn = gr.Button("🤖 Heuristic Step")
190
+ gr.Markdown("### 🧠 LLM")
191
+ with gr.Row():
192
+ llm_step_btn = gr.Button("🔮 LLM Step", variant="secondary")
193
+ llm_ep_btn = gr.Button("⚡ Run Full Episode", variant="primary")
194
+ llm_out = gr.Textbox(label="LLM Output", interactive=False)
195
+ status_md = gr.Markdown("*Press Reset to start*")
196
+ with gr.Column(scale=2):
197
+ gr.Markdown("### 🏭 Factory State")
198
+ state_md = gr.Markdown("*Reset to start*")
199
+ gr.Markdown("### 📊 Score")
200
+ score_md = gr.Markdown("")
201
+ gr.Markdown("### 📋 Step Log")
202
+ log_md = gr.Markdown("*No steps yet*")
203
+ reset_btn.click(reset_env, [task_dd], [state_md, log_md, score_md, status_md])
204
+ manual_btn.click(manual_step, [manual_input], [state_md, log_md, score_md, status_md])
205
+ heuristic_btn.click(heuristic_step, [], [state_md, log_md, score_md, status_md])
206
+ llm_step_btn.click(llm_step, [provider_dd, api_key_box, model_box], [state_md, log_md, score_md, status_md, llm_out])
207
+ llm_ep_btn.click(run_full_episode, [provider_dd, api_key_box, model_box, task_dd], [state_md, log_md, score_md, status_md, llm_out])
208
+ return demo
209
+
210
+
211
+ if __name__ == "__main__":
212
+ build_ui().launch(server_name="0.0.0.0", server_port=7860, show_error=True, theme=gr.themes.Soft())
factory_env/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from factory_env.env import FactoryEnv
2
+ from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
factory_env/env.py CHANGED
@@ -1,93 +1,134 @@
1
  import random
2
- from typing import List
3
- from factory_env.models import Observation, Action, Machine, Job
4
-
5
- class FactoryEnv:
6
- def __init__(self, task="easy"):
7
- self.task = task
8
- self.time = 0
9
- self.max_steps = 20
10
-
11
- async def reset(self):
12
- random.seed(42)
13
-
14
- self.time = 0
15
-
16
- self.machines = [
17
- Machine(id="M1", status="idle"),
18
- Machine(id="M2", status="idle"),
19
- ]
20
-
21
- self.jobs = [
22
- Job(id="J1", remaining_time=3, deadline=10),
23
- Job(id="J2", remaining_time=2, deadline=8),
24
- ]
25
-
26
- return self._get_result(0.0, False)
27
-
28
-
29
- async def step(self, action: Action):
30
- reward = 0.0
31
-
32
- # Apply action
33
- if action.action_type == "assign_job":
34
- job = self._find_job(action.job_id)
35
- machine = self._find_machine(action.machine_id)
36
-
37
- if job and machine and machine.status == "idle":
38
- job.assigned_machine = machine.id
39
- machine.status = "busy"
40
- machine.current_job = job.id
41
- reward += 0.2
42
- else:
43
- reward -= 0.2 # invalid action
44
-
45
- # Simulate time
46
- self.time += 1
47
-
48
- for machine in self.machines:
49
- if machine.status == "busy":
50
- job = self._find_job(machine.current_job)
51
- job.remaining_time -= 1
52
-
53
- if job.remaining_time <= 0:
54
- reward += 1.0
55
- self.jobs.remove(job)
56
- machine.status = "idle"
57
- machine.current_job = None
58
-
59
- # Penalty for idle machines
60
- idle_count = sum(1 for m in self.machines if m.status == "idle")
61
- reward -= idle_count * 0.05
62
-
63
- done = self.time >= self.max_steps or len(self.jobs) == 0
64
-
65
- return self._get_result(reward, done)
66
-
67
-
68
- def state(self):
69
- return self._get_observation()
70
-
71
- def _get_observation(self):
72
- return Observation(
73
- machines=self.machines,
74
- pending_jobs=self.jobs,
75
- time=self.time,
76
- )
77
-
78
- def _get_result(self, reward, done):
79
- return type("Result", (), {
80
- "observation": self._get_observation(),
81
- "reward": reward,
82
- "done": done
83
- })
84
-
85
- def _find_job(self, job_id):
86
- return next((j for j in self.jobs if j.id == job_id), None)
87
-
88
- def _find_machine(self, machine_id):
89
- return next((m for m in self.machines if m.id == machine_id), None)
90
-
91
- async def close(self):
92
- pass
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import random
2
+ from typing import List, Optional
3
+
4
+ from openenv.core import Environment
5
+
6
+ from factory_env.models import FactoryAction, FactoryObservation, FactoryState, Machine, Job
7
+ from factory_env.tasks import TASKS
8
+
9
+
10
+ class FactoryEnv(Environment[FactoryAction, FactoryObservation, FactoryState]):
11
+ """Smart Factory Scheduling Environment — OpenEnv compliant."""
12
+
13
+ SUPPORTS_CONCURRENT_SESSIONS = True
14
+
15
+ def __init__(self, task: str = "easy", seed: int = 42):
16
+ super().__init__()
17
+ if task not in TASKS:
18
+ raise ValueError(f"Unknown task '{task}'. Choose from: {list(TASKS.keys())}")
19
+ self.task = task
20
+ self.seed = seed
21
+ self.config = TASKS[task]
22
+ self._rng = random.Random(seed)
23
+ self.machines: List[Machine] = []
24
+ self.jobs: List[Job] = []
25
+ self.completed_jobs: List[Job] = []
26
+ self.late_jobs: int = 0
27
+ self.time: int = 0
28
+ self.max_steps: int = self.config["max_steps"]
29
+
30
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> FactoryObservation:
31
+ use_seed = seed if seed is not None else self.seed
32
+ self._rng = random.Random(use_seed)
33
+ self.time = 0
34
+ self.completed_jobs = []
35
+ self.late_jobs = 0
36
+
37
+ cfg = self.config
38
+ self.machines = [
39
+ Machine(id=f"M{i+1}", status="idle", failure_rate=cfg.get("failure_rate", 0.0))
40
+ for i in range(cfg["num_machines"])
41
+ ]
42
+ self.jobs = []
43
+ for i in range(cfg["num_jobs"]):
44
+ proc_time = self._rng.randint(*cfg["job_time_range"])
45
+ deadline = self.time + proc_time + self._rng.randint(*cfg["deadline_slack"])
46
+ priority = self._rng.randint(1, cfg.get("max_priority", 1))
47
+ self.jobs.append(Job(id=f"J{i+1}", remaining_time=proc_time, deadline=deadline, priority=priority))
48
+
49
+ return self._make_obs(reward=None, done=False)
50
+
51
+ def step(self, action: FactoryAction, timeout_s: Optional[float] = None, **kwargs) -> FactoryObservation:
52
+ reward = 0.0
53
+
54
+ if action.action_type == "assign_job":
55
+ job = self._find_job(action.job_id)
56
+ machine = self._find_machine(action.machine_id)
57
+ if job is None or machine is None or machine.status != "idle":
58
+ reward -= 0.1
59
+ else:
60
+ job.assigned_machine = machine.id
61
+ machine.status = "busy"
62
+ machine.current_job = job.id
63
+ reward += 0.1
64
+ elif action.action_type == "repair":
65
+ machine = self._find_machine(action.machine_id)
66
+ if machine and machine.status == "broken":
67
+ machine.status = "idle"
68
+ reward += 0.05
69
+ else:
70
+ reward -= 0.05
71
+
72
+ self.time += 1
73
+
74
+ for machine in self.machines:
75
+ if machine.status == "busy":
76
+ job = self._find_job(machine.current_job)
77
+ if job:
78
+ job.remaining_time -= 1
79
+ if job.remaining_time <= 0:
80
+ on_time = self.time <= job.deadline
81
+ reward += (1.0 + 0.2 * job.priority) if on_time else 0.3
82
+ if not on_time:
83
+ self.late_jobs += 1
84
+ self.jobs.remove(job)
85
+ self.completed_jobs.append(job)
86
+ machine.status = "idle"
87
+ machine.current_job = None
88
+
89
+ if machine.status == "busy" and machine.failure_rate > 0:
90
+ if self._rng.random() < machine.failure_rate:
91
+ machine.status = "broken"
92
+ stalled = self._find_job(machine.current_job)
93
+ if stalled:
94
+ stalled.assigned_machine = None
95
+ machine.current_job = None
96
+
97
+ if self.jobs:
98
+ reward -= sum(1 for m in self.machines if m.status == "idle") * 0.05
99
+ for job in self.jobs:
100
+ if self.time > job.deadline:
101
+ reward -= 0.1
102
+
103
+ done = self.time >= self.max_steps or len(self.jobs) == 0
104
+ return self._make_obs(reward=reward, done=done)
105
+
106
+ @property
107
+ def state(self) -> FactoryState:
108
+ return FactoryState(
109
+ machines=list(self.machines),
110
+ pending_jobs=list(self.jobs),
111
+ completed_jobs=list(self.completed_jobs),
112
+ time=self.time,
113
+ task=self.task,
114
+ late_jobs=self.late_jobs,
115
+ step_count=self.time,
116
+ )
117
+
118
+ def _make_obs(self, reward, done: bool) -> FactoryObservation:
119
+ return FactoryObservation(
120
+ machines=list(self.machines),
121
+ pending_jobs=list(self.jobs),
122
+ completed_jobs=list(self.completed_jobs),
123
+ time=self.time,
124
+ max_steps=self.max_steps,
125
+ task=self.task,
126
+ reward=reward,
127
+ done=done,
128
+ )
129
+
130
+ def _find_job(self, job_id: Optional[str]) -> Optional[Job]:
131
+ return next((j for j in self.jobs if j.id == job_id), None) if job_id else None
132
+
133
+ def _find_machine(self, machine_id: Optional[str]) -> Optional[Machine]:
134
+ return next((m for m in self.machines if m.id == machine_id), None) if machine_id else None
factory_env/grader.py CHANGED
@@ -1,3 +1,15 @@
1
- def compute_score(total_reward, max_possible=20):
2
- score = total_reward / max_possible
3
- return max(0.0, min(1.0, score))
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def compute_score(completed, on_time, total_jobs, late_jobs, task="easy"):
2
+ if total_jobs == 0:
3
+ return 0.0
4
+ completion_rate = completed / total_jobs
5
+ on_time_rate = on_time / max(completed, 1)
6
+ utilization_bonus = max(0.0, 1.0 - late_jobs / max(completed, 1))
7
+ score = 0.5 * completion_rate + 0.3 * on_time_rate + 0.2 * utilization_bonus
8
+ return round(max(0.0, min(1.0, score)), 4)
9
+
10
+
11
+ def score_episode(env) -> float:
12
+ total = len(env.completed_jobs) + len(env.jobs)
13
+ completed = len(env.completed_jobs)
14
+ on_time = sum(1 for j in env.completed_jobs if env.time <= j.deadline)
15
+ return compute_score(completed, on_time, total, env.late_jobs, env.task)
factory_env/models.py CHANGED
@@ -1,26 +1,55 @@
1
- from pydantic import BaseModel
2
  from typing import List, Optional
 
 
 
3
 
4
  class Machine(BaseModel):
5
- id: str
6
- status: str # idle, busy, broken
7
- current_job: Optional[str] = None
 
 
 
8
 
9
  class Job(BaseModel):
10
- id: str
11
- remaining_time: int
12
- deadline: int
13
- assigned_machine: Optional[str] = None
14
-
15
- class Observation(BaseModel):
16
- machines: List[Machine]
17
- pending_jobs: List[Job]
18
- time: int
19
-
20
- class Action(BaseModel):
21
- action_type: str # assign_job, wait
22
- job_id: Optional[str] = None
23
- machine_id: Optional[str] = None
24
-
25
- class Reward(BaseModel):
26
- value: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List, Optional
2
+ from pydantic import BaseModel, ConfigDict, Field
3
+ from openenv.core import Action as BaseAction, Observation as BaseObservation, State as BaseState
4
+
5
 
6
  class Machine(BaseModel):
7
+ model_config = ConfigDict(extra="forbid")
8
+ id: str
9
+ status: str # idle | busy | broken
10
+ current_job: Optional[str] = None
11
+ failure_rate: float = 0.0
12
+
13
 
14
  class Job(BaseModel):
15
+ model_config = ConfigDict(extra="forbid")
16
+ id: str
17
+ remaining_time: int
18
+ deadline: int
19
+ priority: int = 1
20
+ assigned_machine: Optional[str] = None
21
+
22
+
23
+ class FactoryAction(BaseAction):
24
+ """
25
+ action_type: assign_job | repair | wait
26
+ job_id: required for assign_job
27
+ machine_id: required for assign_job / repair
28
+ """
29
+ action_type: str
30
+ job_id: Optional[str] = None
31
+ machine_id: Optional[str] = None
32
+
33
+
34
+ class FactoryObservation(BaseObservation):
35
+ """Inherits: done (bool), reward (float|None), metadata (dict)"""
36
+ machines: List[Machine] = Field(default_factory=list)
37
+ pending_jobs: List[Job] = Field(default_factory=list)
38
+ completed_jobs: List[Job] = Field(default_factory=list)
39
+ time: int = 0
40
+ max_steps: int = 20
41
+ task: str = "easy"
42
+
43
+
44
+ class FactoryState(BaseState):
45
+ machines: List[Machine] = Field(default_factory=list)
46
+ pending_jobs: List[Job] = Field(default_factory=list)
47
+ completed_jobs: List[Job] = Field(default_factory=list)
48
+ time: int = 0
49
+ task: str = "easy"
50
+ late_jobs: int = 0
51
+
52
+
53
+ # Aliases for backward compatibility
54
+ Action = FactoryAction
55
+ Observation = FactoryObservation
factory_env/tasks.py CHANGED
@@ -1,17 +1,35 @@
1
  TASKS = {
2
- "easy": {
3
- "machines": 2,
4
- "jobs": 2,
5
- "failures": False,
6
- },
7
- "medium": {
8
- "machines": 3,
9
- "jobs": 5,
10
- "failures": True,
11
- },
12
- "hard": {
13
- "machines": 5,
14
- "jobs": 10,
15
- "failures": True,
16
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
 
1
  TASKS = {
2
+ "easy": {
3
+ "num_machines": 2,
4
+ "num_jobs": 3,
5
+ "failures": False,
6
+ "failure_rate": 0.0,
7
+ "max_priority": 1,
8
+ "job_time_range": (2, 5),
9
+ "deadline_slack": (4, 8),
10
+ "max_steps": 20,
11
+ "description": "Assign 3 jobs to 2 machines with no failures.",
12
+ },
13
+ "medium": {
14
+ "num_machines": 4,
15
+ "num_jobs": 7,
16
+ "failures": True,
17
+ "failure_rate": 0.08,
18
+ "max_priority": 2,
19
+ "job_time_range": (3, 7),
20
+ "deadline_slack": (2, 5),
21
+ "max_steps": 30,
22
+ "description": "Manage 7 jobs across 4 machines with random breakdowns.",
23
+ },
24
+ "hard": {
25
+ "num_machines": 6,
26
+ "num_jobs": 12,
27
+ "failures": True,
28
+ "failure_rate": 0.15,
29
+ "max_priority": 3,
30
+ "job_time_range": (3, 8),
31
+ "deadline_slack": (1, 4),
32
+ "max_steps": 40,
33
+ "description": "Optimize throughput across 12 jobs and 6 machines under frequent failures.",
34
+ },
35
  }
inference.py CHANGED
@@ -1,248 +1,144 @@
1
  """
2
- Factory Environment Inference Script
3
- ===================================
4
- Follows OpenEnv evaluation format strictly.
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
- import asyncio
8
  import os
9
  import textwrap
10
- from typing import List, Optional
11
 
12
  from openai import OpenAI
13
 
14
  from factory_env.env import FactoryEnv
15
- from factory_env.models import Action
16
-
17
- # =========================
18
- # ENV VARIABLES (MANDATORY)
19
- # =========================
20
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
 
 
21
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
22
  MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
23
-
24
  TASK_NAME = os.getenv("FACTORY_TASK", "easy")
25
  BENCHMARK = "factory_env"
26
-
27
- MAX_STEPS = 20
28
  TEMPERATURE = 0.2
29
- MAX_TOKENS = 100
30
  SUCCESS_SCORE_THRESHOLD = 0.5
31
 
32
- # =========================
33
- # PROMPTS
34
- # =========================
35
- SYSTEM_PROMPT = textwrap.dedent(
36
- """
37
- You are controlling a factory scheduling system.
38
-
39
- Your goal:
40
- - Assign jobs to machines efficiently
41
- - Minimize idle machines
42
- - Finish all jobs as fast as possible
43
-
44
- Available actions:
45
- 1. assign_job <job_id> <machine_id>
46
- 2. wait
47
-
48
- Rules:
49
- - Only assign jobs that exist
50
- - Only assign to idle machines
51
- - One action per step
52
-
53
- Respond ONLY with the action string.
54
- Example:
55
- assign_job J1 M1
56
- """
57
- ).strip()
58
-
59
-
60
- # =========================
61
- # LOGGING FUNCTIONS (STRICT FORMAT)
62
- # =========================
63
  def log_start(task: str, env: str, model: str) -> None:
64
  print(f"[START] task={task} env={env} model={model}", flush=True)
65
 
66
 
67
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
68
- error_val = error if error else "null"
69
- done_val = str(done).lower()
70
- print(
71
- f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
72
- flush=True,
73
- )
74
 
75
 
76
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
77
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
78
- print(
79
- f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
80
- flush=True,
81
- )
82
-
83
-
84
- # =========================
85
- # PROMPT BUILDER
86
- # =========================
87
- def build_user_prompt(step, obs, last_reward):
88
- machines_str = "\n".join(
89
- [f"{m.id}: {m.status} (job={m.current_job})" for m in obs.machines]
90
- )
91
-
92
- jobs_str = "\n".join(
93
- [f"{j.id}: remaining={j.remaining_time}, deadline={j.deadline}" for j in obs.pending_jobs]
94
- ) or "None"
95
-
96
- return textwrap.dedent(
97
- f"""
98
- Step: {step}
99
 
100
- Current Time: {obs.time}
101
 
102
- Machines:
103
- {machines_str}
 
 
104
 
105
- Pending Jobs:
106
- {jobs_str}
107
 
108
- Last reward: {last_reward:.2f}
109
-
110
- What action do you take?
111
- """
112
- ).strip()
113
-
114
-
115
- # =========================
116
- # LLM CALL
117
- # =========================
118
- def get_model_action(client: OpenAI, step, obs, last_reward) -> str:
119
  try:
120
- user_prompt = build_user_prompt(step, obs, last_reward)
121
-
122
- completion = client.chat.completions.create(
123
  model=MODEL_NAME,
124
- messages=[
125
- {"role": "system", "content": SYSTEM_PROMPT},
126
- {"role": "user", "content": user_prompt},
127
- ],
128
  temperature=TEMPERATURE,
129
  max_tokens=MAX_TOKENS,
130
  )
131
-
132
- text = (completion.choices[0].message.content or "").strip()
133
- return text if text else "wait"
134
-
135
  except Exception as e:
136
  print(f"[DEBUG] LLM error: {e}", flush=True)
137
  return "wait"
138
 
139
 
140
- # =========================
141
- # ACTION PARSER
142
- # =========================
143
  def parse_action(text: str) -> Action:
144
  try:
145
  parts = text.strip().split()
146
-
147
  if parts[0] == "assign_job" and len(parts) == 3:
148
- return Action(
149
- action_type="assign_job",
150
- job_id=parts[1],
151
- machine_id=parts[2],
152
- )
153
-
154
- elif parts[0] == "wait":
155
- return Action(action_type="wait")
156
-
157
  except Exception:
158
  pass
159
-
160
- # fallback safe action
161
  return Action(action_type="wait")
162
 
163
 
164
- # =========================
165
- # SIMPLE HEURISTIC FALLBACK
166
- # =========================
167
- def heuristic_action(obs) -> Action:
168
- for job in obs.pending_jobs:
169
- for machine in obs.machines:
170
- if machine.status == "idle":
171
- return Action(
172
- action_type="assign_job",
173
- job_id=job.id,
174
- machine_id=machine.id,
175
- )
176
- return Action(action_type="wait")
177
 
178
 
179
- # =========================
180
- # MAIN LOOP
181
- # =========================
182
- async def main():
183
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
184
-
185
- env = FactoryEnv(task=TASK_NAME)
186
-
187
  rewards: List[float] = []
188
  steps_taken = 0
189
  score = 0.0
190
  success = False
191
 
192
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
193
 
194
  try:
195
- result = await env.reset()
196
- obs = result.observation
197
  last_reward = 0.0
198
 
199
- for step in range(1, MAX_STEPS + 1):
200
- if result.done:
201
  break
202
-
203
- # LLM decision
204
  action_text = get_model_action(client, step, obs, last_reward)
205
-
206
- # Parse action
207
  action = parse_action(action_text)
208
-
209
- # Fallback if invalid
210
- if action.action_type == "wait" and len(obs.pending_jobs) > 0:
211
- action = heuristic_action(obs)
212
- action_text = "heuristic_assign"
213
-
214
- # Step env
215
- result = await env.step(action)
216
-
217
- obs = result.observation
218
- reward = result.reward or 0.0
219
- done = result.done
220
- error = None
221
-
222
  rewards.append(reward)
223
  steps_taken = step
224
  last_reward = reward
225
-
226
- log_step(step, action_text, reward, done, error)
227
-
228
- if done:
229
  break
230
 
231
- # Normalize score
232
- if rewards:
233
- score = sum(rewards) / len(rewards)
234
- score = max(0.0, min(1.0, score))
235
-
236
  success = score >= SUCCESS_SCORE_THRESHOLD
237
-
238
  finally:
239
- try:
240
- await env.close()
241
- except Exception as e:
242
- print(f"[DEBUG] env.close error: {e}", flush=True)
243
-
244
  log_end(success, steps_taken, score, rewards)
245
 
246
 
247
  if __name__ == "__main__":
248
- asyncio.run(main())
 
1
  """
2
+ Inference Script — Smart Factory Scheduling Environment
3
+ ========================================================
4
+ Mandatory env vars (per hackathon spec):
5
+ OPENAI_API_KEY API key (also accepts HF_TOKEN for HF router)
6
+ API_BASE_URL LLM endpoint (default: HF router)
7
+ MODEL_NAME Model ID (default: Qwen/Qwen2.5-72B-Instruct)
8
+ HF_TOKEN HuggingFace token
9
+ FACTORY_TASK easy | medium | hard (default: easy)
10
+
11
+ STDOUT FORMAT:
12
+ [START] task=<name> env=factory_env model=<model>
13
+ [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
14
+ [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
15
  """
16
 
 
17
  import os
18
  import textwrap
19
+ from typing import List, Optional, Tuple
20
 
21
  from openai import OpenAI
22
 
23
  from factory_env.env import FactoryEnv
24
+ from factory_env.models import FactoryAction as Action
25
+ from factory_env.grader import score_episode
26
+
27
+ API_KEY = (
28
+ os.getenv("OPENAI_API_KEY")
29
+ or os.getenv("HF_TOKEN")
30
+ or os.getenv("API_KEY")
31
+ )
32
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
33
  MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
 
34
  TASK_NAME = os.getenv("FACTORY_TASK", "easy")
35
  BENCHMARK = "factory_env"
 
 
36
  TEMPERATURE = 0.2
37
+ MAX_TOKENS = 80
38
  SUCCESS_SCORE_THRESHOLD = 0.5
39
 
40
+ SYSTEM_PROMPT = textwrap.dedent("""
41
+ You are controlling a smart factory scheduling system.
42
+ Goal: complete all jobs before their deadlines, keep machines busy, repair broken machines.
43
+ Actions (respond with EXACTLY one line):
44
+ assign_job <job_id> <machine_id>
45
+ repair <machine_id>
46
+ wait
47
+ Respond with ONLY the action string.
48
+ """).strip()
49
+
50
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def log_start(task: str, env: str, model: str) -> None:
52
  print(f"[START] task={task} env={env} model={model}", flush=True)
53
 
54
 
55
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
56
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}", flush=True)
 
 
 
 
 
57
 
58
 
59
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
60
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
62
 
63
+ def build_prompt(step: int, obs, last_reward: float) -> str:
64
+ machines = "\n".join(f" {m.id}: {m.status}" + (f" ({m.current_job})" if m.current_job else "") for m in obs.machines)
65
+ jobs = "\n".join(f" {j.id}: remaining={j.remaining_time}, deadline={j.deadline}, priority={j.priority}" for j in obs.pending_jobs) or " (none)"
66
+ return f"Step {step}/{obs.max_steps} | time={obs.time} | last_reward={last_reward:+.2f}\nMachines:\n{machines}\nPending Jobs:\n{jobs}\nAction:"
67
 
 
 
68
 
69
+ def get_model_action(client: OpenAI, step: int, obs, last_reward: float) -> str:
 
 
 
 
 
 
 
 
 
 
70
  try:
71
+ resp = client.chat.completions.create(
 
 
72
  model=MODEL_NAME,
73
+ messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_prompt(step, obs, last_reward)}],
 
 
 
74
  temperature=TEMPERATURE,
75
  max_tokens=MAX_TOKENS,
76
  )
77
+ return (resp.choices[0].message.content or "wait").strip().splitlines()[0]
 
 
 
78
  except Exception as e:
79
  print(f"[DEBUG] LLM error: {e}", flush=True)
80
  return "wait"
81
 
82
 
 
 
 
83
  def parse_action(text: str) -> Action:
84
  try:
85
  parts = text.strip().split()
 
86
  if parts[0] == "assign_job" and len(parts) == 3:
87
+ return Action(action_type="assign_job", job_id=parts[1], machine_id=parts[2])
88
+ if parts[0] == "repair" and len(parts) == 2:
89
+ return Action(action_type="repair", machine_id=parts[1])
 
 
 
 
 
 
90
  except Exception:
91
  pass
 
 
92
  return Action(action_type="wait")
93
 
94
 
95
+ def heuristic_action(obs) -> Tuple[Action, str]:
96
+ for m in obs.machines:
97
+ if m.status == "broken":
98
+ return Action(action_type="repair", machine_id=m.id), f"repair {m.id}"
99
+ for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
100
+ for m in obs.machines:
101
+ if m.status == "idle":
102
+ s = f"assign_job {j.id} {m.id}"
103
+ return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), s
104
+ return Action(action_type="wait"), "wait"
 
 
 
105
 
106
 
107
+ def run_task(task_name: str) -> None:
 
 
 
108
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
109
+ env = FactoryEnv(task=task_name)
 
 
110
  rewards: List[float] = []
111
  steps_taken = 0
112
  score = 0.0
113
  success = False
114
 
115
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
116
 
117
  try:
118
+ obs = env.reset()
 
119
  last_reward = 0.0
120
 
121
+ for step in range(1, obs.max_steps + 1):
122
+ if obs.done:
123
  break
 
 
124
  action_text = get_model_action(client, step, obs, last_reward)
 
 
125
  action = parse_action(action_text)
126
+ if action.action_type == "wait" and (obs.pending_jobs or any(m.status == "broken" for m in obs.machines)):
127
+ action, action_text = heuristic_action(obs)
128
+ obs = env.step(action)
129
+ reward = obs.reward or 0.0
 
 
 
 
 
 
 
 
 
 
130
  rewards.append(reward)
131
  steps_taken = step
132
  last_reward = reward
133
+ log_step(step, action_text, reward, obs.done, None)
134
+ if obs.done:
 
 
135
  break
136
 
137
+ score = score_episode(env)
 
 
 
 
138
  success = score >= SUCCESS_SCORE_THRESHOLD
 
139
  finally:
 
 
 
 
 
140
  log_end(success, steps_taken, score, rewards)
141
 
142
 
143
  if __name__ == "__main__":
144
+ run_task(TASK_NAME)
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- pydantic
2
- openai
3
- asyncio
 
 
 
 
 
1
+ pydantic>=2.0
2
+ openai>=1.0
3
+ anthropic>=0.90
4
+ gradio>=6.0
5
+ openenv-core>=0.2.3
6
+ fastapi>=0.100
7
+ uvicorn>=0.23
server.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv HTTP Server — Smart Factory Scheduling
3
+ Routes: GET /health POST /reset POST /step GET /state GET /schema
4
+ """
5
+ import os
6
+ from openenv.core import create_app
7
+ from factory_env.env import FactoryEnv
8
+ from factory_env.models import FactoryAction, FactoryObservation
9
+
10
+ TASK = os.getenv("FACTORY_TASK", "easy")
11
+
12
+ app = create_app(
13
+ env=lambda: FactoryEnv(task=TASK, seed=42),
14
+ action_cls=FactoryAction,
15
+ observation_cls=FactoryObservation,
16
+ env_name="factory_env",
17
+ )
18
+
19
+ if __name__ == "__main__":
20
+ import uvicorn
21
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
train.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RL Training Loop — Smart Factory Scheduling
3
+ ============================================
4
+ Strategy: Online In-Context RL — best trajectory fed as few-shot example each episode.
5
+
6
+ Usage:
7
+ export OPENAI_API_KEY=sk-... # OpenAI
8
+ export ANTHROPIC_API_KEY=sk-ant-... # Claude
9
+ python train.py --task easy --episodes 10 --provider openai
10
+ python train.py --task medium --episodes 10 --provider claude
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import os
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+ from typing import List, Optional, Tuple
20
+
21
+ from factory_env.env import FactoryEnv
22
+ from factory_env.grader import score_episode
23
+ from factory_env.models import FactoryAction as Action
24
+
25
+
26
+ def get_openai_client():
27
+ from openai import OpenAI
28
+ key = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
29
+ base = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
30
+ return OpenAI(api_key=key, base_url=base)
31
+
32
+
33
+ def get_claude_client():
34
+ import anthropic
35
+ return anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
36
+
37
+
38
+ @dataclass
39
+ class Step:
40
+ step: int
41
+ obs_text: str
42
+ action_text: str
43
+ reward: float
44
+ done: bool
45
+
46
+
47
+ @dataclass
48
+ class Episode:
49
+ episode_num: int
50
+ task: str
51
+ steps: List[Step] = field(default_factory=list)
52
+ total_reward: float = 0.0
53
+ score: float = 0.0
54
+ completed: int = 0
55
+ late: int = 0
56
+
57
+ def to_few_shot(self, max_steps: int = 6) -> str:
58
+ lines = [f"# Best trajectory so far (score={self.score:.2f}, completed={self.completed} jobs)"]
59
+ for s in self.steps[:max_steps]:
60
+ lines.append(f"[Obs] {s.obs_text}")
61
+ lines.append(f"[Action] {s.action_text} → reward: {s.reward:+.2f}")
62
+ return "\n".join(lines)
63
+
64
+
65
+ SYSTEM_PROMPT = """You are an expert factory scheduling AI.
66
+ Goal: complete all jobs before deadlines, keep machines busy, repair broken machines.
67
+ Actions (one per step):
68
+ assign_job <job_id> <machine_id>
69
+ repair <machine_id>
70
+ wait
71
+ Tips: Fix broken machines first. Sort by earliest deadline. High-priority jobs give bonus reward."""
72
+
73
+
74
+ def obs_to_text(obs) -> str:
75
+ machines = ", ".join(f"{m.id}:{m.status}" + (f"({m.current_job})" if m.current_job else "") for m in obs.machines)
76
+ jobs = ", ".join(f"{j.id}[t={j.remaining_time},dl={j.deadline},p={j.priority}]" for j in obs.pending_jobs) or "none"
77
+ return f"t={obs.time} | machines: {machines} | pending: {jobs}"
78
+
79
+
80
+ def call_llm(messages: list, provider: str, client, model: str) -> str:
81
+ try:
82
+ if provider == "claude":
83
+ system = next((m["content"] for m in messages if m["role"] == "system"), "")
84
+ user_msgs = [m for m in messages if m["role"] != "system"]
85
+ resp = client.messages.create(model=model, max_tokens=60, system=system, messages=user_msgs)
86
+ return resp.content[0].text.strip().splitlines()[0]
87
+ else:
88
+ resp = client.chat.completions.create(model=model, messages=messages, temperature=0.2, max_tokens=60)
89
+ return (resp.choices[0].message.content or "wait").strip().splitlines()[0]
90
+ except Exception as e:
91
+ print(f" [LLM error] {e}")
92
+ return "wait"
93
+
94
+
95
+ def parse_action(text: str) -> Action:
96
+ try:
97
+ parts = text.strip().split()
98
+ if parts[0] == "assign_job" and len(parts) == 3:
99
+ return Action(action_type="assign_job", job_id=parts[1], machine_id=parts[2])
100
+ if parts[0] == "repair" and len(parts) == 2:
101
+ return Action(action_type="repair", machine_id=parts[1])
102
+ except Exception:
103
+ pass
104
+ return Action(action_type="wait")
105
+
106
+
107
+ def heuristic_action(obs) -> Tuple[Action, str]:
108
+ for m in obs.machines:
109
+ if m.status == "broken":
110
+ return Action(action_type="repair", machine_id=m.id), f"repair {m.id}"
111
+ for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)):
112
+ for m in obs.machines:
113
+ if m.status == "idle":
114
+ s = f"assign_job {j.id} {m.id}"
115
+ return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), s
116
+ return Action(action_type="wait"), "wait"
117
+
118
+
119
+ def run_episode(task, episode_num, provider, client, model, best_episode, seed=42, verbose=True) -> Episode:
120
+ env = FactoryEnv(task=task, seed=seed)
121
+ obs = env.reset()
122
+ last_reward = 0.0
123
+ ep = Episode(episode_num=episode_num, task=task)
124
+
125
+ if verbose:
126
+ print(f"\n Episode {episode_num} | task={task} | seed={seed}")
127
+ print(f" {len(obs.machines)} machines, {len(obs.pending_jobs)} jobs, {obs.max_steps} steps")
128
+
129
+ for step in range(1, obs.max_steps + 1):
130
+ if obs.done:
131
+ break
132
+
133
+ obs_text = obs_to_text(obs)
134
+ few_shot = best_episode.to_few_shot() if best_episode and step == 1 else ""
135
+ user = f"{few_shot}\n\n---\n" if few_shot else ""
136
+ user += f"Step {step} | Last reward: {last_reward:+.2f}\n{obs_text}\n\nAction:"
137
+
138
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}]
139
+ action_text = call_llm(messages, provider, client, model)
140
+ action = parse_action(action_text)
141
+
142
+ if action.action_type == "wait" and (obs.pending_jobs or any(m.status == "broken" for m in obs.machines)):
143
+ action, action_text = heuristic_action(obs)
144
+
145
+ obs = env.step(action)
146
+ reward = obs.reward or 0.0
147
+ last_reward = reward
148
+ ep.steps.append(Step(step, obs_text, action_text, reward, obs.done))
149
+ ep.total_reward += reward
150
+
151
+ if verbose:
152
+ marker = "✓" if reward > 0.5 else ("✗" if reward < -0.05 else "·")
153
+ print(f" [{marker}] step={step:2d} {action_text:<30s} r={reward:+.2f}")
154
+
155
+ if obs.done:
156
+ break
157
+
158
+ ep.score = score_episode(env)
159
+ ep.completed = len(env.completed_jobs)
160
+ ep.late = env.late_jobs
161
+
162
+ if verbose:
163
+ print(f" → score={ep.score:.4f} completed={ep.completed} late={ep.late}")
164
+
165
+ return ep
166
+
167
+
168
+ def train(task, num_episodes, provider, model, save_dir="runs", verbose=True):
169
+ print(f"\n{'='*60}")
170
+ print(f" Smart Factory RL Training")
171
+ print(f" Task: {task} | Episodes: {num_episodes} | Provider: {provider} | Model: {model}")
172
+ print(f"{'='*60}")
173
+
174
+ client = get_claude_client() if provider == "claude" else get_openai_client()
175
+ Path(save_dir).mkdir(exist_ok=True)
176
+
177
+ scores = []
178
+ best_episode = None
179
+
180
+ for ep_num in range(1, num_episodes + 1):
181
+ ep = run_episode(task, ep_num, provider, client, model, best_episode, seed=42 + ep_num - 1, verbose=verbose)
182
+ scores.append(ep.score)
183
+ if best_episode is None or ep.score > best_episode.score:
184
+ best_episode = ep
185
+ print(f" ★ New best: score={ep.score:.4f}")
186
+ if ep_num < num_episodes:
187
+ time.sleep(1.0)
188
+
189
+ print(f"\n{'='*60}")
190
+ print(f" Training Complete — {num_episodes} episodes | Task: {task}")
191
+ print(f" First: {scores[0]:.4f} | Last: {scores[-1]:.4f} | Best: {max(scores):.4f}")
192
+ print(f"\n Score per episode:")
193
+ for i, s in enumerate(scores, 1):
194
+ print(f" ep{i:02d}: {s:.4f} {'█' * int(s * 20)}")
195
+
196
+ out = Path(save_dir) / f"{task}_{provider}_{num_episodes}ep.json"
197
+ out.write_text(json.dumps({"task": task, "provider": provider, "model": model, "num_episodes": num_episodes, "scores": scores, "best_score": max(scores), "final_score": scores[-1]}, indent=2))
198
+ print(f"\n Results saved → {out}")
199
+ return scores
200
+
201
+
202
+ def main():
203
+ parser = argparse.ArgumentParser()
204
+ parser.add_argument("--task", default="easy", choices=["easy", "medium", "hard"])
205
+ parser.add_argument("--episodes", type=int, default=5)
206
+ parser.add_argument("--provider", default="openai", choices=["openai", "claude"])
207
+ parser.add_argument("--model", default="")
208
+ parser.add_argument("--save-dir", default="runs")
209
+ parser.add_argument("--quiet", action="store_true")
210
+ args = parser.parse_args()
211
+ if not args.model:
212
+ args.model = "claude-sonnet-4-6" if args.provider == "claude" else "gpt-4o-mini"
213
+ train(args.task, args.episodes, args.provider, args.model, args.save_dir, not args.quiet)
214
+
215
+
216
+ if __name__ == "__main__":
217
+ main()