Marcel0123 commited on
Commit
635c19c
·
verified ·
1 Parent(s): 9acd31a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +464 -166
app.py CHANGED
@@ -1,6 +1,13 @@
1
  # app.py
2
- # RL Gridworld (Q-learning) – Warehouse Robot Demo
3
- # Styling: dark industrial robotics theme (option 2)
 
 
 
 
 
 
 
4
 
5
  import time
6
  import numpy as np
@@ -12,7 +19,7 @@ from PIL import Image
12
  from collections import deque
13
 
14
  # =========================================================
15
- # 🎨 CUSTOM CSS (industrial, calm, dark)
16
  # =========================================================
17
  CUSTOM_CSS = """
18
  body {
@@ -34,24 +41,14 @@ body {
34
  /* Headings */
35
  h1, h2, h3 {
36
  color: #ffd27d;
37
- letter-spacing: 0.06em;
38
  }
39
 
40
  /* Text */
41
- p, li {
42
  color: #d6e6ff;
43
  }
44
 
45
- /* Panels / cards */
46
- .gr-group, .gr-box, .gr-panel {
47
- background: radial-gradient(circle at top left,
48
- rgba(255, 200, 120, 0.06),
49
- rgba(4, 9, 29, 0.98));
50
- border-radius: 22px;
51
- border: 1px solid rgba(255, 200, 120, 0.28);
52
- box-shadow: 0 0 22px rgba(255, 180, 80, 0.12);
53
- }
54
-
55
  /* Labels */
56
  label {
57
  color: #ffddaa !important;
@@ -83,69 +80,104 @@ button.primary {
83
  background: linear-gradient(90deg, #ffb347, #ffcc80) !important;
84
  color: #1a0f02 !important;
85
  border: none !important;
86
- box-shadow: 0 0 16px rgba(255, 180, 80, 0.45);
87
  }
88
 
89
  button.secondary {
90
- background: rgba(12, 20, 40, 0.9) !important;
91
  color: #ffd9a0 !important;
92
  border: 1px solid rgba(255, 200, 120, 0.35) !important;
93
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  """
95
 
96
- # =========================================================
97
- # 🤖 GRIDWORLD ENVIRONMENT
98
- # =========================================================
99
  ACTIONS = ["↑", "→", "↓", "←"]
100
  ACTION_DELTAS = {
101
- 0: (-1, 0),
102
- 1: (0, 1),
103
- 2: (1, 0),
104
- 3: (0, -1),
105
  }
106
 
107
- def neighbors(r, c, n):
108
- for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
109
- nr, nc = r+dr, c+dc
110
- if 0 <= nr < n and 0 <= nc < n:
111
- yield nr, nc
112
 
113
- def has_path(size, start, goal, blocked):
 
114
  q = deque([start])
115
  seen = {start}
116
  while q:
117
  cur = q.popleft()
118
  if cur == goal:
119
  return True
120
- for nxt in neighbors(*cur, size):
121
- if nxt not in seen and nxt not in blocked:
122
- seen.add(nxt)
123
- q.append(nxt)
 
 
 
124
  return False
125
 
126
- def generate_obstacles(size, start, goal, density, rng):
127
- walls, lava = set(), set()
128
- for _ in range(60):
129
- walls.clear()
130
- lava.clear()
 
 
 
 
 
 
 
 
131
  for r in range(size):
132
  for c in range(size):
133
- if (r,c) in (start, goal):
 
134
  continue
135
- if rng.random() < density:
136
- (walls if rng.random() < 0.7 else lava).add((r,c))
137
- if has_path(size, start, goal, walls | lava):
 
 
 
 
 
138
  return walls, lava
139
- density = max(0, density - 0.02)
 
 
140
  return set(), set()
141
 
142
  class GridWorld:
143
- def __init__(self, size, walls, lava):
144
- self.size = size
145
- self.start = (0,0)
146
- self.goal = (size-1, size-1)
147
- self.walls = walls
148
- self.lava = lava
149
  self.reset()
150
 
151
  def reset(self):
@@ -155,42 +187,58 @@ class GridWorld:
155
  def step(self, action):
156
  dr, dc = ACTION_DELTAS[action]
157
  r, c = self.pos
158
- nr, nc = r+dr, c+dc
159
- if (nr < 0 or nr >= self.size or
160
- nc < 0 or nc >= self.size or
161
- (nr,nc) in self.walls):
162
  nr, nc = r, c
163
- self.pos = (nr,nc)
 
 
 
 
 
 
 
