Marcel0123 commited on
Commit
b337ada
·
verified ·
1 Parent(s): 94ac4ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -43
app.py CHANGED
@@ -1,10 +1,10 @@
1
  # app.py
2
- # Gridworld RL (Q-learning) with real-time animation + NON-FLICKERING learning curve
3
- # Keeps your original visualization + layout as much as possible.
4
- #
5
- # Buttons:
6
- # - Train (epsilon decay): learns (explore -> exploit)
7
- # - Play learned policy (epsilon=0): shows shortest safe route deterministically
8
 
9
  import time
10
  import numpy as np
@@ -13,6 +13,7 @@ import matplotlib.pyplot as plt
13
  from matplotlib.patches import Rectangle, FancyBboxPatch
14
  from io import BytesIO
15
  from PIL import Image
 
16
 
17
  # -----------------------------
18
  # Gridworld Environment
@@ -25,13 +26,71 @@ ACTION_DELTAS = {
25
  3: (0, -1), # left
26
  }
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class GridWorld:
29
  def __init__(self, size=5, start=(0, 0), goal=None, lava=None, walls=None):
30
  self.size = int(size)
31
  self.start = start
32
  self.goal = goal if goal is not None else (self.size - 1, self.size - 1)
33
- self.lava = set(lava or [(1, 3), (2, 3), (3, 1)])
34
- self.walls = set(walls or [(1, 1), (2, 1)])
35
  self.reset()
36
 
37
  def reset(self):
@@ -59,7 +118,7 @@ class GridWorld:
59
  if self.pos in self.lava:
60
  return self.pos, -10.0, True
61
 
62
- return self.pos, -0.1, False # step penalty -> shortest path is optimal
63
 
64
  # -----------------------------
65
  # Q-Learning Agent
@@ -91,7 +150,7 @@ class QAgent:
91
  self.Q[r1, c1, a] += self.alpha * td_error
92
 
93
  # -----------------------------
94
- # Rendering helpers (original look)
95
  # -----------------------------
96
  def fig_to_pil(fig):
97
  buf = BytesIO()
@@ -108,8 +167,10 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
108
  ax.set_aspect("equal")
109
  ax.axis("off")
110
 
 
111
  ax.add_patch(Rectangle((0, 0), n, n, facecolor="#0b1020"))
112
 
 
113
  for r in range(n):
114
  for c in range(n):
115
  x, y = c, n - 1 - r # invert y so (0,0) is top-left visually
@@ -133,6 +194,7 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
133
  )
134
  )
135
 
 
136
  if show_q and agent is not None and (r, c) not in env.walls:
137
  best_a = int(np.argmax(agent.Q[r, c]))
138
  qv = float(np.max(agent.Q[r, c]))
@@ -141,6 +203,7 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
141
  ax.text(x + 0.5, y + 0.30, f"{qv:+.2f}", ha="center", va="center",
142
  fontsize=9, color="#a9b7e6", alpha=0.55)
143
 
 
144
  def put_icon(rc, icon, color="#ffffff"):
145
  r, c = rc
146
  x, y = c + 0.5, (n - 1 - r) + 0.5
@@ -151,8 +214,11 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
151
  put_icon(rc, "🔥")
152
  for rc in env.walls:
153
  put_icon(rc, "🧱")
 
 
154
  put_icon(env.pos, "🤖")
155
 
 
156
  title = "Gridworld RL • Q-learning"
157
  sub = []
158
  if episode is not None:
@@ -171,7 +237,7 @@ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None,
171
  # -----------------------------
172
  # Learning curve chart (no flicker)
173
  # -----------------------------
174
- def moving_average(x, window=20):
175
  if len(x) < 2:
176
  return np.array(x, dtype=float)
177
  w = max(2, min(int(window), len(x)))
@@ -181,8 +247,6 @@ def moving_average(x, window=20):
181
  def draw_learning_curve(returns, successes, window=25):
182
  fig, ax = plt.subplots(figsize=(5.4, 4.6))
183
  ax.set_facecolor("#0b1020")
184
-
185
- # dark-friendly axes styling
186
  for spine in ax.spines.values():
187
  spine.set_color("#2a355f")
188
  ax.tick_params(colors="#c9d6ff")
@@ -201,7 +265,8 @@ def draw_learning_curve(returns, successes, window=25):
201
  ma = moving_average(returns, window=window)
202
  if len(ma) > 0:
