"""Warehouse GridWorld - Gradio + Gymnasium navigation game with live RL training.
Run:
pip install -r requirements.txt
python app.py
"""
from __future__ import annotations
import time
from collections import deque
import gradio as gr
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# ---------- Constants ----------
DEFAULT_GRID_SIZE = 9
MAX_STEPS = 100
OBSTACLE_DENSITY = 0.20
UP, RIGHT, DOWN, LEFT = 0, 1, 2, 3
ACTION_NAMES = {0: "UP", 1: "RIGHT", 2: "DOWN", 3: "LEFT"}
ACTION_ARROWS = {0: "↑", 1: "→", 2: "↓", 3: "←"}
ACTION_DELTAS = {
UP: (-1, 0),
RIGHT: (0, 1),
DOWN: (1, 0),
LEFT: (0, -1),
}
# ---------- Environment ----------
class WarehouseEnv(gym.Env):
"""Gymnasium env for a randomized warehouse grid.
Observation: [agent_x_norm, agent_y_norm, goal_x_norm, goal_y_norm]
Action: 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT
"""
metadata = {"render_modes": ["html"]}
def __init__(self, grid_size: int = DEFAULT_GRID_SIZE, max_steps: int = MAX_STEPS):
super().__init__()
self.grid_size = int(grid_size)
self.max_steps = int(max_steps)
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(4,), dtype=np.float32
)
self.grid: np.ndarray | None = None
self.agent_pos: tuple[int, int] = (0, 0)
self.start_pos: tuple[int, int] = (0, 0)
self.goal_pos: tuple[int, int] = (0, 0)
self.steps = 0
self.total_score = 0.0
self.last_reward = 0.0
self.last_action: int | None = None
self.last_rule = "New episode started. Agent begins on S."
self.visited: set[tuple[int, int]] = set()
self.terminated = False
self.truncated = False
# --- generation ---
def _is_solvable(self, grid: np.ndarray, start: tuple[int, int], goal: tuple[int, int]) -> bool:
n = self.grid_size
if grid[start] == 1 or grid[goal] == 1:
return False
seen = {start}
q = deque([start])
while q:
r, c = q.popleft()
if (r, c) == goal:
return True
for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)):
nr, nc = r + dr, c + dc
if 0 <= nr < n and 0 <= nc < n and grid[nr, nc] == 0 and (nr, nc) not in seen:
seen.add((nr, nc))
q.append((nr, nc))
return False
def _generate_grid(self):
n = self.grid_size
rng = self.np_random
for _ in range(300):
start = (int(rng.integers(0, n)), int(rng.integers(0, n)))
goal = (int(rng.integers(0, n)), int(rng.integers(0, n)))
if start == goal:
continue
grid = (rng.random((n, n)) < OBSTACLE_DENSITY).astype(np.int8)
grid[start] = 0
grid[goal] = 0
if self._is_solvable(grid, start, goal):
return grid, start, goal
return (
np.zeros((n, n), dtype=np.int8),
(0, 0),
(n - 1, n - 1),
)
# --- helpers ---
def _get_obs(self) -> np.ndarray:
denom = max(self.grid_size - 1, 1)
ax, ay = self.agent_pos
gx, gy = self.goal_pos
return np.array(
[ax / denom, ay / denom, gx / denom, gy / denom], dtype=np.float32
)
@staticmethod
def _manhattan(a: tuple[int, int], b: tuple[int, int]) -> int:
return abs(a[0] - b[0]) + abs(a[1] - b[1])
# --- gym API ---
def reset(self, seed: int | None = None, options: dict | None = None):
super().reset(seed=seed)
self.grid, self.start_pos, self.goal_pos = self._generate_grid()
self.agent_pos = self.start_pos
self._reset_episode_counters("New episode started. Agent begins on S.")
return self._get_obs(), {}
def soft_reset(self) -> np.ndarray:
"""Reset agent to start, keep the same maze. Used between RL episodes."""
self.agent_pos = self.start_pos
self._reset_episode_counters("New episode (same maze). Agent at S.")
return self._get_obs()
def _reset_episode_counters(self, rule: str) -> None:
self.steps = 0
self.total_score = 0.0
self.last_reward = 0.0
self.last_action = None
self.last_rule = rule
self.visited = {self.start_pos}
self.terminated = False
self.truncated = False
def step(self, action: int):
if self.terminated or self.truncated:
return self._get_obs(), 0.0, self.terminated, self.truncated, {}
action = int(action)
self.steps += 1
self.last_action = action
dr, dc = ACTION_DELTAS[action]
nr, nc = self.agent_pos[0] + dr, self.agent_pos[1] + dc
n = self.grid_size
old_dist = self._manhattan(self.agent_pos, self.goal_pos)
reward = 0.0
rule_parts: list[str] = []
out_of_bounds = not (0 <= nr < n and 0 <= nc < n)
is_obstacle = (not out_of_bounds) and self.grid[nr, nc] == 1
if out_of_bounds or is_obstacle:
reward += -5.0
rule_parts.append(
"Invalid move: " + ("out of bounds" if out_of_bounds else "obstacle")
+ " (-5.0)"
)
else:
self.agent_pos = (nr, nc)
new_dist = self._manhattan(self.agent_pos, self.goal_pos)
if new_dist < old_dist:
reward += 1.0
rule_parts.append("Closer to goal (+1.0)")
elif new_dist > old_dist:
reward += -0.5
rule_parts.append("Farther from goal (-0.5)")
else:
reward += -0.1
rule_parts.append("Same Manhattan distance (-0.1)")
if self.agent_pos not in self.visited:
reward += 0.3
rule_parts.append("New cell (+0.3)")
self.visited.add(self.agent_pos)
if self.agent_pos == self.goal_pos:
reward += 50.0
rule_parts.append("GOAL reached (+50.0)")
self.terminated = True
if not self.terminated and self.steps >= self.max_steps:
reward += -10.0
rule_parts.append("Step limit timeout (-10.0)")
self.truncated = True
self.last_reward = reward
self.total_score += reward
self.last_rule = "; ".join(rule_parts) + "."
return self._get_obs(), reward, self.terminated, self.truncated, {}
# ---------- Q-learning Agent ----------
class QLearningAgent:
"""Tabular Q-learning keyed on agent position. Fits this small grid perfectly."""
def __init__(
self,
n_actions: int = 4,
alpha: float = 0.2,
gamma: float = 0.95,
eps_start: float = 1.0,
eps_end: float = 0.05,
eps_decay: float = 0.97,
):
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.eps = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.q: dict[tuple[int, int], np.ndarray] = {}
self.episode = 0
self.last_episode_steps = 0
self.last_episode_score = 0.0
self.last_episode_solved = False
self.best_score = -float("inf")
self.best_episode = -1
self.solves = 0
def get_q(self, s: tuple[int, int]) -> np.ndarray:
if s not in self.q:
self.q[s] = np.zeros(self.n_actions, dtype=np.float32)
return self.q[s]
def select_action(self, s: tuple[int, int], rng: np.random.Generator) -> int:
if rng.random() < self.eps:
return int(rng.integers(0, self.n_actions))
q = self.get_q(s)
best = np.flatnonzero(q == q.max())
return int(rng.choice(best))
def greedy_action(self, s: tuple[int, int], rng: np.random.Generator) -> int:
q = self.get_q(s)
best = np.flatnonzero(q == q.max())
return int(rng.choice(best))
def update(
self,
s: tuple[int, int],
a: int,
r: float,
s2: tuple[int, int],
done: bool,
) -> None:
q_s = self.get_q(s)
target = r if done else r + self.gamma * float(self.get_q(s2).max())
q_s[a] += self.alpha * (target - q_s[a])
def end_episode(self, score: float, steps: int, solved: bool) -> None:
self.episode += 1
self.last_episode_score = score
self.last_episode_steps = steps
self.last_episode_solved = solved
if solved:
self.solves += 1
if score > self.best_score:
self.best_score = score
self.best_episode = self.episode
self.eps = max(self.eps_end, self.eps * self.eps_decay)
def policy(self) -> dict[tuple[int, int], int]:
"""Greedy action per learned state."""
return {s: int(np.argmax(q)) for s, q in self.q.items()}
# ---------- Rendering ----------
def render_grid_html(env: WarehouseEnv, policy: dict | None = None) -> str:
n = env.grid_size
cell_size = max(26, min(56, 520 // n))
dot = int(cell_size * 0.6)
css = f"""
"""
cells: list[str] = []
for r in range(n):
for c in range(n):
pos = (r, c)
if env.grid[r, c] == 1:
cls, label = "wh-obstacle", "X"
elif pos == env.start_pos:
cls, label = "wh-start", "S"
elif pos == env.goal_pos:
cls, label = "wh-goal", "G"
elif policy and pos in policy:
cls, label = "wh-arrow", ACTION_ARROWS[policy[pos]]
else:
cls, label = "wh-empty", "."
inner = '
' if pos == env.agent_pos else label
cells.append(f'{inner}
')
return css + f''
def render_scoreboard_md(env: WarehouseEnv) -> str:
if env.terminated:
status = "🏁 Goal reached!"
elif env.truncated:
status = "⏱️ Timed out"
else:
status = "🎮 Playing"
last_action = (
ACTION_NAMES[env.last_action] if env.last_action is not None else "None"
)
dist = WarehouseEnv._manhattan(env.agent_pos, env.goal_pos)
return f"""### Score Board
| Field | Value |
|---|---|
| **Total Score** | `{env.total_score:+.2f}` |
| **Last Reward** | `{env.last_reward:+.2f}` |
| **Steps** | `{env.steps} / {env.max_steps}` |
| **Agent Position** | `({env.agent_pos[0]}, {env.agent_pos[1]})` |
| **Goal Position** | `({env.goal_pos[0]}, {env.goal_pos[1]})` |
| **Manhattan Distance** | `{dist}` |
| **Status** | {status} |
| **Last Action** | `{last_action}` |
| **Rule Fired** | {env.last_rule} |
"""
def render_training_stats(agent: QLearningAgent | None, target: int) -> str:
if agent is None:
return (
"### 🤖 RL Agent\n\n"
"*No agent trained yet. Click **Train Agent** to start "
"tabular Q-learning on the current maze.*"
)
last_score = f"{agent.last_episode_score:+.2f}" if agent.episode else "—"
last_steps = agent.last_episode_steps if agent.episode else "—"
best = (
f"{agent.best_score:+.2f} (ep {agent.best_episode})"
if agent.best_episode > 0
else "—"
)
solved_pct = (
f"{100.0 * agent.solves / max(agent.episode, 1):.1f}%"
if agent.episode
else "—"
)
return f"""### 🤖 RL Agent (Q-learning)
| Field | Value |
|---|---|
| **Episode** | `{agent.episode} / {target}` |
| **ε exploration** | `{agent.eps:.3f}` |
| **States seen** | `{len(agent.q)}` |
| **Solve rate** | `{solved_pct}` |
| **Last episode** | score `{last_score}`, steps `{last_steps}` |
| **Best episode** | `{best}` |
"""
# ---------- Training generators ----------
def train_stream(
env: WarehouseEnv | None,
n_episodes: float,
speed_ms: float,
):
"""Train Q-learning on the current maze, yielding UI updates per step.
A speed of 0 ms means we yield only at episode boundaries (fast-forward training);
any positive value yields after every step at the requested pace.
"""
if env is None or env.grid is None:
env = WarehouseEnv()
env.reset()
target = int(n_episodes)
delay = max(0.0, float(speed_ms) / 1000.0)
rng = np.random.default_rng()
agent = QLearningAgent()
# initial frame: clear the agent dot at start, show empty policy
env.soft_reset()
yield (
env,
agent,
render_grid_html(env, policy=agent.policy()),
render_scoreboard_md(env),
render_training_stats(agent, target),
env.steps,
)
for ep in range(target):
env.soft_reset()
s = env.agent_pos
ep_score = 0.0
ep_steps = 0
solved = False
while True:
a = agent.select_action(s, rng)
_, r, term, trunc, _ = env.step(a)
s2 = env.agent_pos
done = term or trunc
agent.update(s, a, r, s2, done)
s = s2
ep_score += r
ep_steps += 1
if delay > 0.0:
yield (
env,
agent,
render_grid_html(env, policy=agent.policy()),
render_scoreboard_md(env),
render_training_stats(agent, target),
env.steps,
)
time.sleep(delay)
if done:
solved = term
break
agent.end_episode(ep_score, ep_steps, solved)
# Always yield at end-of-episode so the UI refreshes even when delay==0.
yield (
env,
agent,
render_grid_html(env, policy=agent.policy()),
render_scoreboard_md(env),
render_training_stats(agent, target),
env.steps,
)
def greedy_stream(
env: WarehouseEnv | None,
agent: QLearningAgent | None,
speed_ms: float,
):
"""Run one episode using the greedy policy of the trained agent."""
if env is None or env.grid is None or agent is None:
return
delay = max(0.05, float(speed_ms) / 1000.0)
rng = np.random.default_rng()
env.soft_reset()
yield (
env,
render_grid_html(env, policy=agent.policy()),
render_scoreboard_md(env),
env.steps,
)
time.sleep(delay)
while True:
a = agent.greedy_action(env.agent_pos, rng)
_, _, term, trunc, _ = env.step(a)
yield (
env,
render_grid_html(env, policy=agent.policy()),
render_scoreboard_md(env),
env.steps,
)
time.sleep(delay)
if term or trunc:
break
# ---------- Gradio app ----------
KEYBOARD_JS = """
() => {
if (window.__wh_kb_bound) return;
window.__wh_kb_bound = true;
document.addEventListener('keydown', (e) => {
const tag = (e.target && e.target.tagName) || '';
if (tag === 'INPUT' || tag === 'TEXTAREA' || tag === 'SELECT') return;
const map = {
'ArrowUp': 'wh-btn-up',
'ArrowRight': 'wh-btn-right',
'ArrowDown': 'wh-btn-down',
'ArrowLeft': 'wh-btn-left',
};
const id = map[e.key];
if (!id) return;
e.preventDefault();
const wrapper = document.getElementById(id);
if (!wrapper) return;
const btn = wrapper.querySelector('button') || wrapper;
btn.click();
});
}
"""
def build_app() -> gr.Blocks:
initial_env = WarehouseEnv()
initial_env.reset(seed=42)
with gr.Blocks(title="Warehouse GridWorld") as demo:
gr.Markdown(
"# 📦 Warehouse GridWorld — Play & Train\n"
"Use the **arrow keys** (or buttons) to move the red agent from **S** to **G**. "
"Or click **Train Agent** and watch a tabular Q-learning agent learn the maze in real time."
)
env_state = gr.State(initial_env)
agent_state: gr.State = gr.State(None)
with gr.Row():
# ---------- LEFT: grid + manual controls ----------
with gr.Column(scale=3):
grid_html = gr.HTML(render_grid_html(initial_env))
gr.Markdown("**Manual play**")
with gr.Row():
up_btn = gr.Button("↑ Up", elem_id="wh-btn-up")
with gr.Row():
left_btn = gr.Button("← Left", elem_id="wh-btn-left")
down_btn = gr.Button("↓ Down", elem_id="wh-btn-down")
right_btn = gr.Button("→ Right", elem_id="wh-btn-right")
gr.Markdown("**RL training**")
with gr.Row():
n_episodes_slider = gr.Slider(
minimum=10,
maximum=1000,
value=200,
step=10,
label="Training episodes",
)
speed_slider = gr.Slider(
minimum=0,
maximum=300,
value=20,
step=5,
label="Speed (ms / step, 0 = fast)",
)
with gr.Row():
train_btn = gr.Button("🤖 Train Agent", variant="primary")
stop_btn = gr.Button("⏹ Stop")
greedy_btn = gr.Button("▶ Run Greedy")
# ---------- RIGHT: settings + stats ----------
with gr.Column(scale=2):
grid_size_slider = gr.Slider(
minimum=3,
maximum=25,
value=DEFAULT_GRID_SIZE,
step=1,
label="Grid Size (resets on change)",
)
steps_progress = gr.Slider(
minimum=0,
maximum=MAX_STEPS,
value=0,
step=1,
label=f"Steps (0 / {MAX_STEPS})",
interactive=False,
)
reset_btn = gr.Button(
"🔁 Reset / Randomize Grid", variant="primary"
)
training_stats = gr.Markdown(render_training_stats(None, 0))
scoreboard = gr.Markdown(render_scoreboard_md(initial_env))
# ---------- handlers ----------
def do_step(state: WarehouseEnv, agent: QLearningAgent | None, action: int):
state.step(action)
policy = agent.policy() if agent else None
return (
state,
render_grid_html(state, policy=policy),
render_scoreboard_md(state),
state.steps,
)
def do_reset(state: WarehouseEnv, _agent, new_size: float):
new_size = int(new_size)
if state is None or new_size != state.grid_size:
state = WarehouseEnv(grid_size=new_size)
state.reset()
return (
state,
None, # clear trained agent — new maze invalidates it
render_grid_html(state),
render_scoreboard_md(state),
render_training_stats(None, 0),
state.steps,
)
step_outputs = [env_state, grid_html, scoreboard, steps_progress]
reset_outputs = [
env_state,
agent_state,
grid_html,
scoreboard,
training_stats,
steps_progress,
]
train_outputs = reset_outputs # same shape
greedy_outputs = step_outputs
# Manual play
up_evt = up_btn.click(
lambda s, a: do_step(s, a, UP),
inputs=[env_state, agent_state], outputs=step_outputs,
)
right_evt = right_btn.click(
lambda s, a: do_step(s, a, RIGHT),
inputs=[env_state, agent_state], outputs=step_outputs,
)
down_evt = down_btn.click(
lambda s, a: do_step(s, a, DOWN),
inputs=[env_state, agent_state], outputs=step_outputs,
)
left_evt = left_btn.click(
lambda s, a: do_step(s, a, LEFT),
inputs=[env_state, agent_state], outputs=step_outputs,
)
# Reset / grid resize
reset_evt = reset_btn.click(
do_reset,
inputs=[env_state, agent_state, grid_size_slider],
outputs=reset_outputs,
)
resize_evt = grid_size_slider.release(
do_reset,
inputs=[env_state, agent_state, grid_size_slider],
outputs=reset_outputs,
)
# RL training (streaming generator)
train_evt = train_btn.click(
train_stream,
inputs=[env_state, n_episodes_slider, speed_slider],
outputs=train_outputs,
)
# Cancel training on Stop, Reset, resize, or any manual move.
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[train_evt])
for evt in (reset_evt, resize_evt, up_evt, right_evt, down_evt, left_evt):
evt.then(fn=None, inputs=None, outputs=None, cancels=[train_evt])
# Greedy rollout of the trained policy
greedy_evt = greedy_btn.click(
greedy_stream,
inputs=[env_state, agent_state, speed_slider],
outputs=greedy_outputs,
)
for evt in (reset_evt, resize_evt, up_evt, right_evt, down_evt, left_evt, train_evt):
evt.then(fn=None, inputs=None, outputs=None, cancels=[greedy_evt])
demo.load(fn=None, inputs=None, outputs=None, js=KEYBOARD_JS)
return demo
if __name__ == "__main__":
app = build_app()
app.launch()