164
  if self.pos == self.goal:
165
- return self.pos, 10.0, True
166
  if self.pos in self.lava:
167
  return self.pos, -10.0, True
168
- return self.pos, -0.1, False
169
 
170
- # =========================================================
171
- # 🧠 Q-LEARNING AGENT
172
- # =========================================================
173
- class QAgent:
174
- def __init__(self, size, alpha, gamma):
175
- self.Q = np.zeros((size, size, 4), dtype=np.float32)
176
- self.alpha = alpha
177
- self.gamma = gamma
178
 
179
- def act(self, s, eps):
180
- if np.random.rand() < eps:
181
- return np.random.randint(4)
182
- return int(np.argmax(self.Q[s]))
183
-
184
- def act_greedy(self, s):
185
- return int(np.argmax(self.Q[s]))
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def update(self, s, a, r, s2, done):
188
- target = r if done else r + self.gamma * np.max(self.Q[s2])
189
- self.Q[s + (a,)] += self.alpha * (target - self.Q[s + (a,)])
190
-
191
- # =========================================================
192
- # 🎨 RENDERING
193
- # =========================================================
 
 
 
 
194
  def fig_to_pil(fig):
195
  buf = BytesIO()
196
  fig.savefig(buf, format="png", dpi=160, bbox_inches="tight")
@@ -198,116 +246,366 @@ def fig_to_pil(fig):
198
  buf.seek(0)
199
  return Image.open(buf)
200
 
201
- def draw(env, agent=None, episode=None, step=None, ret=None):
202
  n = env.size
203
- fig, ax = plt.subplots(figsize=(5.4,5.4))
204
- ax.set_xlim(0,n); ax.set_ylim(0,n)
 
 
205
  ax.axis("off")
206
- ax.add_patch(Rectangle((0,0), n,n, facecolor="#0b1020"))
 
 
 
 
207
  for r in range(n):
208
  for c in range(n):
209
- x,y = c,n-1-r
210
- color = "#121a33"
211
- if (r,c) == env.goal: color="#0f2f1f"
212
- if (r,c) in env.lava: color="#3a1414"
213
- if (r,c) in env.walls: color="#1b1b1b"
214
- ax.add_patch(FancyBboxPatch(
215
- (x+0.05,y+0.05),0.9,0.9,
216
- boxstyle="round,pad=0.02",
217
- facecolor=color, edgecolor="#2a355f"
218
- ))
219
- def icon(rc, txt):
220
- r,c = rc
221
- ax.text(c+0.5, n-1-r+0.5, txt, ha="center", va="center", fontsize=22)
222
- icon(env.goal,"🏁")
223
- for p in env.lava: icon(p,"🔥")
224
- for p in env.walls: icon(p,"🧱")
225
- icon(env.pos,"🤖")
226
- title = f"Episode {episode} | Step {step} | Return {ret:+.2f}" if episode else ""
227
- ax.text(0, n+0.2, title, color="#ffd27d")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  return fig_to_pil(fig)
229
 