203
  xs_ma = np.arange(len(returns) - len(ma) + 1, len(returns) + 1)
204
- ax.plot(xs_ma, ma, linewidth=2.5, alpha=0.95, label=f"Moving avg ({min(int(window), len(returns))})")
 
205
 
206
  ax2 = ax.twinx()
207
  ax2.tick_params(colors="#c9d6ff")
@@ -221,15 +286,31 @@ def draw_learning_curve(returns, successes, window=25):
221
  return fig_to_pil(fig)
222
 
223
  # -----------------------------
224
- # Training + Playback
225
  # -----------------------------
226
- def init_env_agent(size, alpha, gamma):
227
- env = GridWorld(size=size, start=(0, 0), goal=(size - 1, size - 1))
 
 
 
 
 
 
 
228
  agent = QAgent(size=size, alpha=alpha, gamma=gamma)
229
- return env, agent
 
 
 
 
 
 
 
 
230
 
231
  def train_stream(
232
  grid_size,
 
233
  alpha,
234
  gamma,
235
  eps_start,
@@ -241,7 +322,7 @@ def train_stream(
241
  show_q_overlay,
242
  curve_window,
243
  ):
244
- env, agent = init_env_agent(grid_size, alpha, gamma)
245
  eps = float(eps_start)
246
 
247
  returns = []
@@ -250,10 +331,9 @@ def train_stream(
250
  # initial
251
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=0, step_i=0, total_reward=0.0)
252
  last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
253
- status = "Klaar om te trainen. De learning curve blijft nu stabiel in beeld (geen knipperen)."
254
- yield frame, last_curve, agent.Q, status
255
 
256
- # only redraw chart every N steps (but ALWAYS output the last image)
257
  CURVE_UPDATE_EVERY_STEPS = 8
258
 
259
  for ep in range(1, int(episodes) + 1):
@@ -278,7 +358,7 @@ def train_stream(
278
 
279
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=ep, step_i=t, total_reward=total_r)
280
  status = f"Train • ep {ep}/{episodes} • step {t}/{max_steps} • return {total_r:+.2f} • eps {eps:.3f}"
281
- yield frame, last_curve, agent.Q, status
282
 
283
  if speed > 0:
284
  time.sleep(float(speed))
@@ -289,33 +369,40 @@ def train_stream(
289
  returns.append(total_r)
290
  successes.append(reached_goal_this_ep)
291
  last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
292
- yield frame, last_curve, agent.Q, f"Episode {ep} klaar • return {total_r:+.2f} • success={reached_goal_this_ep} • eps {eps:.3f}"
293
 
294
  eps = max(float(eps_end), eps * float(eps_decay))
295
 
296
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=episodes, step_i=None, total_reward=None)
297
  last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
298
  status = "Training klaar ✅ Klik nu op ‘Play learned policy’ om de strakke kortste veilige route te zien (epsilon=0)."
299
- yield frame, last_curve, agent.Q, status
300
 
301
- def play_stream(q_table, grid_size, max_steps, speed, show_q_overlay):
302
- if q_table is None:
303
- env = GridWorld(size=grid_size)
304
- agent = QAgent(size=grid_size)
 
305
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=None, step_i=None, total_reward=None)
306
  curve = draw_learning_curve([], [], window=25)
307
- yield frame, curve, "Nog geen Q-table. Klik eerst op Train."
308
  return
309
 
310
- env = GridWorld(size=grid_size)
311
- agent = QAgent(size=grid_size)
 
 
 
 
 
 
312
  agent.Q = np.array(q_table, dtype=np.float32)
313
 
314
  s = env.reset()
315
  total_r = 0.0
316
 
317
- frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=0, total_reward=total_r)
318
  curve = draw_learning_curve([], [], window=25) # keep curve visible (static) during play
 
319
  yield frame, curve, "Play • epsilon=0.0 (deterministisch) • toont de geleerde route"
320
 
321
  for t in range(1, int(max_steps) + 1):
@@ -335,34 +422,41 @@ def play_stream(q_table, grid_size, max_steps, speed, show_q_overlay):
335
  if env.pos == env.goal:
336
  end = f"🏁 Goal bereikt! return {total_r:+.2f} (korter pad = minder step-penalty)."
337
  elif env.pos in env.lava:
338
- end = f"🔥 In lava beland. Tip: train langer of eps_start hoger."
339
  else:
340
- end = f"Play klaar. Tip: train langer of max_steps omhoog."
341
 
342
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=None, total_reward=total_r)
343
  yield frame, curve, end
344
 
345
  # -----------------------------
346
- # Gradio UI (original layout)
347
  # -----------------------------
348
  with gr.Blocks(theme=gr.themes.Soft(), title="RL Gridworld (Q-learning)") as demo:
349
  gr.Markdown(
350
  """
351
  # 🤖 Reinforcement Learning in een Gridworld (real-time animatie)
352
 
 
353
  - **Train**: agent leert (epsilon decays: eerst ontdekken, later benutten)
354
  - **Play learned policy**: toont wat hij geleerd heeft (**epsilon=0**)
355
 
356
- Rechts zie je een **learning curve** (return + moving average + success rate) die nu **niet knippert**.
357
  """
358
  )
359
 
360
  q_state = gr.State(None)
 
361
 
362
  with gr.Row():
363
  with gr.Column(scale=1):
364
  grid_size = gr.Slider(4, 10, value=5, step=1, label="Grid size")
365
 
 
 
 
 
 
366
  with gr.Accordion("RL parameters (defaults = goede convergentie)", open=True):
367
  alpha = gr.Slider(0.01, 1.0, value=0.45, step=0.01, label="Alpha (learning rate)")
368
  gamma = gr.Slider(0.0, 0.999, value=0.97, step=0.001, label="Gamma (discount)")
@@ -373,7 +467,7 @@ Rechts zie je een **learning curve** (return + moving average + success rate) di
373
  eps_decay = gr.Slider(0.90, 0.999, value=0.985, step=0.001, label="Epsilon decay per episode")
374
 
375
  episodes = gr.Slider(1, 400, value=200, step=1, label="Episodes")
376
- max_steps = gr.Slider(5, 200, value=60, step=1, label="Max steps per episode")
377
 
378
  with gr.Accordion("Visuals & snelheid", open=True):
379
  speed = gr.Slider(0.0, 0.3, value=0.02, step=0.01, label="Animatie vertraging (sec/frame)")
@@ -384,7 +478,7 @@ Rechts zie je een **learning curve** (return + moving average + success rate) di
384
  train_btn = gr.Button("🚀 Train (epsilon decay)", variant="primary")
385
  play_btn = gr.Button("▶️ Play learned policy (epsilon=0)")
386
 
387
- status = gr.Textbox(label="Status", value="Klik Train om te zien hoe de agent leert.", interactive=False)
388
 
389
  with gr.Column(scale=1):
390
  frame_out = gr.Image(label="Live animatie", type="pil", height=520)
@@ -393,17 +487,19 @@ Rechts zie je een **learning curve** (return + moving average + success rate) di
393
  train_btn.click(
394
  fn=train_stream,
395
  inputs=[
396
- grid_size, alpha, gamma,
 
 
397
  eps_start, eps_end, eps_decay,
398
- episodes, max_steps,
399
  speed, show_q_overlay, curve_window
400
  ],
401
- outputs=[frame_out, curve_out, q_state, status],
402
  )
403
 
404
  play_btn.click(
405
  fn=play_stream,
406
- inputs=[q_state, grid_size, max_steps, speed, show_q_overlay],
407
  outputs=[frame_out, curve_out, status],
408
  )
409
 
 
1
  # app.py
2
+ # Gridworld RL (Q-learning) with:
3
+ # Original visualization + layout (as much as possible)
4
+ # ✅ Non-flickering learning curve (always visible)
5
+ # ✅ Option 1: Obstacle density slider (auto-generate more/less blocks)
6
+ # Train uses epsilon decay (converges); Play shows deterministic route (epsilon=0)
7
+ # Same obstacle layout is reused for Play (stored in state)
8
 
9
  import time
10
  import numpy as np
 
13
  from matplotlib.patches import Rectangle, FancyBboxPatch
14
  from io import BytesIO
15
  from PIL import Image
16
+ from collections import deque
17
 
18
  # -----------------------------
19
  # Gridworld Environment
 
26
  3: (0, -1), # left
27
  }
28
 
