Spaces:
Sleeping
Sleeping
| """ | |
| RL Training Loop — Smart Factory Scheduling | |
| ============================================ | |
| Strategy: Online In-Context RL — best trajectory fed as few-shot example each episode. | |
| Usage: | |
| export OPENAI_API_KEY=sk-... # OpenAI | |
| export ANTHROPIC_API_KEY=sk-ant-... # Claude | |
| python train.py --task easy --episodes 10 --provider openai | |
| python train.py --task medium --episodes 10 --provider claude | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| from factory_env.env import FactoryEnv | |
| from factory_env.grader import score_episode | |
| from factory_env.models import FactoryAction as Action | |
| def get_openai_client(): | |
| from openai import OpenAI | |
| key = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| base = os.getenv("API_BASE_URL") or "https://api.openai.com/v1" | |
| return OpenAI(api_key=key, base_url=base) | |
| def get_claude_client(): | |
| import anthropic | |
| return anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
| class Step: | |
| step: int | |
| obs_text: str | |
| action_text: str | |
| reward: float | |
| done: bool | |
| class Episode: | |
| episode_num: int | |
| task: str | |
| steps: List[Step] = field(default_factory=list) | |
| total_reward: float = 0.0 | |
| score: float = 0.0 | |
| completed: int = 0 | |
| late: int = 0 | |
| def to_few_shot(self, max_steps: int = 6) -> str: | |
| lines = [f"# Best trajectory so far (score={self.score:.2f}, completed={self.completed} jobs)"] | |
| for s in self.steps[:max_steps]: | |
| lines.append(f"[Obs] {s.obs_text}") | |
| lines.append(f"[Action] {s.action_text} → reward: {s.reward:+.2f}") | |
| return "\n".join(lines) | |
| SYSTEM_PROMPT = """You are an expert factory scheduling AI. | |
| Goal: complete all jobs before deadlines, keep machines busy, repair broken machines. | |
| Actions (one per step): | |
| assign_job <job_id> <machine_id> | |
| repair <machine_id> | |
| wait | |
| Tips: Fix broken machines first. Sort by earliest deadline. High-priority jobs give bonus reward.""" | |
| def obs_to_text(obs) -> str: | |
| machines = ", ".join(f"{m.id}:{m.status}" + (f"({m.current_job})" if m.current_job else "") for m in obs.machines) | |
| jobs = ", ".join(f"{j.id}[t={j.remaining_time},dl={j.deadline},p={j.priority}]" for j in obs.pending_jobs) or "none" | |
| return f"t={obs.time} | machines: {machines} | pending: {jobs}" | |
| def call_llm(messages: list, provider: str, client, model: str) -> str: | |
| try: | |
| if provider == "claude": | |
| system = next((m["content"] for m in messages if m["role"] == "system"), "") | |
| user_msgs = [m for m in messages if m["role"] != "system"] | |
| resp = client.messages.create(model=model, max_tokens=60, system=system, messages=user_msgs) | |
| return resp.content[0].text.strip().splitlines()[0] | |
| else: | |
| resp = client.chat.completions.create(model=model, messages=messages, temperature=0.2, max_tokens=60) | |
| return (resp.choices[0].message.content or "wait").strip().splitlines()[0] | |
| except Exception as e: | |
| print(f" [LLM error] {e}") | |
| return "wait" | |
| def parse_action(text: str) -> Action: | |
| try: | |
| parts = text.strip().split() | |
| if parts[0] == "assign_job" and len(parts) == 3: | |
| return Action(action_type="assign_job", job_id=parts[1], machine_id=parts[2]) | |
| if parts[0] == "repair" and len(parts) == 2: | |
| return Action(action_type="repair", machine_id=parts[1]) | |
| except Exception: | |
| pass | |
| return Action(action_type="wait") | |
| def heuristic_action(obs) -> Tuple[Action, str]: | |
| for m in obs.machines: | |
| if m.status == "broken": | |
| return Action(action_type="repair", machine_id=m.id), f"repair {m.id}" | |
| for j in sorted(obs.pending_jobs, key=lambda x: (x.deadline, -x.priority)): | |
| for m in obs.machines: | |
| if m.status == "idle": | |
| s = f"assign_job {j.id} {m.id}" | |
| return Action(action_type="assign_job", job_id=j.id, machine_id=m.id), s | |
| return Action(action_type="wait"), "wait" | |
| def run_episode(task, episode_num, provider, client, model, best_episode, seed=42, verbose=True) -> Episode: | |
| env = FactoryEnv(task=task, seed=seed) | |
| obs = env.reset() | |
| last_reward = 0.0 | |
| ep = Episode(episode_num=episode_num, task=task) | |
| if verbose: | |
| print(f"\n Episode {episode_num} | task={task} | seed={seed}") | |
| print(f" {len(obs.machines)} machines, {len(obs.pending_jobs)} jobs, {obs.max_steps} steps") | |
| for step in range(1, obs.max_steps + 1): | |
| if obs.done: | |
| break | |
| obs_text = obs_to_text(obs) | |
| few_shot = best_episode.to_few_shot() if best_episode and step == 1 else "" | |
| user = f"{few_shot}\n\n---\n" if few_shot else "" | |
| user += f"Step {step} | Last reward: {last_reward:+.2f}\n{obs_text}\n\nAction:" | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}] | |
| action_text = call_llm(messages, provider, client, model) | |
| action = parse_action(action_text) | |
| if action.action_type == "wait" and (obs.pending_jobs or any(m.status == "broken" for m in obs.machines)): | |
| action, action_text = heuristic_action(obs) | |
| obs = env.step(action) | |
| reward = obs.reward or 0.0 | |
| last_reward = reward | |
| ep.steps.append(Step(step, obs_text, action_text, reward, obs.done)) | |
| ep.total_reward += reward | |
| if verbose: | |
| marker = "✓" if reward > 0.5 else ("✗" if reward < -0.05 else "·") | |
| print(f" [{marker}] step={step:2d} {action_text:<30s} r={reward:+.2f}") | |
| if obs.done: | |
| break | |
| ep.score = score_episode(env) | |
| ep.completed = len(env.completed_jobs) | |
| ep.late = env.late_jobs | |
| if verbose: | |
| print(f" → score={ep.score:.4f} completed={ep.completed} late={ep.late}") | |
| return ep | |
| def train(task, num_episodes, provider, model, save_dir="runs", verbose=True): | |
| print(f"\n{'='*60}") | |
| print(f" Smart Factory RL Training") | |
| print(f" Task: {task} | Episodes: {num_episodes} | Provider: {provider} | Model: {model}") | |
| print(f"{'='*60}") | |
| client = get_claude_client() if provider == "claude" else get_openai_client() | |
| Path(save_dir).mkdir(exist_ok=True) | |
| scores = [] | |
| best_episode = None | |
| for ep_num in range(1, num_episodes + 1): | |
| ep = run_episode(task, ep_num, provider, client, model, best_episode, seed=42 + ep_num - 1, verbose=verbose) | |
| scores.append(ep.score) | |
| if best_episode is None or ep.score > best_episode.score: | |
| best_episode = ep | |
| print(f" ★ New best: score={ep.score:.4f}") | |
| if ep_num < num_episodes: | |
| time.sleep(1.0) | |
| print(f"\n{'='*60}") | |
| print(f" Training Complete — {num_episodes} episodes | Task: {task}") | |
| print(f" First: {scores[0]:.4f} | Last: {scores[-1]:.4f} | Best: {max(scores):.4f}") | |
| print(f"\n Score per episode:") | |
| for i, s in enumerate(scores, 1): | |
| print(f" ep{i:02d}: {s:.4f} {'█' * int(s * 20)}") | |
| out = Path(save_dir) / f"{task}_{provider}_{num_episodes}ep.json" | |
| out.write_text(json.dumps({"task": task, "provider": provider, "model": model, "num_episodes": num_episodes, "scores": scores, "best_score": max(scores), "final_score": scores[-1]}, indent=2)) | |
| print(f"\n Results saved → {out}") | |
| return scores | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--task", default="easy", choices=["easy", "medium", "hard"]) | |
| parser.add_argument("--episodes", type=int, default=5) | |
| parser.add_argument("--provider", default="openai", choices=["openai", "claude"]) | |
| parser.add_argument("--model", default="") | |
| parser.add_argument("--save-dir", default="runs") | |
| parser.add_argument("--quiet", action="store_true") | |
| args = parser.parse_args() | |
| if not args.model: | |
| args.model = "claude-sonnet-4-6" if args.provider == "claude" else "gpt-4o-mini" | |
| train(args.task, args.episodes, args.provider, args.model, args.save_dir, not args.quiet) | |
| if __name__ == "__main__": | |
| main() | |