230
- # =========================================================
231
- # 🚀 TRAIN / PLAY
232
- # =========================================================
233
- def train(grid, density, alpha, gamma, eps_s, eps_e, eps_d, episodes, max_steps, speed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  rng = np.random.default_rng()
235
- walls,lava = generate_obstacles(grid,(0,0),(grid-1,grid-1),density,rng)
236
- env = GridWorld(grid,walls,lava)
237
- agent = QAgent(grid,alpha,gamma)
238
- eps = eps_s
239
- for ep in range(1, episodes+1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  s = env.reset()
241
- ret = 0
242
- for t in range(1, max_steps+1):
243
- a = agent.act(s, eps)
244
- s2,r,d = env.step(a)
245
- agent.update(s,a,r,s2,d)
246
- s = s2; ret += r
247
- yield draw(env,agent,ep,t,ret), agent.Q, (walls,lava)
248
- time.sleep(speed)
249
- if d: break
250
- eps = max(eps_e, eps*eps_d)
251
-
252
- def play(Q, env_state, grid, max_steps, speed):
253
- walls,lava = env_state
254
- env = GridWorld(grid,walls,lava)
255
- agent = QAgent(grid,0,0); agent.Q = Q
256
- s = env.reset(); ret=0
257
- for t in range(1, max_steps+1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  a = agent.act_greedy(s)
259
- s,r,d = env.step(a)
260
- ret+=r
261
- yield draw(env,agent,"PLAY",t,ret)
262
- time.sleep(speed)
263
- if d: break
264
-
265
- # =========================================================
266
- # 🖥️ UI
267
- # =========================================================
268
- with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
269
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  with gr.Row():
271
  with gr.Column(scale=3):
272
- gr.Markdown("""
 
273
  ### 🤖 Een robot in het magazijn
274
- De robot leert zelf hoe hij veilig en efficiënt door het magazijn beweegt,
275
- zonder regels of kaart.
276
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  with gr.Column(scale=2):
278
- gr.Image("humanoid-robot-apptronic-1024x684.jpg.webp", show_label=False)
 
 
 
 
279
 
 
280
  q_state = gr.State(None)
281
  env_state = gr.State(None)
282
 
283
  with gr.Row():
284
- with gr.Column():
285
- grid = gr.Slider(4,10,5,label="Grid size")
286
- density = gr.Slider(0,0.45,0.15,label="Obstacle density")
287
- alpha = gr.Slider(0.01,1,0.45,label="Alpha")
288
- gamma = gr.Slider(0,0.999,0.97,label="Gamma")
289
- eps_s = gr.Slider(0,1,0.9,label="Epsilon start")
290
- eps_e = gr.Slider(0,0.2,0.02,label="Epsilon end")
291
- eps_d = gr.Slider(0.9,0.999,0.985,label="Epsilon decay")
292
- episodes = gr.Slider(1,300,200,label="Episodes")
293
- max_steps = gr.Slider(5,200,60,label="Max steps")
294
- speed = gr.Slider(0,0.1,0.02,label="Speed")
295
- train_btn = gr.Button("🚀 Train", variant="primary")
296
- play_btn = gr.Button("▶️ Play", variant="secondary")
297
-
298
- with gr.Column():
299
- frame = gr.Image(height=520)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  train_btn.click(
302
- train,
303
- inputs=[grid,density,alpha,gamma,eps_s,eps_e,eps_d,episodes,max_steps,speed],
304
- outputs=[frame,q_state,env_state],
 
 
 
 
 
 
 
305
  )
306
 
307
  play_btn.click(
308
- play,
309
- inputs=[q_state,env_state,grid,max_steps,speed],
310
- outputs=frame,
311
  )
312
 
313
- demo.launch()
 
 
1
  # app.py
2
+ # Gridworld RL (Q-learning) with:
3
+ # Original visualization + layout for the demo section (unchanged)
4
+ # ✅ Non-flickering learning curve (always visible)
5
+ # ✅ Obstacle density slider (auto-generate more/less blocks)
6
+ # ✅ Train uses epsilon decay (converges); Play shows deterministic route (epsilon=0)
7
+ # ✅ Same obstacle layout is reused for Play (stored in state)
8
+ # ✅ Styling (Option 2): dark background + calmer amber/orange accents
9
+ # ✅ Header: text LEFT, photo RIGHT
10
+ # ✅ Removed the extra RL-description block (as you requested earlier)
11
 
12
  import time
13
  import numpy as np
 
19
  from collections import deque
20
 
21
  # =========================================================
22
+ # 🎨 CUSTOM CSS (Option 2: calm industrial robotics)
23
  # =========================================================
24
  CUSTOM_CSS = """
25
  body {
 
41
  /* Headings */
42
  h1, h2, h3 {
43
  color: #ffd27d;
44
+ letter-spacing: 0.04em;
45
  }
46
 
47
  /* Text */
48
+ p, li, .md p {
49
  color: #d6e6ff;
50
  }
51
 
 
 
 
 
 
 
 
 
 
 
52
  /* Labels */
53
  label {
54
  color: #ffddaa !important;
 
80
  background: linear-gradient(90deg, #ffb347, #ffcc80) !important;
81
  color: #1a0f02 !important;
82
  border: none !important;
83
+ box-shadow: 0 0 16px rgba(255, 180, 80, 0.40);
84
  }
85
 
86
  button.secondary {
87
+ background: rgba(12, 20, 40, 0.92) !important;
88
  color: #ffd9a0 !important;
89
  border: 1px solid rgba(255, 200, 120, 0.35) !important;
90
  }
91
+
92
+ /* Accordions / panels - keep subtle */
93
+ .gr-accordion, .gr-box, .gr-panel, .gr-group {
94
+ background: radial-gradient(circle at top left,
95
+ rgba(255, 200, 120, 0.06),
96
+ rgba(4, 9, 29, 0.98)) !important;
97
+ border: 1px solid rgba(255, 200, 120, 0.18) !important;
98
+ border-radius: 18px !important;
99
+ box-shadow: 0 0 18px rgba(255, 180, 80, 0.10);
100
+ }
101
+
102
+ /* Image containers - do not affect the pixels */
103
+ img {
104
+ border-radius: 16px;
105
+ }
106
  """
107
 
108
+ # -----------------------------
109
+ # Gridworld Environment
110
+ # -----------------------------
111
  ACTIONS = ["↑", "→", "↓", "←"]
112
  ACTION_DELTAS = {
113
+ 0: (-1, 0), # up
114
+ 1: (0, 1), # right
115
+ 2: (1, 0), # down
116
+ 3: (0, -1), # left
117
  }
118
 
119
+ def _neighbors(r, c, n):
120
+ if r > 0: yield (r - 1, c)
121
+ if r < n - 1: yield (r + 1, c)
122
+ if c > 0: yield (r, c - 1)
123
+ if c < n - 1: yield (r, c + 1)
124
 
125
+ def _has_path(size, start, goal, blocked):
126
+ """BFS to ensure there's at least one safe path from start to goal."""
127
  q = deque([start])
128
  seen = {start}
129
  while q:
130
  cur = q.popleft()
131
  if cur == goal:
132
  return True
133
+ r, c = cur
134
+ for nr, nc in _neighbors(r, c, size):
135
+ nxt = (nr, nc)
136
+ if nxt in seen or nxt in blocked:
137
+ continue
138
+ seen.add(nxt)
139
+ q.append(nxt)
140
  return False
141
 
142
+ def generate_obstacles(size, start, goal, density, wall_ratio=0.7, max_tries=60, rng=None):
143
+ """
144
+ Generate walls + lava with a given density, retrying until there is a safe path.
145
+ Lava is treated as blocked (terminal negative), so we keep at least one safe route.
146
+ """
147
+ rng = rng or np.random.default_rng()
148
+ density = float(np.clip(density, 0.0, 0.60))
149
+
150
+ cur_density = density
151
+ for _ in range(max_tries):
152
+ walls = set()
153
+ lava = set()
154
+
155
  for r in range(size):
156
  for c in range(size):
157
+ cell = (r, c)
158
+ if cell == start or cell == goal:
159
  continue
160
+ if rng.random() < cur_density:
161
+ if rng.random() < wall_ratio:
162
+ walls.add(cell)
163
+ else:
164
+ lava.add(cell)
165
+
166
+ blocked = walls | lava
167
+ if _has_path(size, start, goal, blocked):
168
  return walls, lava
169
+
170
+ cur_density = max(0.0, cur_density - 0.02)
171
+
172
  return set(), set()
173
 
174
  class GridWorld:
175
+ def __init__(self, size=5, start=(0, 0), goal=None, lava=None, walls=None):
176
+ self.size = int(size)
177
+ self.start = start
178
+ self.goal = goal if goal is not None else (self.size - 1, self.size - 1)
179
+ self.lava = set(lava or [])
180
+ self.walls = set(walls or [])
181
  self.reset()
182
 
183
  def reset(self):
 
187
  def step(self, action):
188
  dr, dc = ACTION_DELTAS[action]
189
  r, c = self.pos
190
+ nr, nc = r + dr, c + dc
191
+
192
+ # bounds check
193
+ if nr < 0 or nr >= self.size or nc < 0 or nc >= self.size:
194
  nr, nc = r, c
195
+
196
+ # wall check
197
+ if (nr, nc) in self.walls:
198
+ nr, nc = r, c
199
+
200
+ self.pos = (nr, nc)
201
+
202
+ # rewards
203
  if self.pos == self.goal:
204
+ return self.pos, +10.0, True
205
  if self.pos in self.lava:
206
  return self.pos, -10.0, True
 
207
 
208
+ return self.pos, -0.1, False # step penalty -> shortest safe path is optimal
 
 
 
 
 
 
 
209
 
210
+ # -----------------------------
211
+ # Q-Learning Agent
212
+ # -----------------------------
213
+ class QAgent:
214
+ def __init__(self, size=5, n_actions=4, alpha=0.3, gamma=0.95):
215
+ self.size = int(size)
216
+ self.n_actions = n_actions
217
+ self.alpha = float(alpha)
218
+ self.gamma = float(gamma)
219
+ self.Q = np.zeros((self.size, self.size, n_actions), dtype=np.float32)
220
+
221
+ def act(self, state, epsilon):
222
+ r, c = state
223
+ if np.random.rand() < float(epsilon):
224
+ return np.random.randint(self.n_actions)
225
+ return int(np.argmax(self.Q[r, c]))
226
+
227
+ def act_greedy(self, state):
228
+ r, c = state
229
+ return int(np.argmax(self.Q[r, c]))
230
 
231
  def update(self, s, a, r, s2, done):
232
+ r1, c1 = s
233
+ r2, c2 = s2
234
+ best_next = 0.0 if done else float(np.max(self.Q[r2, c2]))
235
+ td_target = r + self.gamma * best_next
236
+ td_error = td_target - float(self.Q[r1, c1, a])
237
+ self.Q[r1, c1, a] += self.alpha * td_error
238
+
239
+ # -----------------------------
240
+ # Rendering helpers (ORIGINAL look)
241
+ # -----------------------------
242
  def fig_to_pil(fig):
243
  buf = BytesIO()
244
  fig.savefig(buf, format="png", dpi=160, bbox_inches="tight")
 
246
  buf.seek(0)
247
  return Image.open(buf)
248
 
249
+ def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None, step_i=None, total_reward=None):
250
  n = env.size
251
+ fig, ax = plt.subplots(figsize=(5.4, 5.4))
252
+ ax.set_xlim(0, n)
253
+ ax.set_ylim(0, n)
254
+ ax.set_aspect("equal")
255
  ax.axis("off")
256
+
257
+ # Background (keep original)
258
+ ax.add_patch(Rectangle((0, 0), n, n, facecolor="#0b1020"))
259
+
260
+ # Draw cells
261
  for r in range(n):
262
  for c in range(n):
263
+ x, y = c, n - 1 - r # invert y so (0,0) is top-left visually
264
+
265
+ tile_color = "#121a33"
266
+ if (r, c) == env.goal:
267
+ tile_color = "#0f2f1f"
268
+ if (r, c) in env.lava:
269
+ tile_color = "#3a1414"
270
+ if (r, c) in env.walls:
271
+ tile_color = "#1b1b1b"
272
+
273
+ ax.add_patch(
274
+ FancyBboxPatch(
275
+ (x + 0.05, y + 0.05), 0.9, 0.9,
276
+ boxstyle="round,pad=0.02,rounding_size=0.08",
277
+ linewidth=1.0,
278
+ edgecolor="#2a355f",
279
+ facecolor=tile_color,
280
+ alpha=0.95
281
+ )
282
+ )
283
+
284
+ # overlay Q hint (optional)
285
+ if show_q and agent is not None and (r, c) not in env.walls:
286
+ best_a = int(np.argmax(agent.Q[r, c]))
287
+ qv = float(np.max(agent.Q[r, c]))
288
+ ax.text(x + 0.5, y + 0.55, ACTIONS[best_a], ha="center", va="center",
289
+ fontsize=14, color="#d7e3ff", alpha=0.65)
290
+ ax.text(x + 0.5, y + 0.30, f"{qv:+.2f}", ha="center", va="center",
291
+ fontsize=9, color="#a9b7e6", alpha=0.55)
292
+
293
+ # Icons
294
+ def put_icon(rc, icon, color="#ffffff"):
295
+ r, c = rc
296
+ x, y = c + 0.5, (n - 1 - r) + 0.5
297
+ ax.text(x, y, icon, ha="center", va="center", fontsize=22, color=color)
298
+
299
+ put_icon(env.goal, "🏁")
300
+ for rc in env.lava:
301
+ put_icon(rc, "🔥")
302
+ for rc in env.walls:
303
+ put_icon(rc, "🧱")
304
+
305
+ # Agent
306
+ put_icon(env.pos, "🤖")
307
+
308
+ # Header overlay
309
+ title = "Gridworld RL • Q-learning"
310
+ sub = []
311
+ if episode is not None:
312
+ sub.append(f"Episode: {episode}")
313
+ if step_i is not None:
314
+ sub.append(f"Step: {step_i}")
315
+ if total_reward is not None:
316
+ sub.append(f"Return: {total_reward:+.2f}")
317
+ subtitle = " • ".join(sub)
318
+
319
+ ax.text(0, n + 0.35, title, fontsize=14, color="#eaf0ff", weight="bold")
320
+ ax.text(0, n + 0.08, subtitle, fontsize=10, color="#b8c6ff", alpha=0.9)
321
+
322
  return fig_to_pil(fig)
323
 
324
+ # -----------------------------
325
+ # Learning curve chart (no flicker)
326
+ # -----------------------------
327
+ def moving_average(x, window=25):
328
+ if len(x) < 2:
329
+ return np.array(x, dtype=float)
330
+ w = max(2, min(int(window), len(x)))
331
+ kernel = np.ones(w) / w
332
+ return np.convolve(np.array(x, dtype=float), kernel, mode="valid")
333
+
334
+ def draw_learning_curve(returns, successes, window=25):
335
+ fig, ax = plt.subplots(figsize=(5.4, 4.6))
336
+ ax.set_facecolor("#0b1020")
337
+ for spine in ax.spines.values():
338
+ spine.set_color("#2a355f")
339
+ ax.tick_params(colors="#c9d6ff")
340
+ ax.yaxis.label.set_color("#c9d6ff")
341
+ ax.xaxis.label.set_color("#c9d6ff")
342
+ ax.title.set_color("#eaf0ff")
343
+
344
+ ax.set_title("Learning curve")
345
+ ax.set_xlabel("Episode")
346
+ ax.set_ylabel("Return")
347
+
348
+ if len(returns) > 0:
349
+ xs = np.arange(1, len(returns) + 1)
350
+ ax.plot(xs, returns, linewidth=1.5, alpha=0.9, label="Return")
351
+
352
+ ma = moving_average(returns, window=window)
353
+ if len(ma) > 0:
354
+ xs_ma = np.arange(len(returns) - len(ma) + 1, len(returns) + 1)
355
+ ax.plot(xs_ma, ma, linewidth=2.5, alpha=0.95,
356
+ label=f"Moving avg ({min(int(window), len(returns))})")
357
+
358
+ ax2 = ax.twinx()
359
+ ax2.tick_params(colors="#c9d6ff")
360
+ ax2.spines["right"].set_color("#2a355f")
361
+ ax2.set_ylabel("Success rate", color="#c9d6ff")
362
+
363
+ if len(successes) > 0:
364
+ xs = np.arange(1, len(successes) + 1)
365
+ sr = np.cumsum(np.array(successes, dtype=float)) / xs
366
+ ax2.plot(xs, sr, linewidth=2.0, alpha=0.8, label="Success rate")
367
+
368
+ lines, labels = ax.get_legend_handles_labels()
369
+ lines2, labels2 = ax2.get_legend_handles_labels()
370
+ ax.legend(lines + lines2, labels + labels2, loc="lower right", framealpha=0.2)
371
+
372
+ ax.grid(True, alpha=0.15)
373
+ return fig_to_pil(fig)
374
+
375
+ # -----------------------------
376
+ # Training + Playback (store env layout so Play matches Train)
377
+ # -----------------------------
378
+ def make_env_and_agent(grid_size, obstacle_density, alpha, gamma):
379
+ size = int(grid_size)
380
+ start = (0, 0)
381
+ goal = (size - 1, size - 1)
382
+
383
  rng = np.random.default_rng()
384
+ walls, lava = generate_obstacles(size, start, goal, density=float(obstacle_density), wall_ratio=0.7, rng=rng)
385
+
386
+ env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
387
+ agent = QAgent(size=size, alpha=alpha, gamma=gamma)
388
+
389
+ env_state = {
390
+ "size": size,
391
+ "start": start,
392
+ "goal": goal,
393
+ "walls": sorted(list(walls)),
394
+ "lava": sorted(list(lava)),
395
+ }
396
+ return env, agent, env_state
397
+
398
+ def train_stream(
399
+ grid_size,
400
+ obstacle_density,
401
+ alpha,
402
+ gamma,
403
+ eps_start,
404
+ eps_end,
405
+ eps_decay,
406
+ episodes,
407
+ max_steps,
408
+ speed,
409
+ show_q_overlay,
410
+ curve_window,
411
+ ):
412
+ env, agent, env_state = make_env_and_agent(grid_size, obstacle_density, alpha, gamma)
413
+ eps = float(eps_start)
414
+
415
+ returns = []
416
+ successes = []
417
+
418
+ # initial
419
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode=0, step_i=0, total_reward=0.0)
420
+ last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
421
+ status = f"Kies density en klik Train. (Obstacle density={float(obstacle_density):.2f})"
422
+ yield frame, last_curve, agent.Q, env_state, status
423
+
424
+ CURVE_UPDATE_EVERY_STEPS = 8
425
+
426
+ for ep in range(1, int(episodes) + 1):
427
  s = env.reset()
428
+ total_r = 0.0
429
+ reached_goal_this_ep = 0
430
+
431
+ for t in range(1, int(max_steps) + 1):
432
+ a = agent.act(s, epsilon=eps)
433
+ s2, r, done = env.step(a)
434
+ agent.update(s, a, r, s2, done)
435
+ s = s2
436
+ total_r += r
437
+
438
+ if done and env.pos == env.goal:
439
+ reached_goal_this_ep = 1
440
+
441
+ if (t % CURVE_UPDATE_EVERY_STEPS == 0) or done:
442
+ preview_returns = returns + [total_r]
443
+ preview_successes = successes + [reached_goal_this_ep]
444
+ last_curve = draw_learning_curve(preview_returns, preview_successes, window=int(curve_window))
445
+
446
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode=ep, step_i=t, total_reward=total_r)
447
+ status = f"Train • ep {ep}/{episodes} • step {t}/{max_steps} • return {total_r:+.2f} • eps {eps:.3f}"
448
+ yield frame, last_curve, agent.Q, env_state, status
449
+
450
+ if speed > 0:
451
+ time.sleep(float(speed))
452
+
453
+ if done:
454
+ break
455
+
456
+ returns.append(total_r)
457
+ successes.append(reached_goal_this_ep)
458
+ last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
459
+ yield frame, last_curve, agent.Q, env_state, f"Episode {ep} klaar • return {total_r:+.2f} • success={reached_goal_this_ep} • eps {eps:.3f}"
460
+
461
+ eps = max(float(eps_end), eps * float(eps_decay))
462
+
463
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode=episodes, step_i=None, total_reward=None)
464
+ last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
465
+ status = "Training klaar ✅ Klik nu op ‘Play learned policy’."
466
+ yield frame, last_curve, agent.Q, env_state, status
467
+
468
+ def play_stream(q_table, env_state, speed, show_q_overlay, max_steps):
469
+ if q_table is None or env_state is None:
470
+ env = GridWorld(size=5, start=(0, 0), goal=(4, 4), walls=[], lava=[])
471
+ agent = QAgent(size=5)
472
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode=None, step_i=None, total_reward=None)
473
+ curve = draw_learning_curve([], [], window=25)
474
+ yield frame, curve, "Nog geen training gedaan. Klik eerst op Train."
475
+ return
476
+
477
+ size = int(env_state["size"])
478
+ start = tuple(env_state["start"])
479
+ goal = tuple(env_state["goal"])
480
+ walls = [tuple(x) for x in env_state["walls"]]
481
+ lava = [tuple(x) for x in env_state["lava"]]
482
+
483
+ env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
484
+ agent = QAgent(size=size)
485
+ agent.Q = np.array(q_table, dtype=np.float32)
486
+
487
+ s = env.reset()
488
+ total_r = 0.0
489
+
490
+ curve = draw_learning_curve([], [], window=25) # keep curve visible (static) during play
491
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=0, total_reward=total_r)
492
+ yield frame, curve, "Play • epsilon=0.0 (deterministisch)"
493
+
494
+ for t in range(1, int(max_steps) + 1):
495
  a = agent.act_greedy(s)