29
+ def _neighbors(r, c, n):
30
+ if r > 0: yield (r - 1, c)
31
+ if r < n - 1: yield (r + 1, c)
32
+ if c > 0: yield (r, c - 1)
33
+ if c < n - 1: yield (r, c + 1)
34
+
35
+ def _has_path(size, start, goal, blocked):
36
+ """BFS to ensure there's at least one safe path from start to goal."""
37
+ q = deque([start])
38
+ seen = {start}
39
+ while q:
40
+ cur = q.popleft()
41
+ if cur == goal:
42
+ return True
43
+ r, c = cur
44
+ for nr, nc in _neighbors(r, c, size):
45
+ nxt = (nr, nc)
46
+ if nxt in seen or nxt in blocked:
47
+ continue
48
+ seen.add(nxt)
49
+ q.append(nxt)
50
+ return False
51
+
52
+ def generate_obstacles(size, start, goal, density, wall_ratio=0.7, max_tries=60, rng=None):
53
+ """
54
+ Generate walls + lava with a given density, retrying until there is a safe path.
55
+ Lava is treated as blocked (terminal negative), so we keep at least one safe route.
56
+ """
57
+ rng = rng or np.random.default_rng()
58
+ density = float(np.clip(density, 0.0, 0.60))
59
+
60
+ # If density is too high, repeatedly try; if impossible, gradually reduce density
61
+ cur_density = density
62
+ for _ in range(max_tries):
63
+ walls = set()
64
+ lava = set()
65
+
66
+ for r in range(size):
67
+ for c in range(size):
68
+ cell = (r, c)
69
+ if cell == start or cell == goal:
70
+ continue
71
+ if rng.random() < cur_density:
72
+ if rng.random() < wall_ratio:
73
+ walls.add(cell)
74
+ else:
75
+ lava.add(cell)
76
+
77
+ blocked = walls | lava
78
+ if _has_path(size, start, goal, blocked):
79
+ return walls, lava
80
+
81
+ # If no path, soften the environment a bit and try again
82
+ cur_density = max(0.0, cur_density - 0.02)
83
+
84
+ # Fallback: empty obstacles (always solvable)
85
+ return set(), set()
86
+
87
  class GridWorld:
88
  def __init__(self, size=5, start=(0, 0), goal=None, lava=None, walls=None):
89
  self.size = int(size)
90
  self.start = start
91
  self.goal = goal if goal is not None else (self.size - 1, self.size - 1)
92
+ self.lava = set(lava or [])
93
+ self.walls = set(walls or [])
94
  self.reset()
95
 
96
  def reset(self):
 
118
  if self.pos in self.lava:
119
  return self.pos, -10.0, True
120
 
121
+ return self.pos, -0.1, False # small step penalty -> shortest safe path is optimal
122
 
123
  # -----------------------------
124
  # Q-Learning Agent
 
150
  self.Q[r1, c1, a] += self.alpha * td_error
151
 
152
  # -----------------------------
153
+ # Rendering helpers (ORIGINAL look)
154
  # -----------------------------
155
  def fig_to_pil(fig):
156
  buf = BytesIO()
 
167
  ax.set_aspect("equal")
168
  ax.axis("off")
169
 
170
+ # Background
171
  ax.add_patch(Rectangle((0, 0), n, n, facecolor="#0b1020"))
172
 
173
+ # Draw cells
174
  for r in range(n):
175
  for c in range(n):
176
  x, y = c, n - 1 - r # invert y so (0,0) is top-left visually
 
194
  )
195
  )
196
 
197
+ # overlay Q hint (optional)
198
  if show_q and agent is not None and (r, c) not in env.walls:
199
  best_a = int(np.argmax(agent.Q[r, c]))
200
  qv = float(np.max(agent.Q[r, c]))
 
203
  ax.text(x + 0.5, y + 0.30, f"{qv:+.2f}", ha="center", va="center",
204
  fontsize=9, color="#a9b7e6", alpha=0.55)
205
 
206
+ # Icons
207
  def put_icon(rc, icon, color="#ffffff"):
208
  r, c = rc
209
  x, y = c + 0.5, (n - 1 - r) + 0.5
 
214
  put_icon(rc, "🔥")
215
  for rc in env.walls:
216
  put_icon(rc, "🧱")
217
+
218
+ # Agent
219
  put_icon(env.pos, "🤖")
220
 
221
+ # Header overlay
222
  title = "Gridworld RL • Q-learning"
223
  sub = []
224
  if episode is not None:
 
237
  # -----------------------------
238
  # Learning curve chart (no flicker)
239
  # -----------------------------
240
+ def moving_average(x, window=25):
241
  if len(x) < 2:
