Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
# app.py
|
| 2 |
-
# Gridworld RL (Q-learning) with
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
#
|
| 8 |
|
| 9 |
import time
|
| 10 |
import numpy as np
|
|
@@ -13,6 +13,7 @@ import matplotlib.pyplot as plt
|
|
| 13 |
from matplotlib.patches import Rectangle, FancyBboxPatch
|
| 14 |
from io import BytesIO
|
| 15 |
from PIL import Image
|
|
|
|
| 16 |
|
| 17 |
# -----------------------------
|
| 18 |
# Gridworld Environment
|
|
@@ -25,13 +26,71 @@ ACTION_DELTAS = {
|
|
| 25 |
3: (0, -1), # left
|
| 26 |
}
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
class GridWorld:
|
| 29 |
def __init__(self, size=5, start=(0, 0), goal=None, lava=None, walls=None):
|
| 30 |
self.size = int(size)
|
| 31 |
self.start = start
|
| 32 |
self.goal = goal if goal is not None else (self.size - 1, self.size - 1)
|
| 33 |
-
self.lava = set(lava or [
|
| 34 |
-
self.walls = set(walls or [
|
| 35 |
self.reset()
|
| 36 |
|
| 37 |
def reset(self):
|
|
@@ -59,7 +118,7 @@ class GridWorld:
|
|
| 59 |
if self.pos in self.lava:
|
| 60 |
return self.pos, -10.0, True
|
| 61 |
|
| 62 |
-
return self.pos, -0.1, False # step penalty -> shortest path is optimal
|
| 63 |
|
| 64 |
# -----------------------------
|
| 65 |
# Q-Learning Agent
|
|
@@ -91,7 +150,7 @@ class QAgent:
|
|
| 91 |
self.Q[r1, c1, a] += self.alpha * td_error
|
| 92 |
|
| 93 |
# -----------------------------
|
| 94 |
-
# Rendering helpers (
|
| 95 |
# -----------------------------
|
| 96 |
def fig_to_pil(fig):
|
| 97 |
buf = BytesIO()
|
|
@@ -108,8 +167,10 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
|
|
| 108 |
ax.set_aspect("equal")
|
| 109 |
ax.axis("off")
|
| 110 |
|
|
|
|
| 111 |
ax.add_patch(Rectangle((0, 0), n, n, facecolor="#0b1020"))
|
| 112 |
|
|
|
|
| 113 |
for r in range(n):
|
| 114 |
for c in range(n):
|
| 115 |
x, y = c, n - 1 - r # invert y so (0,0) is top-left visually
|
|
@@ -133,6 +194,7 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
|
|
| 133 |
)
|
| 134 |
)
|
| 135 |
|
|
|
|
| 136 |
if show_q and agent is not None and (r, c) not in env.walls:
|
| 137 |
best_a = int(np.argmax(agent.Q[r, c]))
|
| 138 |
qv = float(np.max(agent.Q[r, c]))
|
|
@@ -141,6 +203,7 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
|
|
| 141 |
ax.text(x + 0.5, y + 0.30, f"{qv:+.2f}", ha="center", va="center",
|
| 142 |
fontsize=9, color="#a9b7e6", alpha=0.55)
|
| 143 |
|
|
|
|
| 144 |
def put_icon(rc, icon, color="#ffffff"):
|
| 145 |
r, c = rc
|
| 146 |
x, y = c + 0.5, (n - 1 - r) + 0.5
|
|
@@ -151,8 +214,11 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
|
|
| 151 |
put_icon(rc, "🔥")
|
| 152 |
for rc in env.walls:
|
| 153 |
put_icon(rc, "🧱")
|
|
|
|
|
|
|
| 154 |
put_icon(env.pos, "🤖")
|
| 155 |
|
|
|
|
| 156 |
title = "Gridworld RL • Q-learning"
|
| 157 |
sub = []
|
| 158 |
if episode is not None:
|
|
@@ -171,7 +237,7 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
|
|
| 171 |
# -----------------------------
|
| 172 |
# Learning curve chart (no flicker)
|
| 173 |
# -----------------------------
|
| 174 |
-
def moving_average(x, window=
|
| 175 |
if len(x) < 2:
|
| 176 |
return np.array(x, dtype=float)
|
| 177 |
w = max(2, min(int(window), len(x)))
|
|
@@ -181,8 +247,6 @@ def moving_average(x, window=20):
|
|
| 181 |
def draw_learning_curve(returns, successes, window=25):
|
| 182 |
fig, ax = plt.subplots(figsize=(5.4, 4.6))
|
| 183 |
ax.set_facecolor("#0b1020")
|
| 184 |
-
|
| 185 |
-
# dark-friendly axes styling
|
| 186 |
for spine in ax.spines.values():
|
| 187 |
spine.set_color("#2a355f")
|
| 188 |
ax.tick_params(colors="#c9d6ff")
|
|
@@ -201,7 +265,8 @@ def draw_learning_curve(returns, successes, window=25):
|
|
| 201 |
ma = moving_average(returns, window=window)
|
| 202 |
if len(ma) > 0:
|
| 203 |
xs_ma = np.arange(len(returns) - len(ma) + 1, len(returns) + 1)
|
| 204 |
-
ax.plot(xs_ma, ma, linewidth=2.5, alpha=0.95,
|
|
|
|
| 205 |
|
| 206 |
ax2 = ax.twinx()
|
| 207 |
ax2.tick_params(colors="#c9d6ff")
|
|
@@ -221,15 +286,31 @@ def draw_learning_curve(returns, successes, window=25):
|
|
| 221 |
return fig_to_pil(fig)
|
| 222 |
|
| 223 |
# -----------------------------
|
| 224 |
-
# Training + Playback
|
| 225 |
# -----------------------------
|
| 226 |
-
def
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
agent = QAgent(size=size, alpha=alpha, gamma=gamma)
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def train_stream(
|
| 232 |
grid_size,
|
|
|
|
| 233 |
alpha,
|
| 234 |
gamma,
|
| 235 |
eps_start,
|
|
@@ -241,7 +322,7 @@ def train_stream(
|
|
| 241 |
show_q_overlay,
|
| 242 |
curve_window,
|
| 243 |
):
|
| 244 |
-
env, agent =
|
| 245 |
eps = float(eps_start)
|
| 246 |
|
| 247 |
returns = []
|
|
@@ -250,10 +331,9 @@ def train_stream(
|
|
| 250 |
# initial
|
| 251 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=0, step_i=0, total_reward=0.0)
|
| 252 |
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 253 |
-
status = "Klaar om te trainen.
|
| 254 |
-
yield frame, last_curve, agent.Q, status
|
| 255 |
|
| 256 |
-
# only redraw chart every N steps (but ALWAYS output the last image)
|
| 257 |
CURVE_UPDATE_EVERY_STEPS = 8
|
| 258 |
|
| 259 |
for ep in range(1, int(episodes) + 1):
|
|
@@ -278,7 +358,7 @@ def train_stream(
|
|
| 278 |
|
| 279 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=ep, step_i=t, total_reward=total_r)
|
| 280 |
status = f"Train • ep {ep}/{episodes} • step {t}/{max_steps} • return {total_r:+.2f} • eps {eps:.3f}"
|
| 281 |
-
yield frame, last_curve, agent.Q, status
|
| 282 |
|
| 283 |
if speed > 0:
|
| 284 |
time.sleep(float(speed))
|
|
@@ -289,33 +369,40 @@ def train_stream(
|
|
| 289 |
returns.append(total_r)
|
| 290 |
successes.append(reached_goal_this_ep)
|
| 291 |
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 292 |
-
yield frame, last_curve, agent.Q, f"Episode {ep} klaar • return {total_r:+.2f} • success={reached_goal_this_ep} • eps {eps:.3f}"
|
| 293 |
|
| 294 |
eps = max(float(eps_end), eps * float(eps_decay))
|
| 295 |
|
| 296 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=episodes, step_i=None, total_reward=None)
|
| 297 |
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 298 |
status = "Training klaar ✅ Klik nu op ‘Play learned policy’ om de strakke kortste veilige route te zien (epsilon=0)."
|
| 299 |
-
yield frame, last_curve, agent.Q, status
|
| 300 |
|
| 301 |
-
def play_stream(q_table,
|
| 302 |
-
if q_table is None:
|
| 303 |
-
|
| 304 |
-
|
|
|
|
| 305 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=None, step_i=None, total_reward=None)
|
| 306 |
curve = draw_learning_curve([], [], window=25)
|
| 307 |
-
yield frame, curve, "Nog geen
|
| 308 |
return
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
agent.Q = np.array(q_table, dtype=np.float32)
|
| 313 |
|
| 314 |
s = env.reset()
|
| 315 |
total_r = 0.0
|
| 316 |
|
| 317 |
-
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=0, total_reward=total_r)
|
| 318 |
curve = draw_learning_curve([], [], window=25) # keep curve visible (static) during play
|
|
|
|
| 319 |
yield frame, curve, "Play • epsilon=0.0 (deterministisch) • toont de geleerde route"
|
| 320 |
|
| 321 |
for t in range(1, int(max_steps) + 1):
|
|
@@ -335,34 +422,41 @@ def play_stream(q_table, grid_size, max_steps, speed, show_q_overlay):
|
|
| 335 |
if env.pos == env.goal:
|
| 336 |
end = f"🏁 Goal bereikt! return {total_r:+.2f} (korter pad = minder step-penalty)."
|
| 337 |
elif env.pos in env.lava:
|
| 338 |
-
end =
|
| 339 |
else:
|
| 340 |
-
end =
|
| 341 |
|
| 342 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=None, total_reward=total_r)
|
| 343 |
yield frame, curve, end
|
| 344 |
|
| 345 |
# -----------------------------
|
| 346 |
-
# Gradio UI (original layout)
|
| 347 |
# -----------------------------
|
| 348 |
with gr.Blocks(theme=gr.themes.Soft(), title="RL Gridworld (Q-learning)") as demo:
|
| 349 |
gr.Markdown(
|
| 350 |
"""
|
| 351 |
# 🤖 Reinforcement Learning in een Gridworld (real-time animatie)
|
| 352 |
|
|
|
|
| 353 |
- **Train**: agent leert (epsilon decays: eerst ontdekken, later benutten)
|
| 354 |
- **Play learned policy**: toont wat hij geleerd heeft (**epsilon=0**)
|
| 355 |
|
| 356 |
-
Rechts zie je een **learning curve** (return + moving average + success rate) die
|
| 357 |
"""
|
| 358 |
)
|
| 359 |
|
| 360 |
q_state = gr.State(None)
|
|
|
|
| 361 |
|
| 362 |
with gr.Row():
|
| 363 |
with gr.Column(scale=1):
|
| 364 |
grid_size = gr.Slider(4, 10, value=5, step=1, label="Grid size")
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
with gr.Accordion("RL parameters (defaults = goede convergentie)", open=True):
|
| 367 |
alpha = gr.Slider(0.01, 1.0, value=0.45, step=0.01, label="Alpha (learning rate)")
|
| 368 |
gamma = gr.Slider(0.0, 0.999, value=0.97, step=0.001, label="Gamma (discount)")
|
|
@@ -373,7 +467,7 @@ Rechts zie je een **learning curve** (return + moving average + success rate) di
|
|
| 373 |
eps_decay = gr.Slider(0.90, 0.999, value=0.985, step=0.001, label="Epsilon decay per episode")
|
| 374 |
|
| 375 |
episodes = gr.Slider(1, 400, value=200, step=1, label="Episodes")
|
| 376 |
-
|
| 377 |
|
| 378 |
with gr.Accordion("Visuals & snelheid", open=True):
|
| 379 |
speed = gr.Slider(0.0, 0.3, value=0.02, step=0.01, label="Animatie vertraging (sec/frame)")
|
|
@@ -384,7 +478,7 @@ Rechts zie je een **learning curve** (return + moving average + success rate) di
|
|
| 384 |
train_btn = gr.Button("🚀 Train (epsilon decay)", variant="primary")
|
| 385 |
play_btn = gr.Button("▶️ Play learned policy (epsilon=0)")
|
| 386 |
|
| 387 |
-
status = gr.Textbox(label="Status", value="
|
| 388 |
|
| 389 |
with gr.Column(scale=1):
|
| 390 |
frame_out = gr.Image(label="Live animatie", type="pil", height=520)
|
|
@@ -393,17 +487,19 @@ Rechts zie je een **learning curve** (return + moving average + success rate) di
|
|
| 393 |
train_btn.click(
|
| 394 |
fn=train_stream,
|
| 395 |
inputs=[
|
| 396 |
-
grid_size,
|
|
|
|
|
|
|
| 397 |
eps_start, eps_end, eps_decay,
|
| 398 |
-
episodes,
|
| 399 |
speed, show_q_overlay, curve_window
|
| 400 |
],
|
| 401 |
-
outputs=[frame_out, curve_out, q_state, status],
|
| 402 |
)
|
| 403 |
|
| 404 |
play_btn.click(
|
| 405 |
fn=play_stream,
|
| 406 |
-
inputs=[q_state,
|
| 407 |
outputs=[frame_out, curve_out, status],
|
| 408 |
)
|
| 409 |
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
# Gridworld RL (Q-learning) with:
|
| 3 |
+
# ✅ Original visualization + layout (as much as possible)
|
| 4 |
+
# ✅ Non-flickering learning curve (always visible)
|
| 5 |
+
# ✅ Option 1: Obstacle density slider (auto-generate more/less blocks)
|
| 6 |
+
# ✅ Train uses epsilon decay (converges); Play shows deterministic route (epsilon=0)
|
| 7 |
+
# ✅ Same obstacle layout is reused for Play (stored in state)
|
| 8 |
|
| 9 |
import time
|
| 10 |
import numpy as np
|
|
|
|
| 13 |
from matplotlib.patches import Rectangle, FancyBboxPatch
|
| 14 |
from io import BytesIO
|
| 15 |
from PIL import Image
|
| 16 |
+
from collections import deque
|
| 17 |
|
| 18 |
# -----------------------------
|
| 19 |
# Gridworld Environment
|
|
|
|
| 26 |
3: (0, -1), # left
|
| 27 |
}
|
| 28 |
|
| 29 |
+
def _neighbors(r, c, n):
|
| 30 |
+
if r > 0: yield (r - 1, c)
|
| 31 |
+
if r < n - 1: yield (r + 1, c)
|
| 32 |
+
if c > 0: yield (r, c - 1)
|
| 33 |
+
if c < n - 1: yield (r, c + 1)
|
| 34 |
+
|
| 35 |
+
def _has_path(size, start, goal, blocked):
|
| 36 |
+
"""BFS to ensure there's at least one safe path from start to goal."""
|
| 37 |
+
q = deque([start])
|
| 38 |
+
seen = {start}
|
| 39 |
+
while q:
|
| 40 |
+
cur = q.popleft()
|
| 41 |
+
if cur == goal:
|
| 42 |
+
return True
|
| 43 |
+
r, c = cur
|
| 44 |
+
for nr, nc in _neighbors(r, c, size):
|
| 45 |
+
nxt = (nr, nc)
|
| 46 |
+
if nxt in seen or nxt in blocked:
|
| 47 |
+
continue
|
| 48 |
+
seen.add(nxt)
|
| 49 |
+
q.append(nxt)
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
def generate_obstacles(size, start, goal, density, wall_ratio=0.7, max_tries=60, rng=None):
|
| 53 |
+
"""
|
| 54 |
+
Generate walls + lava with a given density, retrying until there is a safe path.
|
| 55 |
+
Lava is treated as blocked (terminal negative), so we keep at least one safe route.
|
| 56 |
+
"""
|
| 57 |
+
rng = rng or np.random.default_rng()
|
| 58 |
+
density = float(np.clip(density, 0.0, 0.60))
|
| 59 |
+
|
| 60 |
+
# If density is too high, repeatedly try; if impossible, gradually reduce density
|
| 61 |
+
cur_density = density
|
| 62 |
+
for _ in range(max_tries):
|
| 63 |
+
walls = set()
|
| 64 |
+
lava = set()
|
| 65 |
+
|
| 66 |
+
for r in range(size):
|
| 67 |
+
for c in range(size):
|
| 68 |
+
cell = (r, c)
|
| 69 |
+
if cell == start or cell == goal:
|
| 70 |
+
continue
|
| 71 |
+
if rng.random() < cur_density:
|
| 72 |
+
if rng.random() < wall_ratio:
|
| 73 |
+
walls.add(cell)
|
| 74 |
+
else:
|
| 75 |
+
lava.add(cell)
|
| 76 |
+
|
| 77 |
+
blocked = walls | lava
|
| 78 |
+
if _has_path(size, start, goal, blocked):
|
| 79 |
+
return walls, lava
|
| 80 |
+
|
| 81 |
+
# If no path, soften the environment a bit and try again
|
| 82 |
+
cur_density = max(0.0, cur_density - 0.02)
|
| 83 |
+
|
| 84 |
+
# Fallback: empty obstacles (always solvable)
|
| 85 |
+
return set(), set()
|
| 86 |
+
|
| 87 |
class GridWorld:
|
| 88 |
def __init__(self, size=5, start=(0, 0), goal=None, lava=None, walls=None):
|
| 89 |
self.size = int(size)
|
| 90 |
self.start = start
|
| 91 |
self.goal = goal if goal is not None else (self.size - 1, self.size - 1)
|
| 92 |
+
self.lava = set(lava or [])
|
| 93 |
+
self.walls = set(walls or [])
|
| 94 |
self.reset()
|
| 95 |
|
| 96 |
def reset(self):
|
|
|
|
| 118 |
if self.pos in self.lava:
|
| 119 |
return self.pos, -10.0, True
|
| 120 |
|
| 121 |
+
return self.pos, -0.1, False # small step penalty -> shortest safe path is optimal
|
| 122 |
|
| 123 |
# -----------------------------
|
| 124 |
# Q-Learning Agent
|
|
|
|
| 150 |
self.Q[r1, c1, a] += self.alpha * td_error
|
| 151 |
|
| 152 |
# -----------------------------
|
| 153 |
+
# Rendering helpers (ORIGINAL look)
|
| 154 |
# -----------------------------
|
| 155 |
def fig_to_pil(fig):
|
| 156 |
buf = BytesIO()
|
|
|
|
| 167 |
ax.set_aspect("equal")
|
| 168 |
ax.axis("off")
|
| 169 |
|
| 170 |
+
# Background
|
| 171 |
ax.add_patch(Rectangle((0, 0), n, n, facecolor="#0b1020"))
|
| 172 |
|
| 173 |
+
# Draw cells
|
| 174 |
for r in range(n):
|
| 175 |
for c in range(n):
|
| 176 |
x, y = c, n - 1 - r # invert y so (0,0) is top-left visually
|
|
|
|
| 194 |
)
|
| 195 |
)
|
| 196 |
|
| 197 |
+
# overlay Q hint (optional)
|
| 198 |
if show_q and agent is not None and (r, c) not in env.walls:
|
| 199 |
best_a = int(np.argmax(agent.Q[r, c]))
|
| 200 |
qv = float(np.max(agent.Q[r, c]))
|
|
|
|
| 203 |
ax.text(x + 0.5, y + 0.30, f"{qv:+.2f}", ha="center", va="center",
|
| 204 |
fontsize=9, color="#a9b7e6", alpha=0.55)
|
| 205 |
|
| 206 |
+
# Icons
|
| 207 |
def put_icon(rc, icon, color="#ffffff"):
|
| 208 |
r, c = rc
|
| 209 |
x, y = c + 0.5, (n - 1 - r) + 0.5
|
|
|
|
| 214 |
put_icon(rc, "🔥")
|
| 215 |
for rc in env.walls:
|
| 216 |
put_icon(rc, "🧱")
|
| 217 |
+
|
| 218 |
+
# Agent
|
| 219 |
put_icon(env.pos, "🤖")
|
| 220 |
|
| 221 |
+
# Header overlay
|
| 222 |
title = "Gridworld RL • Q-learning"
|
| 223 |
sub = []
|
| 224 |
if episode is not None:
|
|
|
|
| 237 |
# -----------------------------
|
| 238 |
# Learning curve chart (no flicker)
|
| 239 |
# -----------------------------
|
| 240 |
+
def moving_average(x, window=25):
|
| 241 |
if len(x) < 2:
|
| 242 |
return np.array(x, dtype=float)
|
| 243 |
w = max(2, min(int(window), len(x)))
|
|
|
|
| 247 |
def draw_learning_curve(returns, successes, window=25):
|
| 248 |
fig, ax = plt.subplots(figsize=(5.4, 4.6))
|
| 249 |
ax.set_facecolor("#0b1020")
|
|
|
|
|
|
|
| 250 |
for spine in ax.spines.values():
|
| 251 |
spine.set_color("#2a355f")
|
| 252 |
ax.tick_params(colors="#c9d6ff")
|
|
|
|
| 265 |
ma = moving_average(returns, window=window)
|
| 266 |
if len(ma) > 0:
|
| 267 |
xs_ma = np.arange(len(returns) - len(ma) + 1, len(returns) + 1)
|
| 268 |
+
ax.plot(xs_ma, ma, linewidth=2.5, alpha=0.95,
|
| 269 |
+
label=f"Moving avg ({min(int(window), len(returns))})")
|
| 270 |
|
| 271 |
ax2 = ax.twinx()
|
| 272 |
ax2.tick_params(colors="#c9d6ff")
|
|
|
|
| 286 |
return fig_to_pil(fig)
|
| 287 |
|
| 288 |
# -----------------------------
|
| 289 |
+
# Training + Playback (store env layout so Play matches Train)
|
| 290 |
# -----------------------------
|
| 291 |
+
def make_env_and_agent(grid_size, obstacle_density, alpha, gamma):
|
| 292 |
+
size = int(grid_size)
|
| 293 |
+
start = (0, 0)
|
| 294 |
+
goal = (size - 1, size - 1)
|
| 295 |
+
|
| 296 |
+
rng = np.random.default_rng() # new layout each train run
|
| 297 |
+
walls, lava = generate_obstacles(size, start, goal, density=float(obstacle_density), wall_ratio=0.7, rng=rng)
|
| 298 |
+
|
| 299 |
+
env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
|
| 300 |
agent = QAgent(size=size, alpha=alpha, gamma=gamma)
|
| 301 |
+
|
| 302 |
+
env_state = {
|
| 303 |
+
"size": size,
|
| 304 |
+
"start": start,
|
| 305 |
+
"goal": goal,
|
| 306 |
+
"walls": sorted(list(walls)),
|
| 307 |
+
"lava": sorted(list(lava)),
|
| 308 |
+
}
|
| 309 |
+
return env, agent, env_state
|
| 310 |
|
| 311 |
def train_stream(
|
| 312 |
grid_size,
|
| 313 |
+
obstacle_density,
|
| 314 |
alpha,
|
| 315 |
gamma,
|
| 316 |
eps_start,
|
|
|
|
| 322 |
show_q_overlay,
|
| 323 |
curve_window,
|
| 324 |
):
|
| 325 |
+
env, agent, env_state = make_env_and_agent(grid_size, obstacle_density, alpha, gamma)
|
| 326 |
eps = float(eps_start)
|
| 327 |
|
| 328 |
returns = []
|
|
|
|
| 331 |
# initial
|
| 332 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=0, step_i=0, total_reward=0.0)
|
| 333 |
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 334 |
+
status = f"Klaar om te trainen. Obstacle density={float(obstacle_density):.2f}. (Curve knippert niet.)"
|
| 335 |
+
yield frame, last_curve, agent.Q, env_state, status
|
| 336 |
|
|
|
|
| 337 |
CURVE_UPDATE_EVERY_STEPS = 8
|
| 338 |
|
| 339 |
for ep in range(1, int(episodes) + 1):
|
|
|
|
| 358 |
|
| 359 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=ep, step_i=t, total_reward=total_r)
|
| 360 |
status = f"Train • ep {ep}/{episodes} • step {t}/{max_steps} • return {total_r:+.2f} • eps {eps:.3f}"
|
| 361 |
+
yield frame, last_curve, agent.Q, env_state, status
|
| 362 |
|
| 363 |
if speed > 0:
|
| 364 |
time.sleep(float(speed))
|
|
|
|
| 369 |
returns.append(total_r)
|
| 370 |
successes.append(reached_goal_this_ep)
|
| 371 |
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 372 |
+
yield frame, last_curve, agent.Q, env_state, f"Episode {ep} klaar • return {total_r:+.2f} • success={reached_goal_this_ep} • eps {eps:.3f}"
|
| 373 |
|
| 374 |
eps = max(float(eps_end), eps * float(eps_decay))
|
| 375 |
|
| 376 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=episodes, step_i=None, total_reward=None)
|
| 377 |
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 378 |
status = "Training klaar ✅ Klik nu op ‘Play learned policy’ om de strakke kortste veilige route te zien (epsilon=0)."
|
| 379 |
+
yield frame, last_curve, agent.Q, env_state, status
|
| 380 |
|
| 381 |
+
def play_stream(q_table, env_state, speed, show_q_overlay, max_steps):
|
| 382 |
+
if q_table is None or env_state is None:
|
| 383 |
+
# show something reasonable
|
| 384 |
+
env = GridWorld(size=5, start=(0, 0), goal=(4, 4), walls=[], lava=[])
|
| 385 |
+
agent = QAgent(size=5)
|
| 386 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=None, step_i=None, total_reward=None)
|
| 387 |
curve = draw_learning_curve([], [], window=25)
|
| 388 |
+
yield frame, curve, "Nog geen training gedaan. Klik eerst op Train."
|
| 389 |
return
|
| 390 |
|
| 391 |
+
size = int(env_state["size"])
|
| 392 |
+
start = tuple(env_state["start"])
|
| 393 |
+
goal = tuple(env_state["goal"])
|
| 394 |
+
walls = [tuple(x) for x in env_state["walls"]]
|
| 395 |
+
lava = [tuple(x) for x in env_state["lava"]]
|
| 396 |
+
|
| 397 |
+
env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
|
| 398 |
+
agent = QAgent(size=size)
|
| 399 |
agent.Q = np.array(q_table, dtype=np.float32)
|
| 400 |
|
| 401 |
s = env.reset()
|
| 402 |
total_r = 0.0
|
| 403 |
|
|
|
|
| 404 |
curve = draw_learning_curve([], [], window=25) # keep curve visible (static) during play
|
| 405 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=0, total_reward=total_r)
|
| 406 |
yield frame, curve, "Play • epsilon=0.0 (deterministisch) • toont de geleerde route"
|
| 407 |
|
| 408 |
for t in range(1, int(max_steps) + 1):
|
|
|
|
| 422 |
if env.pos == env.goal:
|
| 423 |
end = f"🏁 Goal bereikt! return {total_r:+.2f} (korter pad = minder step-penalty)."
|
| 424 |
elif env.pos in env.lava:
|
| 425 |
+
end = "🔥 In lava beland. Tip: train langer of zet density lager."
|
| 426 |
else:
|
| 427 |
+
end = "Play klaar. Tip: train langer of max_steps omhoog."
|
| 428 |
|
| 429 |
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=None, total_reward=total_r)
|
| 430 |
yield frame, curve, end
|
| 431 |
|
| 432 |
# -----------------------------
|
| 433 |
+
# Gradio UI (original layout) + obstacle density slider
|
| 434 |
# -----------------------------
|
| 435 |
with gr.Blocks(theme=gr.themes.Soft(), title="RL Gridworld (Q-learning)") as demo:
|
| 436 |
gr.Markdown(
|
| 437 |
"""
|
| 438 |
# 🤖 Reinforcement Learning in een Gridworld (real-time animatie)
|
| 439 |
|
| 440 |
+
- **Obstacle density**: hoeveel 🧱/🔥 er in het grid staan (meer = moeilijker)
|
| 441 |
- **Train**: agent leert (epsilon decays: eerst ontdekken, later benutten)
|
| 442 |
- **Play learned policy**: toont wat hij geleerd heeft (**epsilon=0**)
|
| 443 |
|
| 444 |
+
Rechts zie je een **learning curve** (return + moving average + success rate) die **niet knippert**.
|
| 445 |
"""
|
| 446 |
)
|
| 447 |
|
| 448 |
q_state = gr.State(None)
|
| 449 |
+
env_state = gr.State(None)
|
| 450 |
|
| 451 |
with gr.Row():
|
| 452 |
with gr.Column(scale=1):
|
| 453 |
grid_size = gr.Slider(4, 10, value=5, step=1, label="Grid size")
|
| 454 |
|
| 455 |
+
obstacle_density = gr.Slider(
|
| 456 |
+
0.0, 0.45, value=0.15, step=0.05,
|
| 457 |
+
label="Obstacle density (meer blokken/gevaar)"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
with gr.Accordion("RL parameters (defaults = goede convergentie)", open=True):
|
| 461 |
alpha = gr.Slider(0.01, 1.0, value=0.45, step=0.01, label="Alpha (learning rate)")
|
| 462 |
gamma = gr.Slider(0.0, 0.999, value=0.97, step=0.001, label="Gamma (discount)")
|
|
|
|
| 467 |
eps_decay = gr.Slider(0.90, 0.999, value=0.985, step=0.001, label="Epsilon decay per episode")
|
| 468 |
|
| 469 |
episodes = gr.Slider(1, 400, value=200, step=1, label="Episodes")
|
| 470 |
+
max_steps_train = gr.Slider(5, 200, value=60, step=1, label="Max steps per episode")
|
| 471 |
|
| 472 |
with gr.Accordion("Visuals & snelheid", open=True):
|
| 473 |
speed = gr.Slider(0.0, 0.3, value=0.02, step=0.01, label="Animatie vertraging (sec/frame)")
|
|
|
|
| 478 |
train_btn = gr.Button("🚀 Train (epsilon decay)", variant="primary")
|
| 479 |
play_btn = gr.Button("▶️ Play learned policy (epsilon=0)")
|
| 480 |
|
| 481 |
+
status = gr.Textbox(label="Status", value="Kies density en klik Train.", interactive=False)
|
| 482 |
|
| 483 |
with gr.Column(scale=1):
|
| 484 |
frame_out = gr.Image(label="Live animatie", type="pil", height=520)
|
|
|
|
| 487 |
train_btn.click(
|
| 488 |
fn=train_stream,
|
| 489 |
inputs=[
|
| 490 |
+
grid_size,
|
| 491 |
+
obstacle_density,
|
| 492 |
+
alpha, gamma,
|
| 493 |
eps_start, eps_end, eps_decay,
|
| 494 |
+
episodes, max_steps_train,
|
| 495 |
speed, show_q_overlay, curve_window
|
| 496 |
],
|
| 497 |
+
outputs=[frame_out, curve_out, q_state, env_state, status],
|
| 498 |
)
|
| 499 |
|
| 500 |
play_btn.click(
|
| 501 |
fn=play_stream,
|
| 502 |
+
inputs=[q_state, env_state, speed, show_q_overlay, max_steps_train],
|
| 503 |
outputs=[frame_out, curve_out, status],
|
| 504 |
)
|
| 505 |
|