496
+ s2, r, done = env.step(a)
497
+ s = s2
498
+ total_r += r
499
+
500
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=t, total_reward=total_r)
501
+ yield frame, curve, f"Play • step {t}/{max_steps} • return {total_r:+.2f}"
502
+
503
+ if speed > 0:
504
+ time.sleep(float(speed))
505
+ if done:
506
+ break
507
+
508
+ if env.pos == env.goal:
509
+ end = f"🏁 Goal bereikt! return {total_r:+.2f}"
510
+ elif env.pos in env.lava:
511
+ end = "🔥 In lava beland. Tip: train langer of zet density lager."
512
+ else:
513
+ end = "Play klaar. Tip: train langer of max_steps omhoog."
514
+
515
+ frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=None, total_reward=total_r)
516
+ yield frame, curve, end
517
+
518
+ # -----------------------------
519
+ # Gradio UI (layout stays the same)
520
+ # -----------------------------
521
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), title="Warehouse Robot RL Demo") as demo:
522
+ # Header: text LEFT, image RIGHT (as you requested)
523
  with gr.Row():
524
  with gr.Column(scale=3):
525
+ gr.Markdown(
526
+ """
527
  ### 🤖 Een robot in het magazijn
528
+
529
+ Stel je voor: je werkt in een groot magazijn.
530
+ Tussen de stellingen rijdt een robot rond die bestellingen moet ophalen en naar het inpakstation brengen.
531
+ Die robot krijgt geen kaart, geen regels en geen instructies over wat de snelste route is.
532
+ In deze demo zie je hoe zo’n robot zelf leert wat slim gedrag is.
533
+
534
+ In het begin rijdt hij willekeurig rond en maakt hij fouten.
535
+ Maar naarmate hij meer ervaring opdoet, ontdekt hij vanzelf: hoe hij veilig, efficiënt en zo snel mogelijk door het magazijn kan bewegen.
536
+
537
+ Boven zie je de robot rijden tussen stellingen en gevaarlijke zones.
538
+ Onder zie je hoe zijn prestaties verbeteren naarmate hij leert.
539
+
540
+ 👉 Probeer het zelf: maak het magazijn makkelijker of moeilijker, train de robot,
541
+ en laat daarna zien wat hij geleerd heeft.
542
+ """
543
+ )
544
  with gr.Column(scale=2):