242
  return np.array(x, dtype=float)
243
  w = max(2, min(int(window), len(x)))
 
247
  def draw_learning_curve(returns, successes, window=25):
248
  fig, ax = plt.subplots(figsize=(5.4, 4.6))
249
  ax.set_facecolor("#0b1020")
 
 
250
  for spine in ax.spines.values():
251
  spine.set_color("#2a355f")
252
  ax.tick_params(colors="#c9d6ff")
 
265
  ma = moving_average(returns, window=window)
266
  if len(ma) > 0:
267
  xs_ma = np.arange(len(returns) - len(ma) + 1, len(returns) + 1)
268
+ ax.plot(xs_ma, ma, linewidth=2.5, alpha=0.95,
269
+ label=f"Moving avg ({min(int(window), len(returns))})")
270
 
271
  ax2 = ax.twinx()
272
  ax2.tick_params(colors="#c9d6ff")
 
286
  return fig_to_pil(fig)
287
 
288
  # -----------------------------
289
+ # Training + Playback (store env layout so Play matches Train)
290
  # -----------------------------
291
+ def make_env_and_agent(grid_size, obstacle_density, alpha, gamma):
292
+ size = int(grid_size)
293
+ start = (0, 0)
294
+ goal = (size - 1, size - 1)
295
+
296
+ rng = np.random.default_rng() # new layout each train run
297
+ walls, lava = generate_obstacles(size, start, goal, density=float(obstacle_density), wall_ratio=0.7, rng=rng)
298
+
299
+ env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
300
  agent = QAgent(size=size, alpha=alpha, gamma=gamma)
301
+
302
+ env_state = {
303
+ "size": size,
304
+ "start": start,
305
+ "goal": goal,
306
+ "walls": sorted(list(walls)),
307
+ "lava": sorted(list(lava)),
308
+ }
309
+ return env, agent, env_state
310
 
311
  def train_stream(
312
  grid_size,
313
+ obstacle_density,
314
  alpha,
315
  gamma,
316
  eps_start,
 
322
  show_q_overlay,
323
  curve_window,
324
  ):
325
+ env, agent, env_state = make_env_and_agent(grid_size, obstacle_density, alpha, gamma)
326
  eps = float(eps_start)
327
 
328
  returns = []
 
331
  # initial
332
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=0, step_i=0, total_reward=0.0)
333
  last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
334
+ status = f"Klaar om te trainen. Obstacle density={float(obstacle_density):.2f}. (Curve knippert niet.)"
335
+ yield frame, last_curve, agent.Q, env_state, status
336
 
 
337
  CURVE_UPDATE_EVERY_STEPS = 8
338
 
339
  for ep in range(1, int(episodes) + 1):
 
358
 
359
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=ep, step_i=t, total_reward=total_r)
360
  status = f"Train • ep {ep}/{episodes} • step {t}/{max_steps} • return {total_r:+.2f} • eps {eps:.3f}"
361
+ yield frame, last_curve, agent.Q, env_state, status
362
 
363
  if speed > 0:
364
  time.sleep(float(speed))
 
369
  returns.append(total_r)
370
  successes.append(reached_goal_this_ep)
371
  last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
372
+ yield frame, last_curve, agent.Q, env_state, f"Episode {ep} klaar • return {total_r:+.2f} • success={reached_goal_this_ep} • eps {eps:.3f}"
373
 
374
  eps = max(float(eps_end), eps * float(eps_decay))
375
 
376
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=episodes, step_i=None, total_reward=None)
377
  last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
378
  status = "Training klaar ✅ Klik nu op ‘Play learned policy’ om de strakke kortste veilige route te zien (epsilon=0)."
379
+ yield frame, last_curve, agent.Q, env_state, status
380
 
381
+ def play_stream(q_table, env_state, speed, show_q_overlay, max_steps):
382
+ if q_table is None or env_state is None:
383
+ # show something reasonable
384
+ env = GridWorld(size=5, start=(0, 0), goal=(4, 4), walls=[], lava=[])
385
+ agent = QAgent(size=5)
386
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode=None, step_i=None, total_reward=None)
387
  curve = draw_learning_curve([], [], window=25)
388
+ yield frame, curve, "Nog geen training gedaan. Klik eerst op Train."
389
  return
390
 
