"""Warehouse GridWorld - Gradio + Gymnasium navigation game with live RL training. Run: pip install -r requirements.txt python app.py """ from __future__ import annotations import time from collections import deque import gradio as gr import gymnasium as gym import numpy as np from gymnasium import spaces # ---------- Constants ---------- DEFAULT_GRID_SIZE = 9 MAX_STEPS = 100 OBSTACLE_DENSITY = 0.20 UP, RIGHT, DOWN, LEFT = 0, 1, 2, 3 ACTION_NAMES = {0: "UP", 1: "RIGHT", 2: "DOWN", 3: "LEFT"} ACTION_ARROWS = {0: "↑", 1: "→", 2: "↓", 3: "←"} ACTION_DELTAS = { UP: (-1, 0), RIGHT: (0, 1), DOWN: (1, 0), LEFT: (0, -1), } # ---------- Environment ---------- class WarehouseEnv(gym.Env): """Gymnasium env for a randomized warehouse grid. Observation: [agent_x_norm, agent_y_norm, goal_x_norm, goal_y_norm] Action: 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT """ metadata = {"render_modes": ["html"]} def __init__(self, grid_size: int = DEFAULT_GRID_SIZE, max_steps: int = MAX_STEPS): super().__init__() self.grid_size = int(grid_size) self.max_steps = int(max_steps) self.action_space = spaces.Discrete(4) self.observation_space = spaces.Box( low=0.0, high=1.0, shape=(4,), dtype=np.float32 ) self.grid: np.ndarray | None = None self.agent_pos: tuple[int, int] = (0, 0) self.start_pos: tuple[int, int] = (0, 0) self.goal_pos: tuple[int, int] = (0, 0) self.steps = 0 self.total_score = 0.0 self.last_reward = 0.0 self.last_action: int | None = None self.last_rule = "New episode started. Agent begins on S." self.visited: set[tuple[int, int]] = set() self.terminated = False self.truncated = False # --- generation --- def _is_solvable(self, grid: np.ndarray, start: tuple[int, int], goal: tuple[int, int]) -> bool: n = self.grid_size if grid[start] == 1 or grid[goal] == 1: return False seen = {start} q = deque([start]) while q: r, c = q.popleft() if (r, c) == goal: return True for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): nr, nc = r + dr, c + dc if 0 <= nr < n and 0 <= nc < n and grid[nr, nc] == 0 and (nr, nc) not in seen: seen.add((nr, nc)) q.append((nr, nc)) return False def _generate_grid(self): n = self.grid_size rng = self.np_random for _ in range(300): start = (int(rng.integers(0, n)), int(rng.integers(0, n))) goal = (int(rng.integers(0, n)), int(rng.integers(0, n))) if start == goal: continue grid = (rng.random((n, n)) < OBSTACLE_DENSITY).astype(np.int8) grid[start] = 0 grid[goal] = 0 if self._is_solvable(grid, start, goal): return grid, start, goal return ( np.zeros((n, n), dtype=np.int8), (0, 0), (n - 1, n - 1), ) # --- helpers --- def _get_obs(self) -> np.ndarray: denom = max(self.grid_size - 1, 1) ax, ay = self.agent_pos gx, gy = self.goal_pos return np.array( [ax / denom, ay / denom, gx / denom, gy / denom], dtype=np.float32 ) @staticmethod def _manhattan(a: tuple[int, int], b: tuple[int, int]) -> int: return abs(a[0] - b[0]) + abs(a[1] - b[1]) # --- gym API --- def reset(self, seed: int | None = None, options: dict | None = None): super().reset(seed=seed) self.grid, self.start_pos, self.goal_pos = self._generate_grid() self.agent_pos = self.start_pos self._reset_episode_counters("New episode started. Agent begins on S.") return self._get_obs(), {} def soft_reset(self) -> np.ndarray: """Reset agent to start, keep the same maze. Used between RL episodes.""" self.agent_pos = self.start_pos self._reset_episode_counters("New episode (same maze). Agent at S.") return self._get_obs() def _reset_episode_counters(self, rule: str) -> None: self.steps = 0 self.total_score = 0.0 self.last_reward = 0.0 self.last_action = None self.last_rule = rule self.visited = {self.start_pos} self.terminated = False self.truncated = False def step(self, action: int): if self.terminated or self.truncated: return self._get_obs(), 0.0, self.terminated, self.truncated, {} action = int(action) self.steps += 1 self.last_action = action dr, dc = ACTION_DELTAS[action] nr, nc = self.agent_pos[0] + dr, self.agent_pos[1] + dc n = self.grid_size old_dist = self._manhattan(self.agent_pos, self.goal_pos) reward = 0.0 rule_parts: list[str] = [] out_of_bounds = not (0 <= nr < n and 0 <= nc < n) is_obstacle = (not out_of_bounds) and self.grid[nr, nc] == 1 if out_of_bounds or is_obstacle: reward += -5.0 rule_parts.append( "Invalid move: " + ("out of bounds" if out_of_bounds else "obstacle") + " (-5.0)" ) else: self.agent_pos = (nr, nc) new_dist = self._manhattan(self.agent_pos, self.goal_pos) if new_dist < old_dist: reward += 1.0 rule_parts.append("Closer to goal (+1.0)") elif new_dist > old_dist: reward += -0.5 rule_parts.append("Farther from goal (-0.5)") else: reward += -0.1 rule_parts.append("Same Manhattan distance (-0.1)") if self.agent_pos not in self.visited: reward += 0.3 rule_parts.append("New cell (+0.3)") self.visited.add(self.agent_pos) if self.agent_pos == self.goal_pos: reward += 50.0 rule_parts.append("GOAL reached (+50.0)") self.terminated = True if not self.terminated and self.steps >= self.max_steps: reward += -10.0 rule_parts.append("Step limit timeout (-10.0)") self.truncated = True self.last_reward = reward self.total_score += reward self.last_rule = "; ".join(rule_parts) + "." return self._get_obs(), reward, self.terminated, self.truncated, {} # ---------- Q-learning Agent ---------- class QLearningAgent: """Tabular Q-learning keyed on agent position. Fits this small grid perfectly.""" def __init__( self, n_actions: int = 4, alpha: float = 0.2, gamma: float = 0.95, eps_start: float = 1.0, eps_end: float = 0.05, eps_decay: float = 0.97, ): self.n_actions = n_actions self.alpha = alpha self.gamma = gamma self.eps = eps_start self.eps_end = eps_end self.eps_decay = eps_decay self.q: dict[tuple[int, int], np.ndarray] = {} self.episode = 0 self.last_episode_steps = 0 self.last_episode_score = 0.0 self.last_episode_solved = False self.best_score = -float("inf") self.best_episode = -1 self.solves = 0 def get_q(self, s: tuple[int, int]) -> np.ndarray: if s not in self.q: self.q[s] = np.zeros(self.n_actions, dtype=np.float32) return self.q[s] def select_action(self, s: tuple[int, int], rng: np.random.Generator) -> int: if rng.random() < self.eps: return int(rng.integers(0, self.n_actions)) q = self.get_q(s) best = np.flatnonzero(q == q.max()) return int(rng.choice(best)) def greedy_action(self, s: tuple[int, int], rng: np.random.Generator) -> int: q = self.get_q(s) best = np.flatnonzero(q == q.max()) return int(rng.choice(best)) def update( self, s: tuple[int, int], a: int, r: float, s2: tuple[int, int], done: bool, ) -> None: q_s = self.get_q(s) target = r if done else r + self.gamma * float(self.get_q(s2).max()) q_s[a] += self.alpha * (target - q_s[a]) def end_episode(self, score: float, steps: int, solved: bool) -> None: self.episode += 1 self.last_episode_score = score self.last_episode_steps = steps self.last_episode_solved = solved if solved: self.solves += 1 if score > self.best_score: self.best_score = score self.best_episode = self.episode self.eps = max(self.eps_end, self.eps * self.eps_decay) def policy(self) -> dict[tuple[int, int], int]: """Greedy action per learned state.""" return {s: int(np.argmax(q)) for s, q in self.q.items()} # ---------- Rendering ---------- def render_grid_html(env: WarehouseEnv, policy: dict | None = None) -> str: n = env.grid_size cell_size = max(26, min(56, 520 // n)) dot = int(cell_size * 0.6) css = f""" """ cells: list[str] = [] for r in range(n): for c in range(n): pos = (r, c) if env.grid[r, c] == 1: cls, label = "wh-obstacle", "X" elif pos == env.start_pos: cls, label = "wh-start", "S" elif pos == env.goal_pos: cls, label = "wh-goal", "G" elif policy and pos in policy: cls, label = "wh-arrow", ACTION_ARROWS[policy[pos]] else: cls, label = "wh-empty", "." inner = '
' if pos == env.agent_pos else label cells.append(f'
{inner}
') return css + f'
{"".join(cells)}
' def render_scoreboard_md(env: WarehouseEnv) -> str: if env.terminated: status = "🏁 Goal reached!" elif env.truncated: status = "⏱️ Timed out" else: status = "🎮 Playing" last_action = ( ACTION_NAMES[env.last_action] if env.last_action is not None else "None" ) dist = WarehouseEnv._manhattan(env.agent_pos, env.goal_pos) return f"""### Score Board | Field | Value | |---|---| | **Total Score** | `{env.total_score:+.2f}` | | **Last Reward** | `{env.last_reward:+.2f}` | | **Steps** | `{env.steps} / {env.max_steps}` | | **Agent Position** | `({env.agent_pos[0]}, {env.agent_pos[1]})` | | **Goal Position** | `({env.goal_pos[0]}, {env.goal_pos[1]})` | | **Manhattan Distance** | `{dist}` | | **Status** | {status} | | **Last Action** | `{last_action}` | | **Rule Fired** | {env.last_rule} | """ def render_training_stats(agent: QLearningAgent | None, target: int) -> str: if agent is None: return ( "### 🤖 RL Agent\n\n" "*No agent trained yet. Click **Train Agent** to start " "tabular Q-learning on the current maze.*" ) last_score = f"{agent.last_episode_score:+.2f}" if agent.episode else "—" last_steps = agent.last_episode_steps if agent.episode else "—" best = ( f"{agent.best_score:+.2f} (ep {agent.best_episode})" if agent.best_episode > 0 else "—" ) solved_pct = ( f"{100.0 * agent.solves / max(agent.episode, 1):.1f}%" if agent.episode else "—" ) return f"""### 🤖 RL Agent (Q-learning) | Field | Value | |---|---| | **Episode** | `{agent.episode} / {target}` | | **ε exploration** | `{agent.eps:.3f}` | | **States seen** | `{len(agent.q)}` | | **Solve rate** | `{solved_pct}` | | **Last episode** | score `{last_score}`, steps `{last_steps}` | | **Best episode** | `{best}` | """ # ---------- Training generators ---------- def train_stream( env: WarehouseEnv | None, n_episodes: float, speed_ms: float, ): """Train Q-learning on the current maze, yielding UI updates per step. A speed of 0 ms means we yield only at episode boundaries (fast-forward training); any positive value yields after every step at the requested pace. """ if env is None or env.grid is None: env = WarehouseEnv() env.reset() target = int(n_episodes) delay = max(0.0, float(speed_ms) / 1000.0) rng = np.random.default_rng() agent = QLearningAgent() # initial frame: clear the agent dot at start, show empty policy env.soft_reset() yield ( env, agent, render_grid_html(env, policy=agent.policy()), render_scoreboard_md(env), render_training_stats(agent, target), env.steps, ) for ep in range(target): env.soft_reset() s = env.agent_pos ep_score = 0.0 ep_steps = 0 solved = False while True: a = agent.select_action(s, rng) _, r, term, trunc, _ = env.step(a) s2 = env.agent_pos done = term or trunc agent.update(s, a, r, s2, done) s = s2 ep_score += r ep_steps += 1 if delay > 0.0: yield ( env, agent, render_grid_html(env, policy=agent.policy()), render_scoreboard_md(env), render_training_stats(agent, target), env.steps, ) time.sleep(delay) if done: solved = term break agent.end_episode(ep_score, ep_steps, solved) # Always yield at end-of-episode so the UI refreshes even when delay==0. yield ( env, agent, render_grid_html(env, policy=agent.policy()), render_scoreboard_md(env), render_training_stats(agent, target), env.steps, ) def greedy_stream( env: WarehouseEnv | None, agent: QLearningAgent | None, speed_ms: float, ): """Run one episode using the greedy policy of the trained agent.""" if env is None or env.grid is None or agent is None: return delay = max(0.05, float(speed_ms) / 1000.0) rng = np.random.default_rng() env.soft_reset() yield ( env, render_grid_html(env, policy=agent.policy()), render_scoreboard_md(env), env.steps, ) time.sleep(delay) while True: a = agent.greedy_action(env.agent_pos, rng) _, _, term, trunc, _ = env.step(a) yield ( env, render_grid_html(env, policy=agent.policy()), render_scoreboard_md(env), env.steps, ) time.sleep(delay) if term or trunc: break # ---------- Gradio app ---------- KEYBOARD_JS = """ () => { if (window.__wh_kb_bound) return; window.__wh_kb_bound = true; document.addEventListener('keydown', (e) => { const tag = (e.target && e.target.tagName) || ''; if (tag === 'INPUT' || tag === 'TEXTAREA' || tag === 'SELECT') return; const map = { 'ArrowUp': 'wh-btn-up', 'ArrowRight': 'wh-btn-right', 'ArrowDown': 'wh-btn-down', 'ArrowLeft': 'wh-btn-left', }; const id = map[e.key]; if (!id) return; e.preventDefault(); const wrapper = document.getElementById(id); if (!wrapper) return; const btn = wrapper.querySelector('button') || wrapper; btn.click(); }); } """ def build_app() -> gr.Blocks: initial_env = WarehouseEnv() initial_env.reset(seed=42) with gr.Blocks(title="Warehouse GridWorld") as demo: gr.Markdown( "# 📦 Warehouse GridWorld — Play & Train\n" "Use the **arrow keys** (or buttons) to move the red agent from **S** to **G**. " "Or click **Train Agent** and watch a tabular Q-learning agent learn the maze in real time." ) env_state = gr.State(initial_env) agent_state: gr.State = gr.State(None) with gr.Row(): # ---------- LEFT: grid + manual controls ---------- with gr.Column(scale=3): grid_html = gr.HTML(render_grid_html(initial_env)) gr.Markdown("**Manual play**") with gr.Row(): up_btn = gr.Button("↑ Up", elem_id="wh-btn-up") with gr.Row(): left_btn = gr.Button("← Left", elem_id="wh-btn-left") down_btn = gr.Button("↓ Down", elem_id="wh-btn-down") right_btn = gr.Button("→ Right", elem_id="wh-btn-right") gr.Markdown("**RL training**") with gr.Row(): n_episodes_slider = gr.Slider( minimum=10, maximum=1000, value=200, step=10, label="Training episodes", ) speed_slider = gr.Slider( minimum=0, maximum=300, value=20, step=5, label="Speed (ms / step, 0 = fast)", ) with gr.Row(): train_btn = gr.Button("🤖 Train Agent", variant="primary") stop_btn = gr.Button("⏹ Stop") greedy_btn = gr.Button("▶ Run Greedy") # ---------- RIGHT: settings + stats ---------- with gr.Column(scale=2): grid_size_slider = gr.Slider( minimum=3, maximum=25, value=DEFAULT_GRID_SIZE, step=1, label="Grid Size (resets on change)", ) steps_progress = gr.Slider( minimum=0, maximum=MAX_STEPS, value=0, step=1, label=f"Steps (0 / {MAX_STEPS})", interactive=False, ) reset_btn = gr.Button( "🔁 Reset / Randomize Grid", variant="primary" ) training_stats = gr.Markdown(render_training_stats(None, 0)) scoreboard = gr.Markdown(render_scoreboard_md(initial_env)) # ---------- handlers ---------- def do_step(state: WarehouseEnv, agent: QLearningAgent | None, action: int): state.step(action) policy = agent.policy() if agent else None return ( state, render_grid_html(state, policy=policy), render_scoreboard_md(state), state.steps, ) def do_reset(state: WarehouseEnv, _agent, new_size: float): new_size = int(new_size) if state is None or new_size != state.grid_size: state = WarehouseEnv(grid_size=new_size) state.reset() return ( state, None, # clear trained agent — new maze invalidates it render_grid_html(state), render_scoreboard_md(state), render_training_stats(None, 0), state.steps, ) step_outputs = [env_state, grid_html, scoreboard, steps_progress] reset_outputs = [ env_state, agent_state, grid_html, scoreboard, training_stats, steps_progress, ] train_outputs = reset_outputs # same shape greedy_outputs = step_outputs # Manual play up_evt = up_btn.click( lambda s, a: do_step(s, a, UP), inputs=[env_state, agent_state], outputs=step_outputs, ) right_evt = right_btn.click( lambda s, a: do_step(s, a, RIGHT), inputs=[env_state, agent_state], outputs=step_outputs, ) down_evt = down_btn.click( lambda s, a: do_step(s, a, DOWN), inputs=[env_state, agent_state], outputs=step_outputs, ) left_evt = left_btn.click( lambda s, a: do_step(s, a, LEFT), inputs=[env_state, agent_state], outputs=step_outputs, ) # Reset / grid resize reset_evt = reset_btn.click( do_reset, inputs=[env_state, agent_state, grid_size_slider], outputs=reset_outputs, ) resize_evt = grid_size_slider.release( do_reset, inputs=[env_state, agent_state, grid_size_slider], outputs=reset_outputs, ) # RL training (streaming generator) train_evt = train_btn.click( train_stream, inputs=[env_state, n_episodes_slider, speed_slider], outputs=train_outputs, ) # Cancel training on Stop, Reset, resize, or any manual move. stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[train_evt]) for evt in (reset_evt, resize_evt, up_evt, right_evt, down_evt, left_evt): evt.then(fn=None, inputs=None, outputs=None, cancels=[train_evt]) # Greedy rollout of the trained policy greedy_evt = greedy_btn.click( greedy_stream, inputs=[env_state, agent_state, speed_slider], outputs=greedy_outputs, ) for evt in (reset_evt, resize_evt, up_evt, right_evt, down_evt, left_evt, train_evt): evt.then(fn=None, inputs=None, outputs=None, cancels=[greedy_evt]) demo.load(fn=None, inputs=None, outputs=None, js=KEYBOARD_JS) return demo if __name__ == "__main__": app = build_app() app.launch()