545
+ gr.Image(
546
+ value="humanoid-robot-apptronic-1024x684.jpg.webp",
547
+ show_label=False,
548
+ height=340,
549
+ )
550
 
551
+ # ---- Demo section (unchanged) ----
552
  q_state = gr.State(None)
553
  env_state = gr.State(None)
554
 
555
  with gr.Row():
556
+ with gr.Column(scale=1):
557
+ grid_size = gr.Slider(4, 10, value=5, step=1, label="Grid size")
558
+
559
+ obstacle_density = gr.Slider(
560
+ 0.0, 0.45, value=0.15, step=0.05,
561
+ label="Obstacle density (meer blokken/gevaar)"
562
+ )
563
+
564
+ with gr.Accordion("RL parameters (defaults = goede convergentie)", open=True):
565
+ alpha = gr.Slider(0.01, 1.0, value=0.45, step=0.01, label="Alpha (learning rate)")
566
+ gamma = gr.Slider(0.0, 0.999, value=0.97, step=0.001, label="Gamma (discount)")
567
+
568
+ with gr.Accordion("Exploration (epsilon decay)", open=True):
569
+ eps_start = gr.Slider(0.0, 1.0, value=0.90, step=0.01, label="Epsilon start (veel explore)")
570
+ eps_end = gr.Slider(0.0, 0.2, value=0.02, step=0.005, label="Epsilon end (bijna greedy)")
571
+ eps_decay = gr.Slider(0.90, 0.999, value=0.985, step=0.001, label="Epsilon decay per episode")
572
+
573
+ episodes = gr.Slider(1, 400, value=200, step=1, label="Episodes")
574
+ max_steps_train = gr.Slider(5, 200, value=60, step=1, label="Max steps per episode")
575
+
576
+ with gr.Accordion("Visuals & snelheid", open=True):
577
+ speed = gr.Slider(0.0, 0.3, value=0.02, step=0.01, label="Animatie vertraging (sec/frame)")
578
+ show_q_overlay = gr.Checkbox(value=True, label="Toon beste actie & Q-waarde per vakje (overlay)")
579
+ curve_window = gr.Slider(5, 80, value=25, step=1, label="Moving average window (episodes)")
580
+
581
+ with gr.Row():
582
+ train_btn = gr.Button("🚀 Train (epsilon decay)", variant="primary")
583
+ play_btn = gr.Button("▶️ Play learned policy (epsilon=0)")
584
+
585
+ status = gr.Textbox(label="Status", value="Kies density en klik Train.", interactive=False)
586
+
587
+ with gr.Column(scale=1):
588
+ frame_out = gr.Image(label="Live animatie", type="pil", height=520)
589
+ curve_out = gr.Image(label="Learning curve (live)", type="pil", height=420)
590
 
591
  train_btn.click(
592
+ fn=train_stream,
593
+ inputs=[
594
+ grid_size,
595
+ obstacle_density,
596
+ alpha, gamma,
597
+ eps_start, eps_end, eps_decay,
598
+ episodes, max_steps_train,
599
+ speed, show_q_overlay, curve_window
600
+ ],
601
+ outputs=[frame_out, curve_out, q_state, env_state, status],
602
  )
603
 
604
  play_btn.click(
605
+ fn=play_stream,
606
+ inputs=[q_state, env_state, speed, show_q_overlay, max_steps_train],
607
+ outputs=[frame_out, curve_out, status],
608
  )
609
 
610
+ if __name__ == "__main__":
611
+ demo.launch()