391
+ size = int(env_state["size"])
392
+ start = tuple(env_state["start"])
393
+ goal = tuple(env_state["goal"])
394
+ walls = [tuple(x) for x in env_state["walls"]]
395
+ lava = [tuple(x) for x in env_state["lava"]]
396
+
397
+ env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
398
+ agent = QAgent(size=size)
399
  agent.Q = np.array(q_table, dtype=np.float32)
400
 
401
  s = env.reset()
402
  total_r = 0.0
403
 
 
404
  curve = draw_learning_curve([], [], window=25) # keep curve visible (static) during play
405
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=0, total_reward=total_r)
406
  yield frame, curve, "Play • epsilon=0.0 (deterministisch) • toont de geleerde route"
407
 
408
  for t in range(1, int(max_steps) + 1):
 
422
  if env.pos == env.goal:
423
  end = f"🏁 Goal bereikt! return {total_r:+.2f} (korter pad = minder step-penalty)."
424
  elif env.pos in env.lava:
425
+ end = "🔥 In lava beland. Tip: train langer of zet density lager."
426
  else:
427
+ end = "Play klaar. Tip: train langer of max_steps omhoog."
428
 
429
  frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=None, total_reward=total_r)
430
  yield frame, curve, end
431
 
432
  # -----------------------------
433
+ # Gradio UI (original layout) + obstacle density slider
434
  # -----------------------------
435
  with gr.Blocks(theme=gr.themes.Soft(), title="RL Gridworld (Q-learning)") as demo:
436
  gr.Markdown(
437
  """
438
  # 🤖 Reinforcement Learning in een Gridworld (real-time animatie)
439
 
440
+ - **Obstacle density**: hoeveel 🧱/🔥 er in het grid staan (meer = moeilijker)
441
  - **Train**: agent leert (epsilon decays: eerst ontdekken, later benutten)
442
  - **Play learned policy**: toont wat hij geleerd heeft (**epsilon=0**)
443
 
444
+ Rechts zie je een **learning curve** (return + moving average + success rate) die **niet knippert**.
445
  """
446
  )
447
 
448
  q_state = gr.State(None)
449
+ env_state = gr.State(None)
450
 
451
  with gr.Row():
452
  with gr.Column(scale=1):
453
  grid_size = gr.Slider(4, 10, value=5, step=1, label="Grid size")
454
 
455
+ obstacle_density = gr.Slider(
456
+ 0.0, 0.45, value=0.15, step=0.05,
457
+ label="Obstacle density (meer blokken/gevaar)"
458
+ )
459
+
460
  with gr.Accordion("RL parameters (defaults = goede convergentie)", open=True):
461
  alpha = gr.Slider(0.01, 1.0, value=0.45, step=0.01, label="Alpha (learning rate)")
462
  gamma = gr.Slider(0.0, 0.999, value=0.97, step=0.001, label="Gamma (discount)")
 
467
  eps_decay = gr.Slider(0.90, 0.999, value=0.985, step=0.001, label="Epsilon decay per episode")
468
 
469
  episodes = gr.Slider(1, 400, value=200, step=1, label="Episodes")
470
+ max_steps_train = gr.Slider(5, 200, value=60, step=1, label="Max steps per episode")
471
 
472
  with gr.Accordion("Visuals & snelheid", open=True):
473
  speed = gr.Slider(0.0, 0.3, value=0.02, step=0.01, label="Animatie vertraging (sec/frame)")
 
478
  train_btn = gr.Button("🚀 Train (epsilon decay)", variant="primary")
479
  play_btn = gr.Button("▶️ Play learned policy (epsilon=0)")
480
 
481
+ status = gr.Textbox(label="Status", value="Kies density en klik Train.", interactive=False)
482
 
483
  with gr.Column(scale=1):
484
  frame_out = gr.Image(label="Live animatie", type="pil", height=520)
 
487
  train_btn.click(
488
  fn=train_stream,
489
  inputs=[
490
+ grid_size,
491
+ obstacle_density,
492
+ alpha, gamma,
493
  eps_start, eps_end, eps_decay,
494
+ episodes, max_steps_train,
495
  speed, show_q_overlay, curve_window
496
  ],
497
+ outputs=[frame_out, curve_out, q_state, env_state, status],
498
  )
499
 
500
  play_btn.click(
501
  fn=play_stream,
502
+ inputs=[q_state, env_state, speed, show_q_overlay, max_steps_train],
503
  outputs=[frame_out, curve_out, status],
504
  )
505