Spaces:
Sleeping
Sleeping
| """Warehouse GridWorld - Gradio + Gymnasium navigation game with live RL training. | |
| Run: | |
| pip install -r requirements.txt | |
| python app.py | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from collections import deque | |
| import gradio as gr | |
| import gymnasium as gym | |
| import numpy as np | |
| from gymnasium import spaces | |
| # ---------- Constants ---------- | |
| DEFAULT_GRID_SIZE = 9 | |
| MAX_STEPS = 100 | |
| OBSTACLE_DENSITY = 0.20 | |
| UP, RIGHT, DOWN, LEFT = 0, 1, 2, 3 | |
| ACTION_NAMES = {0: "UP", 1: "RIGHT", 2: "DOWN", 3: "LEFT"} | |
| ACTION_ARROWS = {0: "↑", 1: "→", 2: "↓", 3: "←"} | |
| ACTION_DELTAS = { | |
| UP: (-1, 0), | |
| RIGHT: (0, 1), | |
| DOWN: (1, 0), | |
| LEFT: (0, -1), | |
| } | |
| # ---------- Environment ---------- | |
| class WarehouseEnv(gym.Env): | |
| """Gymnasium env for a randomized warehouse grid. | |
| Observation: [agent_x_norm, agent_y_norm, goal_x_norm, goal_y_norm] | |
| Action: 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT | |
| """ | |
| metadata = {"render_modes": ["html"]} | |
| def __init__(self, grid_size: int = DEFAULT_GRID_SIZE, max_steps: int = MAX_STEPS): | |
| super().__init__() | |
| self.grid_size = int(grid_size) | |
| self.max_steps = int(max_steps) | |
| self.action_space = spaces.Discrete(4) | |
| self.observation_space = spaces.Box( | |
| low=0.0, high=1.0, shape=(4,), dtype=np.float32 | |
| ) | |
| self.grid: np.ndarray | None = None | |
| self.agent_pos: tuple[int, int] = (0, 0) | |
| self.start_pos: tuple[int, int] = (0, 0) | |
| self.goal_pos: tuple[int, int] = (0, 0) | |
| self.steps = 0 | |
| self.total_score = 0.0 | |
| self.last_reward = 0.0 | |
| self.last_action: int | None = None | |
| self.last_rule = "New episode started. Agent begins on S." | |
| self.visited: set[tuple[int, int]] = set() | |
| self.terminated = False | |
| self.truncated = False | |
| # --- generation --- | |
| def _is_solvable(self, grid: np.ndarray, start: tuple[int, int], goal: tuple[int, int]) -> bool: | |
| n = self.grid_size | |
| if grid[start] == 1 or grid[goal] == 1: | |
| return False | |
| seen = {start} | |
| q = deque([start]) | |
| while q: | |
| r, c = q.popleft() | |
| if (r, c) == goal: | |
| return True | |
| for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): | |
| nr, nc = r + dr, c + dc | |
| if 0 <= nr < n and 0 <= nc < n and grid[nr, nc] == 0 and (nr, nc) not in seen: | |
| seen.add((nr, nc)) | |
| q.append((nr, nc)) | |
| return False | |
| def _generate_grid(self): | |
| n = self.grid_size | |
| rng = self.np_random | |
| for _ in range(300): | |
| start = (int(rng.integers(0, n)), int(rng.integers(0, n))) | |
| goal = (int(rng.integers(0, n)), int(rng.integers(0, n))) | |
| if start == goal: | |
| continue | |
| grid = (rng.random((n, n)) < OBSTACLE_DENSITY).astype(np.int8) | |
| grid[start] = 0 | |
| grid[goal] = 0 | |
| if self._is_solvable(grid, start, goal): | |
| return grid, start, goal | |
| return ( | |
| np.zeros((n, n), dtype=np.int8), | |
| (0, 0), | |
| (n - 1, n - 1), | |
| ) | |
| # --- helpers --- | |
| def _get_obs(self) -> np.ndarray: | |
| denom = max(self.grid_size - 1, 1) | |
| ax, ay = self.agent_pos | |
| gx, gy = self.goal_pos | |
| return np.array( | |
| [ax / denom, ay / denom, gx / denom, gy / denom], dtype=np.float32 | |
| ) | |
| def _manhattan(a: tuple[int, int], b: tuple[int, int]) -> int: | |
| return abs(a[0] - b[0]) + abs(a[1] - b[1]) | |
| # --- gym API --- | |
| def reset(self, seed: int | None = None, options: dict | None = None): | |
| super().reset(seed=seed) | |
| self.grid, self.start_pos, self.goal_pos = self._generate_grid() | |
| self.agent_pos = self.start_pos | |
| self._reset_episode_counters("New episode started. Agent begins on S.") | |
| return self._get_obs(), {} | |
| def soft_reset(self) -> np.ndarray: | |
| """Reset agent to start, keep the same maze. Used between RL episodes.""" | |
| self.agent_pos = self.start_pos | |
| self._reset_episode_counters("New episode (same maze). Agent at S.") | |
| return self._get_obs() | |
| def _reset_episode_counters(self, rule: str) -> None: | |
| self.steps = 0 | |
| self.total_score = 0.0 | |
| self.last_reward = 0.0 | |
| self.last_action = None | |
| self.last_rule = rule | |
| self.visited = {self.start_pos} | |
| self.terminated = False | |
| self.truncated = False | |
| def step(self, action: int): | |
| if self.terminated or self.truncated: | |
| return self._get_obs(), 0.0, self.terminated, self.truncated, {} | |
| action = int(action) | |
| self.steps += 1 | |
| self.last_action = action | |
| dr, dc = ACTION_DELTAS[action] | |
| nr, nc = self.agent_pos[0] + dr, self.agent_pos[1] + dc | |
| n = self.grid_size | |
| old_dist = self._manhattan(self.agent_pos, self.goal_pos) | |
| reward = 0.0 | |
| rule_parts: list[str] = [] | |
| out_of_bounds = not (0 <= nr < n and 0 <= nc < n) | |
| is_obstacle = (not out_of_bounds) and self.grid[nr, nc] == 1 | |
| if out_of_bounds or is_obstacle: | |
| reward += -5.0 | |
| rule_parts.append( | |
| "Invalid move: " + ("out of bounds" if out_of_bounds else "obstacle") | |
| + " (-5.0)" | |
| ) | |
| else: | |
| self.agent_pos = (nr, nc) | |
| new_dist = self._manhattan(self.agent_pos, self.goal_pos) | |
| if new_dist < old_dist: | |
| reward += 1.0 | |
| rule_parts.append("Closer to goal (+1.0)") | |
| elif new_dist > old_dist: | |
| reward += -0.5 | |
| rule_parts.append("Farther from goal (-0.5)") | |
| else: | |
| reward += -0.1 | |
| rule_parts.append("Same Manhattan distance (-0.1)") | |
| if self.agent_pos not in self.visited: | |
| reward += 0.3 | |
| rule_parts.append("New cell (+0.3)") | |
| self.visited.add(self.agent_pos) | |
| if self.agent_pos == self.goal_pos: | |
| reward += 50.0 | |
| rule_parts.append("GOAL reached (+50.0)") | |
| self.terminated = True | |
| if not self.terminated and self.steps >= self.max_steps: | |
| reward += -10.0 | |
| rule_parts.append("Step limit timeout (-10.0)") | |
| self.truncated = True | |
| self.last_reward = reward | |
| self.total_score += reward | |
| self.last_rule = "; ".join(rule_parts) + "." | |
| return self._get_obs(), reward, self.terminated, self.truncated, {} | |
| # ---------- Q-learning Agent ---------- | |
| class QLearningAgent: | |
| """Tabular Q-learning keyed on agent position. Fits this small grid perfectly.""" | |
| def __init__( | |
| self, | |
| n_actions: int = 4, | |
| alpha: float = 0.2, | |
| gamma: float = 0.95, | |
| eps_start: float = 1.0, | |
| eps_end: float = 0.05, | |
| eps_decay: float = 0.97, | |
| ): | |
| self.n_actions = n_actions | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.eps = eps_start | |
| self.eps_end = eps_end | |
| self.eps_decay = eps_decay | |
| self.q: dict[tuple[int, int], np.ndarray] = {} | |
| self.episode = 0 | |
| self.last_episode_steps = 0 | |
| self.last_episode_score = 0.0 | |
| self.last_episode_solved = False | |
| self.best_score = -float("inf") | |
| self.best_episode = -1 | |
| self.solves = 0 | |
| def get_q(self, s: tuple[int, int]) -> np.ndarray: | |
| if s not in self.q: | |
| self.q[s] = np.zeros(self.n_actions, dtype=np.float32) | |
| return self.q[s] | |
| def select_action(self, s: tuple[int, int], rng: np.random.Generator) -> int: | |
| if rng.random() < self.eps: | |
| return int(rng.integers(0, self.n_actions)) | |
| q = self.get_q(s) | |
| best = np.flatnonzero(q == q.max()) | |
| return int(rng.choice(best)) | |
| def greedy_action(self, s: tuple[int, int], rng: np.random.Generator) -> int: | |
| q = self.get_q(s) | |
| best = np.flatnonzero(q == q.max()) | |
| return int(rng.choice(best)) | |
| def update( | |
| self, | |
| s: tuple[int, int], | |
| a: int, | |
| r: float, | |
| s2: tuple[int, int], | |
| done: bool, | |
| ) -> None: | |
| q_s = self.get_q(s) | |
| target = r if done else r + self.gamma * float(self.get_q(s2).max()) | |
| q_s[a] += self.alpha * (target - q_s[a]) | |
| def end_episode(self, score: float, steps: int, solved: bool) -> None: | |
| self.episode += 1 | |
| self.last_episode_score = score | |
| self.last_episode_steps = steps | |
| self.last_episode_solved = solved | |
| if solved: | |
| self.solves += 1 | |
| if score > self.best_score: | |
| self.best_score = score | |
| self.best_episode = self.episode | |
| self.eps = max(self.eps_end, self.eps * self.eps_decay) | |
| def policy(self) -> dict[tuple[int, int], int]: | |
| """Greedy action per learned state.""" | |
| return {s: int(np.argmax(q)) for s, q in self.q.items()} | |
| # ---------- Rendering ---------- | |
| def render_grid_html(env: WarehouseEnv, policy: dict | None = None) -> str: | |
| n = env.grid_size | |
| cell_size = max(26, min(56, 520 // n)) | |
| dot = int(cell_size * 0.6) | |
| css = f""" | |
| <style> | |
| .wh-wrap {{ display: inline-block; }} | |
| .wh-grid {{ | |
| display: grid; | |
| grid-template-columns: repeat({n}, {cell_size}px); | |
| grid-template-rows: repeat({n}, {cell_size}px); | |
| gap: 1px; | |
| background: #333; | |
| padding: 1px; | |
| border: 2px solid #222; | |
| width: fit-content; | |
| }} | |
| .wh-cell {{ | |
| width: {cell_size}px; | |
| height: {cell_size}px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, monospace; | |
| font-weight: 700; | |
| font-size: {int(cell_size * 0.42)}px; | |
| }} | |
| .wh-empty {{ background: #f3f3f3; color: #cfcfcf; }} | |
| .wh-arrow {{ background: #f3f3f3; color: #6b8fbf; }} | |
| .wh-obstacle {{ background: #2b3a55; color: #2b3a55; }} | |
| .wh-start {{ background: #79b6ff; color: #003f8a; }} | |
| .wh-goal {{ background: #6ee08a; color: #0a5022; }} | |
| .wh-dot {{ | |
| width: {dot}px; | |
| height: {dot}px; | |
| border-radius: 50%; | |
| background: #e63946; | |
| border: 2px solid #7a1018; | |
| box-shadow: 0 0 4px rgba(0,0,0,0.35); | |
| }} | |
| </style> | |
| """ | |
| cells: list[str] = [] | |
| for r in range(n): | |
| for c in range(n): | |
| pos = (r, c) | |
| if env.grid[r, c] == 1: | |
| cls, label = "wh-obstacle", "X" | |
| elif pos == env.start_pos: | |
| cls, label = "wh-start", "S" | |
| elif pos == env.goal_pos: | |
| cls, label = "wh-goal", "G" | |
| elif policy and pos in policy: | |
| cls, label = "wh-arrow", ACTION_ARROWS[policy[pos]] | |
| else: | |
| cls, label = "wh-empty", "." | |
| inner = '<div class="wh-dot"></div>' if pos == env.agent_pos else label | |
| cells.append(f'<div class="wh-cell {cls}">{inner}</div>') | |
| return css + f'<div class="wh-wrap"><div class="wh-grid">{"".join(cells)}</div></div>' | |
| def render_scoreboard_md(env: WarehouseEnv) -> str: | |
| if env.terminated: | |
| status = "🏁 Goal reached!" | |
| elif env.truncated: | |
| status = "⏱️ Timed out" | |
| else: | |
| status = "🎮 Playing" | |
| last_action = ( | |
| ACTION_NAMES[env.last_action] if env.last_action is not None else "None" | |
| ) | |
| dist = WarehouseEnv._manhattan(env.agent_pos, env.goal_pos) | |
| return f"""### Score Board | |
| | Field | Value | | |
| |---|---| | |
| | **Total Score** | `{env.total_score:+.2f}` | | |
| | **Last Reward** | `{env.last_reward:+.2f}` | | |
| | **Steps** | `{env.steps} / {env.max_steps}` | | |
| | **Agent Position** | `({env.agent_pos[0]}, {env.agent_pos[1]})` | | |
| | **Goal Position** | `({env.goal_pos[0]}, {env.goal_pos[1]})` | | |
| | **Manhattan Distance** | `{dist}` | | |
| | **Status** | {status} | | |
| | **Last Action** | `{last_action}` | | |
| | **Rule Fired** | {env.last_rule} | | |
| """ | |
| def render_training_stats(agent: QLearningAgent | None, target: int) -> str: | |
| if agent is None: | |
| return ( | |
| "### 🤖 RL Agent\n\n" | |
| "*No agent trained yet. Click **Train Agent** to start " | |
| "tabular Q-learning on the current maze.*" | |
| ) | |
| last_score = f"{agent.last_episode_score:+.2f}" if agent.episode else "—" | |
| last_steps = agent.last_episode_steps if agent.episode else "—" | |
| best = ( | |
| f"{agent.best_score:+.2f} (ep {agent.best_episode})" | |
| if agent.best_episode > 0 | |
| else "—" | |
| ) | |
| solved_pct = ( | |
| f"{100.0 * agent.solves / max(agent.episode, 1):.1f}%" | |
| if agent.episode | |
| else "—" | |
| ) | |
| return f"""### 🤖 RL Agent (Q-learning) | |
| | Field | Value | | |
| |---|---| | |
| | **Episode** | `{agent.episode} / {target}` | | |
| | **ε exploration** | `{agent.eps:.3f}` | | |
| | **States seen** | `{len(agent.q)}` | | |
| | **Solve rate** | `{solved_pct}` | | |
| | **Last episode** | score `{last_score}`, steps `{last_steps}` | | |
| | **Best episode** | `{best}` | | |
| """ | |
| # ---------- Training generators ---------- | |
| def train_stream( | |
| env: WarehouseEnv | None, | |
| n_episodes: float, | |
| speed_ms: float, | |
| ): | |
| """Train Q-learning on the current maze, yielding UI updates per step. | |
| A speed of 0 ms means we yield only at episode boundaries (fast-forward training); | |
| any positive value yields after every step at the requested pace. | |
| """ | |
| if env is None or env.grid is None: | |
| env = WarehouseEnv() | |
| env.reset() | |
| target = int(n_episodes) | |
| delay = max(0.0, float(speed_ms) / 1000.0) | |
| rng = np.random.default_rng() | |
| agent = QLearningAgent() | |
| # initial frame: clear the agent dot at start, show empty policy | |
| env.soft_reset() | |
| yield ( | |
| env, | |
| agent, | |
| render_grid_html(env, policy=agent.policy()), | |
| render_scoreboard_md(env), | |
| render_training_stats(agent, target), | |
| env.steps, | |
| ) | |
| for ep in range(target): | |
| env.soft_reset() | |
| s = env.agent_pos | |
| ep_score = 0.0 | |
| ep_steps = 0 | |
| solved = False | |
| while True: | |
| a = agent.select_action(s, rng) | |
| _, r, term, trunc, _ = env.step(a) | |
| s2 = env.agent_pos | |
| done = term or trunc | |
| agent.update(s, a, r, s2, done) | |
| s = s2 | |
| ep_score += r | |
| ep_steps += 1 | |
| if delay > 0.0: | |
| yield ( | |
| env, | |
| agent, | |
| render_grid_html(env, policy=agent.policy()), | |
| render_scoreboard_md(env), | |
| render_training_stats(agent, target), | |
| env.steps, | |
| ) | |
| time.sleep(delay) | |
| if done: | |
| solved = term | |
| break | |
| agent.end_episode(ep_score, ep_steps, solved) | |
| # Always yield at end-of-episode so the UI refreshes even when delay==0. | |
| yield ( | |
| env, | |
| agent, | |
| render_grid_html(env, policy=agent.policy()), | |
| render_scoreboard_md(env), | |
| render_training_stats(agent, target), | |
| env.steps, | |
| ) | |
| def greedy_stream( | |
| env: WarehouseEnv | None, | |
| agent: QLearningAgent | None, | |
| speed_ms: float, | |
| ): | |
| """Run one episode using the greedy policy of the trained agent.""" | |
| if env is None or env.grid is None or agent is None: | |
| return | |
| delay = max(0.05, float(speed_ms) / 1000.0) | |
| rng = np.random.default_rng() | |
| env.soft_reset() | |
| yield ( | |
| env, | |
| render_grid_html(env, policy=agent.policy()), | |
| render_scoreboard_md(env), | |
| env.steps, | |
| ) | |
| time.sleep(delay) | |
| while True: | |
| a = agent.greedy_action(env.agent_pos, rng) | |
| _, _, term, trunc, _ = env.step(a) | |
| yield ( | |
| env, | |
| render_grid_html(env, policy=agent.policy()), | |
| render_scoreboard_md(env), | |
| env.steps, | |
| ) | |
| time.sleep(delay) | |
| if term or trunc: | |
| break | |
| # ---------- Gradio app ---------- | |
| KEYBOARD_JS = """ | |
| () => { | |
| if (window.__wh_kb_bound) return; | |
| window.__wh_kb_bound = true; | |
| document.addEventListener('keydown', (e) => { | |
| const tag = (e.target && e.target.tagName) || ''; | |
| if (tag === 'INPUT' || tag === 'TEXTAREA' || tag === 'SELECT') return; | |
| const map = { | |
| 'ArrowUp': 'wh-btn-up', | |
| 'ArrowRight': 'wh-btn-right', | |
| 'ArrowDown': 'wh-btn-down', | |
| 'ArrowLeft': 'wh-btn-left', | |
| }; | |
| const id = map[e.key]; | |
| if (!id) return; | |
| e.preventDefault(); | |
| const wrapper = document.getElementById(id); | |
| if (!wrapper) return; | |
| const btn = wrapper.querySelector('button') || wrapper; | |
| btn.click(); | |
| }); | |
| } | |
| """ | |
| def build_app() -> gr.Blocks: | |
| initial_env = WarehouseEnv() | |
| initial_env.reset(seed=42) | |
| with gr.Blocks(title="Warehouse GridWorld") as demo: | |
| gr.Markdown( | |
| "# 📦 Warehouse GridWorld — Play & Train\n" | |
| "Use the **arrow keys** (or buttons) to move the red agent from **S** to **G**. " | |
| "Or click **Train Agent** and watch a tabular Q-learning agent learn the maze in real time." | |
| ) | |
| env_state = gr.State(initial_env) | |
| agent_state: gr.State = gr.State(None) | |
| with gr.Row(): | |
| # ---------- LEFT: grid + manual controls ---------- | |
| with gr.Column(scale=3): | |
| grid_html = gr.HTML(render_grid_html(initial_env)) | |
| gr.Markdown("**Manual play**") | |
| with gr.Row(): | |
| up_btn = gr.Button("↑ Up", elem_id="wh-btn-up") | |
| with gr.Row(): | |
| left_btn = gr.Button("← Left", elem_id="wh-btn-left") | |
| down_btn = gr.Button("↓ Down", elem_id="wh-btn-down") | |
| right_btn = gr.Button("→ Right", elem_id="wh-btn-right") | |
| gr.Markdown("**RL training**") | |
| with gr.Row(): | |
| n_episodes_slider = gr.Slider( | |
| minimum=10, | |
| maximum=1000, | |
| value=200, | |
| step=10, | |
| label="Training episodes", | |
| ) | |
| speed_slider = gr.Slider( | |
| minimum=0, | |
| maximum=300, | |
| value=20, | |
| step=5, | |
| label="Speed (ms / step, 0 = fast)", | |
| ) | |
| with gr.Row(): | |
| train_btn = gr.Button("🤖 Train Agent", variant="primary") | |
| stop_btn = gr.Button("⏹ Stop") | |
| greedy_btn = gr.Button("▶ Run Greedy") | |
| # ---------- RIGHT: settings + stats ---------- | |
| with gr.Column(scale=2): | |
| grid_size_slider = gr.Slider( | |
| minimum=3, | |
| maximum=25, | |
| value=DEFAULT_GRID_SIZE, | |
| step=1, | |
| label="Grid Size (resets on change)", | |
| ) | |
| steps_progress = gr.Slider( | |
| minimum=0, | |
| maximum=MAX_STEPS, | |
| value=0, | |
| step=1, | |
| label=f"Steps (0 / {MAX_STEPS})", | |
| interactive=False, | |
| ) | |
| reset_btn = gr.Button( | |
| "🔁 Reset / Randomize Grid", variant="primary" | |
| ) | |
| training_stats = gr.Markdown(render_training_stats(None, 0)) | |
| scoreboard = gr.Markdown(render_scoreboard_md(initial_env)) | |
| # ---------- handlers ---------- | |
| def do_step(state: WarehouseEnv, agent: QLearningAgent | None, action: int): | |
| state.step(action) | |
| policy = agent.policy() if agent else None | |
| return ( | |
| state, | |
| render_grid_html(state, policy=policy), | |
| render_scoreboard_md(state), | |
| state.steps, | |
| ) | |
| def do_reset(state: WarehouseEnv, _agent, new_size: float): | |
| new_size = int(new_size) | |
| if state is None or new_size != state.grid_size: | |
| state = WarehouseEnv(grid_size=new_size) | |
| state.reset() | |
| return ( | |
| state, | |
| None, # clear trained agent — new maze invalidates it | |
| render_grid_html(state), | |
| render_scoreboard_md(state), | |
| render_training_stats(None, 0), | |
| state.steps, | |
| ) | |
| step_outputs = [env_state, grid_html, scoreboard, steps_progress] | |
| reset_outputs = [ | |
| env_state, | |
| agent_state, | |
| grid_html, | |
| scoreboard, | |
| training_stats, | |
| steps_progress, | |
| ] | |
| train_outputs = reset_outputs # same shape | |
| greedy_outputs = step_outputs | |
| # Manual play | |
| up_evt = up_btn.click( | |
| lambda s, a: do_step(s, a, UP), | |
| inputs=[env_state, agent_state], outputs=step_outputs, | |
| ) | |
| right_evt = right_btn.click( | |
| lambda s, a: do_step(s, a, RIGHT), | |
| inputs=[env_state, agent_state], outputs=step_outputs, | |
| ) | |
| down_evt = down_btn.click( | |
| lambda s, a: do_step(s, a, DOWN), | |
| inputs=[env_state, agent_state], outputs=step_outputs, | |
| ) | |
| left_evt = left_btn.click( | |
| lambda s, a: do_step(s, a, LEFT), | |
| inputs=[env_state, agent_state], outputs=step_outputs, | |
| ) | |
| # Reset / grid resize | |
| reset_evt = reset_btn.click( | |
| do_reset, | |
| inputs=[env_state, agent_state, grid_size_slider], | |
| outputs=reset_outputs, | |
| ) | |
| resize_evt = grid_size_slider.release( | |
| do_reset, | |
| inputs=[env_state, agent_state, grid_size_slider], | |
| outputs=reset_outputs, | |
| ) | |
| # RL training (streaming generator) | |
| train_evt = train_btn.click( | |
| train_stream, | |
| inputs=[env_state, n_episodes_slider, speed_slider], | |
| outputs=train_outputs, | |
| ) | |
| # Cancel training on Stop, Reset, resize, or any manual move. | |
| stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[train_evt]) | |
| for evt in (reset_evt, resize_evt, up_evt, right_evt, down_evt, left_evt): | |
| evt.then(fn=None, inputs=None, outputs=None, cancels=[train_evt]) | |
| # Greedy rollout of the trained policy | |
| greedy_evt = greedy_btn.click( | |
| greedy_stream, | |
| inputs=[env_state, agent_state, speed_slider], | |
| outputs=greedy_outputs, | |
| ) | |
| for evt in (reset_evt, resize_evt, up_evt, right_evt, down_evt, left_evt, train_evt): | |
| evt.then(fn=None, inputs=None, outputs=None, cancels=[greedy_evt]) | |
| demo.load(fn=None, inputs=None, outputs=None, js=KEYBOARD_JS) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch() | |