Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
| 1 |
# app.py
|
| 2 |
-
# RL
|
| 3 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import time
|
| 6 |
import numpy as np
|
|
@@ -12,7 +19,7 @@ from PIL import Image
|
|
| 12 |
from collections import deque
|
| 13 |
|
| 14 |
# =========================================================
|
| 15 |
-
# 🎨 CUSTOM CSS (
|
| 16 |
# =========================================================
|
| 17 |
CUSTOM_CSS = """
|
| 18 |
body {
|
|
@@ -34,24 +41,14 @@ body {
|
|
| 34 |
/* Headings */
|
| 35 |
h1, h2, h3 {
|
| 36 |
color: #ffd27d;
|
| 37 |
-
letter-spacing: 0.
|
| 38 |
}
|
| 39 |
|
| 40 |
/* Text */
|
| 41 |
-
p, li {
|
| 42 |
color: #d6e6ff;
|
| 43 |
}
|
| 44 |
|
| 45 |
-
/* Panels / cards */
|
| 46 |
-
.gr-group, .gr-box, .gr-panel {
|
| 47 |
-
background: radial-gradient(circle at top left,
|
| 48 |
-
rgba(255, 200, 120, 0.06),
|
| 49 |
-
rgba(4, 9, 29, 0.98));
|
| 50 |
-
border-radius: 22px;
|
| 51 |
-
border: 1px solid rgba(255, 200, 120, 0.28);
|
| 52 |
-
box-shadow: 0 0 22px rgba(255, 180, 80, 0.12);
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
/* Labels */
|
| 56 |
label {
|
| 57 |
color: #ffddaa !important;
|
|
@@ -83,69 +80,104 @@ button.primary {
|
|
| 83 |
background: linear-gradient(90deg, #ffb347, #ffcc80) !important;
|
| 84 |
color: #1a0f02 !important;
|
| 85 |
border: none !important;
|
| 86 |
-
box-shadow: 0 0 16px rgba(255, 180, 80, 0.
|
| 87 |
}
|
| 88 |
|
| 89 |
button.secondary {
|
| 90 |
-
background: rgba(12, 20, 40, 0.
|
| 91 |
color: #ffd9a0 !important;
|
| 92 |
border: 1px solid rgba(255, 200, 120, 0.35) !important;
|
| 93 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
#
|
| 98 |
-
#
|
| 99 |
ACTIONS = ["↑", "→", "↓", "←"]
|
| 100 |
ACTION_DELTAS = {
|
| 101 |
-
0: (-1, 0),
|
| 102 |
-
1: (0, 1),
|
| 103 |
-
2: (1, 0),
|
| 104 |
-
3: (0, -1),
|
| 105 |
}
|
| 106 |
|
| 107 |
-
def
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
|
| 113 |
-
def
|
|
|
|
| 114 |
q = deque([start])
|
| 115 |
seen = {start}
|
| 116 |
while q:
|
| 117 |
cur = q.popleft()
|
| 118 |
if cur == goal:
|
| 119 |
return True
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
return False
|
| 125 |
|
| 126 |
-
def generate_obstacles(size, start, goal, density, rng):
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
for r in range(size):
|
| 132 |
for c in range(size):
|
| 133 |
-
|
|
|
|
| 134 |
continue
|
| 135 |
-
if rng.random() <
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
return walls, lava
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
return set(), set()
|
| 141 |
|
| 142 |
class GridWorld:
|
| 143 |
-
def __init__(self, size,
|
| 144 |
-
self.size = size
|
| 145 |
-
self.start =
|
| 146 |
-
self.goal = (size-1, size-1)
|
| 147 |
-
self.
|
| 148 |
-
self.
|
| 149 |
self.reset()
|
| 150 |
|
| 151 |
def reset(self):
|
|
@@ -155,42 +187,58 @@ class GridWorld:
|
|
| 155 |
def step(self, action):
|
| 156 |
dr, dc = ACTION_DELTAS[action]
|
| 157 |
r, c = self.pos
|
| 158 |
-
nr, nc = r+dr, c+dc
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
nr, nc = r, c
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
if self.pos == self.goal:
|
| 165 |
-
return self.pos, 10.0, True
|
| 166 |
if self.pos in self.lava:
|
| 167 |
return self.pos, -10.0, True
|
| 168 |
-
return self.pos, -0.1, False
|
| 169 |
|
| 170 |
-
#
|
| 171 |
-
# 🧠 Q-LEARNING AGENT
|
| 172 |
-
# =========================================================
|
| 173 |
-
class QAgent:
|
| 174 |
-
def __init__(self, size, alpha, gamma):
|
| 175 |
-
self.Q = np.zeros((size, size, 4), dtype=np.float32)
|
| 176 |
-
self.alpha = alpha
|
| 177 |
-
self.gamma = gamma
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
def update(self, s, a, r, s2, done):
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
def fig_to_pil(fig):
|
| 195 |
buf = BytesIO()
|
| 196 |
fig.savefig(buf, format="png", dpi=160, bbox_inches="tight")
|
|
@@ -198,116 +246,366 @@ def fig_to_pil(fig):
|
|
| 198 |
buf.seek(0)
|
| 199 |
return Image.open(buf)
|
| 200 |
|
| 201 |
-
def
|
| 202 |
n = env.size
|
| 203 |
-
fig, ax = plt.subplots(figsize=(5.4,5.4))
|
| 204 |
-
ax.set_xlim(0,
|
|
|
|
|
|
|
| 205 |
ax.axis("off")
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
for r in range(n):
|
| 208 |
for c in range(n):
|
| 209 |
-
x,y = c,n-1-r
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
if (r,c)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return fig_to_pil(fig)
|
| 229 |
|
| 230 |
-
#
|
| 231 |
-
#
|
| 232 |
-
#
|
| 233 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
rng = np.random.default_rng()
|
| 235 |
-
walls,lava = generate_obstacles(
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
s = env.reset()
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
agent.
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
a = agent.act_greedy(s)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
with gr.Row():
|
| 271 |
with gr.Column(scale=3):
|
| 272 |
-
gr.Markdown(
|
|
|
|
| 273 |
### 🤖 Een robot in het magazijn
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
with gr.Column(scale=2):
|
| 278 |
-
gr.Image(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
|
|
|
| 280 |
q_state = gr.State(None)
|
| 281 |
env_state = gr.State(None)
|
| 282 |
|
| 283 |
with gr.Row():
|
| 284 |
-
with gr.Column():
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
train_btn.click(
|
| 302 |
-
|
| 303 |
-
inputs=[
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
)
|
| 306 |
|
| 307 |
play_btn.click(
|
| 308 |
-
|
| 309 |
-
inputs=[q_state,env_state,
|
| 310 |
-
outputs=
|
| 311 |
)
|
| 312 |
|
| 313 |
-
|
|
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
# Gridworld RL (Q-learning) with:
|
| 3 |
+
# ✅ Original visualization + layout for the demo section (unchanged)
|
| 4 |
+
# ✅ Non-flickering learning curve (always visible)
|
| 5 |
+
# ✅ Obstacle density slider (auto-generate more/less blocks)
|
| 6 |
+
# ✅ Train uses epsilon decay (converges); Play shows deterministic route (epsilon=0)
|
| 7 |
+
# ✅ Same obstacle layout is reused for Play (stored in state)
|
| 8 |
+
# ✅ Styling (Option 2): dark background + calmer amber/orange accents
|
| 9 |
+
# ✅ Header: text LEFT, photo RIGHT
|
| 10 |
+
# ✅ Removed the extra RL-description block (as you requested earlier)
|
| 11 |
|
| 12 |
import time
|
| 13 |
import numpy as np
|
|
|
|
| 19 |
from collections import deque
|
| 20 |
|
| 21 |
# =========================================================
|
| 22 |
+
# 🎨 CUSTOM CSS (Option 2: calm industrial robotics)
|
| 23 |
# =========================================================
|
| 24 |
CUSTOM_CSS = """
|
| 25 |
body {
|
|
|
|
| 41 |
/* Headings */
|
| 42 |
h1, h2, h3 {
|
| 43 |
color: #ffd27d;
|
| 44 |
+
letter-spacing: 0.04em;
|
| 45 |
}
|
| 46 |
|
| 47 |
/* Text */
|
| 48 |
+
p, li, .md p {
|
| 49 |
color: #d6e6ff;
|
| 50 |
}
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
/* Labels */
|
| 53 |
label {
|
| 54 |
color: #ffddaa !important;
|
|
|
|
| 80 |
background: linear-gradient(90deg, #ffb347, #ffcc80) !important;
|
| 81 |
color: #1a0f02 !important;
|
| 82 |
border: none !important;
|
| 83 |
+
box-shadow: 0 0 16px rgba(255, 180, 80, 0.40);
|
| 84 |
}
|
| 85 |
|
| 86 |
button.secondary {
|
| 87 |
+
background: rgba(12, 20, 40, 0.92) !important;
|
| 88 |
color: #ffd9a0 !important;
|
| 89 |
border: 1px solid rgba(255, 200, 120, 0.35) !important;
|
| 90 |
}
|
| 91 |
+
|
| 92 |
+
/* Accordions / panels - keep subtle */
|
| 93 |
+
.gr-accordion, .gr-box, .gr-panel, .gr-group {
|
| 94 |
+
background: radial-gradient(circle at top left,
|
| 95 |
+
rgba(255, 200, 120, 0.06),
|
| 96 |
+
rgba(4, 9, 29, 0.98)) !important;
|
| 97 |
+
border: 1px solid rgba(255, 200, 120, 0.18) !important;
|
| 98 |
+
border-radius: 18px !important;
|
| 99 |
+
box-shadow: 0 0 18px rgba(255, 180, 80, 0.10);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
/* Image containers - do not affect the pixels */
|
| 103 |
+
img {
|
| 104 |
+
border-radius: 16px;
|
| 105 |
+
}
|
| 106 |
"""
|
| 107 |
|
| 108 |
+
# -----------------------------
|
| 109 |
+
# Gridworld Environment
|
| 110 |
+
# -----------------------------
|
| 111 |
ACTIONS = ["↑", "→", "↓", "←"]
|
| 112 |
ACTION_DELTAS = {
|
| 113 |
+
0: (-1, 0), # up
|
| 114 |
+
1: (0, 1), # right
|
| 115 |
+
2: (1, 0), # down
|
| 116 |
+
3: (0, -1), # left
|
| 117 |
}
|
| 118 |
|
| 119 |
+
def _neighbors(r, c, n):
|
| 120 |
+
if r > 0: yield (r - 1, c)
|
| 121 |
+
if r < n - 1: yield (r + 1, c)
|
| 122 |
+
if c > 0: yield (r, c - 1)
|
| 123 |
+
if c < n - 1: yield (r, c + 1)
|
| 124 |
|
| 125 |
+
def _has_path(size, start, goal, blocked):
|
| 126 |
+
"""BFS to ensure there's at least one safe path from start to goal."""
|
| 127 |
q = deque([start])
|
| 128 |
seen = {start}
|
| 129 |
while q:
|
| 130 |
cur = q.popleft()
|
| 131 |
if cur == goal:
|
| 132 |
return True
|
| 133 |
+
r, c = cur
|
| 134 |
+
for nr, nc in _neighbors(r, c, size):
|
| 135 |
+
nxt = (nr, nc)
|
| 136 |
+
if nxt in seen or nxt in blocked:
|
| 137 |
+
continue
|
| 138 |
+
seen.add(nxt)
|
| 139 |
+
q.append(nxt)
|
| 140 |
return False
|
| 141 |
|
| 142 |
+
def generate_obstacles(size, start, goal, density, wall_ratio=0.7, max_tries=60, rng=None):
|
| 143 |
+
"""
|
| 144 |
+
Generate walls + lava with a given density, retrying until there is a safe path.
|
| 145 |
+
Lava is treated as blocked (terminal negative), so we keep at least one safe route.
|
| 146 |
+
"""
|
| 147 |
+
rng = rng or np.random.default_rng()
|
| 148 |
+
density = float(np.clip(density, 0.0, 0.60))
|
| 149 |
+
|
| 150 |
+
cur_density = density
|
| 151 |
+
for _ in range(max_tries):
|
| 152 |
+
walls = set()
|
| 153 |
+
lava = set()
|
| 154 |
+
|
| 155 |
for r in range(size):
|
| 156 |
for c in range(size):
|
| 157 |
+
cell = (r, c)
|
| 158 |
+
if cell == start or cell == goal:
|
| 159 |
continue
|
| 160 |
+
if rng.random() < cur_density:
|
| 161 |
+
if rng.random() < wall_ratio:
|
| 162 |
+
walls.add(cell)
|
| 163 |
+
else:
|
| 164 |
+
lava.add(cell)
|
| 165 |
+
|
| 166 |
+
blocked = walls | lava
|
| 167 |
+
if _has_path(size, start, goal, blocked):
|
| 168 |
return walls, lava
|
| 169 |
+
|
| 170 |
+
cur_density = max(0.0, cur_density - 0.02)
|
| 171 |
+
|
| 172 |
return set(), set()
|
| 173 |
|
| 174 |
class GridWorld:
|
| 175 |
+
def __init__(self, size=5, start=(0, 0), goal=None, lava=None, walls=None):
|
| 176 |
+
self.size = int(size)
|
| 177 |
+
self.start = start
|
| 178 |
+
self.goal = goal if goal is not None else (self.size - 1, self.size - 1)
|
| 179 |
+
self.lava = set(lava or [])
|
| 180 |
+
self.walls = set(walls or [])
|
| 181 |
self.reset()
|
| 182 |
|
| 183 |
def reset(self):
|
|
|
|
| 187 |
def step(self, action):
|
| 188 |
dr, dc = ACTION_DELTAS[action]
|
| 189 |
r, c = self.pos
|
| 190 |
+
nr, nc = r + dr, c + dc
|
| 191 |
+
|
| 192 |
+
# bounds check
|
| 193 |
+
if nr < 0 or nr >= self.size or nc < 0 or nc >= self.size:
|
| 194 |
nr, nc = r, c
|
| 195 |
+
|
| 196 |
+
# wall check
|
| 197 |
+
if (nr, nc) in self.walls:
|
| 198 |
+
nr, nc = r, c
|
| 199 |
+
|
| 200 |
+
self.pos = (nr, nc)
|
| 201 |
+
|
| 202 |
+
# rewards
|
| 203 |
if self.pos == self.goal:
|
| 204 |
+
return self.pos, +10.0, True
|
| 205 |
if self.pos in self.lava:
|
| 206 |
return self.pos, -10.0, True
|
|
|
|
| 207 |
|
| 208 |
+
return self.pos, -0.1, False # step penalty -> shortest safe path is optimal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
# -----------------------------
|
| 211 |
+
# Q-Learning Agent
|
| 212 |
+
# -----------------------------
|
| 213 |
+
class QAgent:
|
| 214 |
+
def __init__(self, size=5, n_actions=4, alpha=0.3, gamma=0.95):
|
| 215 |
+
self.size = int(size)
|
| 216 |
+
self.n_actions = n_actions
|
| 217 |
+
self.alpha = float(alpha)
|
| 218 |
+
self.gamma = float(gamma)
|
| 219 |
+
self.Q = np.zeros((self.size, self.size, n_actions), dtype=np.float32)
|
| 220 |
+
|
| 221 |
+
def act(self, state, epsilon):
|
| 222 |
+
r, c = state
|
| 223 |
+
if np.random.rand() < float(epsilon):
|
| 224 |
+
return np.random.randint(self.n_actions)
|
| 225 |
+
return int(np.argmax(self.Q[r, c]))
|
| 226 |
+
|
| 227 |
+
def act_greedy(self, state):
|
| 228 |
+
r, c = state
|
| 229 |
+
return int(np.argmax(self.Q[r, c]))
|
| 230 |
|
| 231 |
def update(self, s, a, r, s2, done):
|
| 232 |
+
r1, c1 = s
|
| 233 |
+
r2, c2 = s2
|
| 234 |
+
best_next = 0.0 if done else float(np.max(self.Q[r2, c2]))
|
| 235 |
+
td_target = r + self.gamma * best_next
|
| 236 |
+
td_error = td_target - float(self.Q[r1, c1, a])
|
| 237 |
+
self.Q[r1, c1, a] += self.alpha * td_error
|
| 238 |
+
|
| 239 |
+
# -----------------------------
|
| 240 |
+
# Rendering helpers (ORIGINAL look)
|
| 241 |
+
# -----------------------------
|
| 242 |
def fig_to_pil(fig):
|
| 243 |
buf = BytesIO()
|
| 244 |
fig.savefig(buf, format="png", dpi=160, bbox_inches="tight")
|
|
|
|
| 246 |
buf.seek(0)
|
| 247 |
return Image.open(buf)
|
| 248 |
|
| 249 |
+
def draw_grid(env: GridWorld, agent: QAgent = None, show_q=False, episode=None, step_i=None, total_reward=None):
|
| 250 |
n = env.size
|
| 251 |
+
fig, ax = plt.subplots(figsize=(5.4, 5.4))
|
| 252 |
+
ax.set_xlim(0, n)
|
| 253 |
+
ax.set_ylim(0, n)
|
| 254 |
+
ax.set_aspect("equal")
|
| 255 |
ax.axis("off")
|
| 256 |
+
|
| 257 |
+
# Background (keep original)
|
| 258 |
+
ax.add_patch(Rectangle((0, 0), n, n, facecolor="#0b1020"))
|
| 259 |
+
|
| 260 |
+
# Draw cells
|
| 261 |
for r in range(n):
|
| 262 |
for c in range(n):
|
| 263 |
+
x, y = c, n - 1 - r # invert y so (0,0) is top-left visually
|
| 264 |
+
|
| 265 |
+
tile_color = "#121a33"
|
| 266 |
+
if (r, c) == env.goal:
|
| 267 |
+
tile_color = "#0f2f1f"
|
| 268 |
+
if (r, c) in env.lava:
|
| 269 |
+
tile_color = "#3a1414"
|
| 270 |
+
if (r, c) in env.walls:
|
| 271 |
+
tile_color = "#1b1b1b"
|
| 272 |
+
|
| 273 |
+
ax.add_patch(
|
| 274 |
+
FancyBboxPatch(
|
| 275 |
+
(x + 0.05, y + 0.05), 0.9, 0.9,
|
| 276 |
+
boxstyle="round,pad=0.02,rounding_size=0.08",
|
| 277 |
+
linewidth=1.0,
|
| 278 |
+
edgecolor="#2a355f",
|
| 279 |
+
facecolor=tile_color,
|
| 280 |
+
alpha=0.95
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# overlay Q hint (optional)
|
| 285 |
+
if show_q and agent is not None and (r, c) not in env.walls:
|
| 286 |
+
best_a = int(np.argmax(agent.Q[r, c]))
|
| 287 |
+
qv = float(np.max(agent.Q[r, c]))
|
| 288 |
+
ax.text(x + 0.5, y + 0.55, ACTIONS[best_a], ha="center", va="center",
|
| 289 |
+
fontsize=14, color="#d7e3ff", alpha=0.65)
|
| 290 |
+
ax.text(x + 0.5, y + 0.30, f"{qv:+.2f}", ha="center", va="center",
|
| 291 |
+
fontsize=9, color="#a9b7e6", alpha=0.55)
|
| 292 |
+
|
| 293 |
+
# Icons
|
| 294 |
+
def put_icon(rc, icon, color="#ffffff"):
|
| 295 |
+
r, c = rc
|
| 296 |
+
x, y = c + 0.5, (n - 1 - r) + 0.5
|
| 297 |
+
ax.text(x, y, icon, ha="center", va="center", fontsize=22, color=color)
|
| 298 |
+
|
| 299 |
+
put_icon(env.goal, "🏁")
|
| 300 |
+
for rc in env.lava:
|
| 301 |
+
put_icon(rc, "🔥")
|
| 302 |
+
for rc in env.walls:
|
| 303 |
+
put_icon(rc, "🧱")
|
| 304 |
+
|
| 305 |
+
# Agent
|
| 306 |
+
put_icon(env.pos, "🤖")
|
| 307 |
+
|
| 308 |
+
# Header overlay
|
| 309 |
+
title = "Gridworld RL • Q-learning"
|
| 310 |
+
sub = []
|
| 311 |
+
if episode is not None:
|
| 312 |
+
sub.append(f"Episode: {episode}")
|
| 313 |
+
if step_i is not None:
|
| 314 |
+
sub.append(f"Step: {step_i}")
|
| 315 |
+
if total_reward is not None:
|
| 316 |
+
sub.append(f"Return: {total_reward:+.2f}")
|
| 317 |
+
subtitle = " • ".join(sub)
|
| 318 |
+
|
| 319 |
+
ax.text(0, n + 0.35, title, fontsize=14, color="#eaf0ff", weight="bold")
|
| 320 |
+
ax.text(0, n + 0.08, subtitle, fontsize=10, color="#b8c6ff", alpha=0.9)
|
| 321 |
+
|
| 322 |
return fig_to_pil(fig)
|
| 323 |
|
| 324 |
+
# -----------------------------
|
| 325 |
+
# Learning curve chart (no flicker)
|
| 326 |
+
# -----------------------------
|
| 327 |
+
def moving_average(x, window=25):
|
| 328 |
+
if len(x) < 2:
|
| 329 |
+
return np.array(x, dtype=float)
|
| 330 |
+
w = max(2, min(int(window), len(x)))
|
| 331 |
+
kernel = np.ones(w) / w
|
| 332 |
+
return np.convolve(np.array(x, dtype=float), kernel, mode="valid")
|
| 333 |
+
|
| 334 |
+
def draw_learning_curve(returns, successes, window=25):
|
| 335 |
+
fig, ax = plt.subplots(figsize=(5.4, 4.6))
|
| 336 |
+
ax.set_facecolor("#0b1020")
|
| 337 |
+
for spine in ax.spines.values():
|
| 338 |
+
spine.set_color("#2a355f")
|
| 339 |
+
ax.tick_params(colors="#c9d6ff")
|
| 340 |
+
ax.yaxis.label.set_color("#c9d6ff")
|
| 341 |
+
ax.xaxis.label.set_color("#c9d6ff")
|
| 342 |
+
ax.title.set_color("#eaf0ff")
|
| 343 |
+
|
| 344 |
+
ax.set_title("Learning curve")
|
| 345 |
+
ax.set_xlabel("Episode")
|
| 346 |
+
ax.set_ylabel("Return")
|
| 347 |
+
|
| 348 |
+
if len(returns) > 0:
|
| 349 |
+
xs = np.arange(1, len(returns) + 1)
|
| 350 |
+
ax.plot(xs, returns, linewidth=1.5, alpha=0.9, label="Return")
|
| 351 |
+
|
| 352 |
+
ma = moving_average(returns, window=window)
|
| 353 |
+
if len(ma) > 0:
|
| 354 |
+
xs_ma = np.arange(len(returns) - len(ma) + 1, len(returns) + 1)
|
| 355 |
+
ax.plot(xs_ma, ma, linewidth=2.5, alpha=0.95,
|
| 356 |
+
label=f"Moving avg ({min(int(window), len(returns))})")
|
| 357 |
+
|
| 358 |
+
ax2 = ax.twinx()
|
| 359 |
+
ax2.tick_params(colors="#c9d6ff")
|
| 360 |
+
ax2.spines["right"].set_color("#2a355f")
|
| 361 |
+
ax2.set_ylabel("Success rate", color="#c9d6ff")
|
| 362 |
+
|
| 363 |
+
if len(successes) > 0:
|
| 364 |
+
xs = np.arange(1, len(successes) + 1)
|
| 365 |
+
sr = np.cumsum(np.array(successes, dtype=float)) / xs
|
| 366 |
+
ax2.plot(xs, sr, linewidth=2.0, alpha=0.8, label="Success rate")
|
| 367 |
+
|
| 368 |
+
lines, labels = ax.get_legend_handles_labels()
|
| 369 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
| 370 |
+
ax.legend(lines + lines2, labels + labels2, loc="lower right", framealpha=0.2)
|
| 371 |
+
|
| 372 |
+
ax.grid(True, alpha=0.15)
|
| 373 |
+
return fig_to_pil(fig)
|
| 374 |
+
|
| 375 |
+
# -----------------------------
|
| 376 |
+
# Training + Playback (store env layout so Play matches Train)
|
| 377 |
+
# -----------------------------
|
| 378 |
+
def make_env_and_agent(grid_size, obstacle_density, alpha, gamma):
|
| 379 |
+
size = int(grid_size)
|
| 380 |
+
start = (0, 0)
|
| 381 |
+
goal = (size - 1, size - 1)
|
| 382 |
+
|
| 383 |
rng = np.random.default_rng()
|
| 384 |
+
walls, lava = generate_obstacles(size, start, goal, density=float(obstacle_density), wall_ratio=0.7, rng=rng)
|
| 385 |
+
|
| 386 |
+
env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
|
| 387 |
+
agent = QAgent(size=size, alpha=alpha, gamma=gamma)
|
| 388 |
+
|
| 389 |
+
env_state = {
|
| 390 |
+
"size": size,
|
| 391 |
+
"start": start,
|
| 392 |
+
"goal": goal,
|
| 393 |
+
"walls": sorted(list(walls)),
|
| 394 |
+
"lava": sorted(list(lava)),
|
| 395 |
+
}
|
| 396 |
+
return env, agent, env_state
|
| 397 |
+
|
| 398 |
+
def train_stream(
|
| 399 |
+
grid_size,
|
| 400 |
+
obstacle_density,
|
| 401 |
+
alpha,
|
| 402 |
+
gamma,
|
| 403 |
+
eps_start,
|
| 404 |
+
eps_end,
|
| 405 |
+
eps_decay,
|
| 406 |
+
episodes,
|
| 407 |
+
max_steps,
|
| 408 |
+
speed,
|
| 409 |
+
show_q_overlay,
|
| 410 |
+
curve_window,
|
| 411 |
+
):
|
| 412 |
+
env, agent, env_state = make_env_and_agent(grid_size, obstacle_density, alpha, gamma)
|
| 413 |
+
eps = float(eps_start)
|
| 414 |
+
|
| 415 |
+
returns = []
|
| 416 |
+
successes = []
|
| 417 |
+
|
| 418 |
+
# initial
|
| 419 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=0, step_i=0, total_reward=0.0)
|
| 420 |
+
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 421 |
+
status = f"Kies density en klik Train. (Obstacle density={float(obstacle_density):.2f})"
|
| 422 |
+
yield frame, last_curve, agent.Q, env_state, status
|
| 423 |
+
|
| 424 |
+
CURVE_UPDATE_EVERY_STEPS = 8
|
| 425 |
+
|
| 426 |
+
for ep in range(1, int(episodes) + 1):
|
| 427 |
s = env.reset()
|
| 428 |
+
total_r = 0.0
|
| 429 |
+
reached_goal_this_ep = 0
|
| 430 |
+
|
| 431 |
+
for t in range(1, int(max_steps) + 1):
|
| 432 |
+
a = agent.act(s, epsilon=eps)
|
| 433 |
+
s2, r, done = env.step(a)
|
| 434 |
+
agent.update(s, a, r, s2, done)
|
| 435 |
+
s = s2
|
| 436 |
+
total_r += r
|
| 437 |
+
|
| 438 |
+
if done and env.pos == env.goal:
|
| 439 |
+
reached_goal_this_ep = 1
|
| 440 |
+
|
| 441 |
+
if (t % CURVE_UPDATE_EVERY_STEPS == 0) or done:
|
| 442 |
+
preview_returns = returns + [total_r]
|
| 443 |
+
preview_successes = successes + [reached_goal_this_ep]
|
| 444 |
+
last_curve = draw_learning_curve(preview_returns, preview_successes, window=int(curve_window))
|
| 445 |
+
|
| 446 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=ep, step_i=t, total_reward=total_r)
|
| 447 |
+
status = f"Train • ep {ep}/{episodes} • step {t}/{max_steps} • return {total_r:+.2f} • eps {eps:.3f}"
|
| 448 |
+
yield frame, last_curve, agent.Q, env_state, status
|
| 449 |
+
|
| 450 |
+
if speed > 0:
|
| 451 |
+
time.sleep(float(speed))
|
| 452 |
+
|
| 453 |
+
if done:
|
| 454 |
+
break
|
| 455 |
+
|
| 456 |
+
returns.append(total_r)
|
| 457 |
+
successes.append(reached_goal_this_ep)
|
| 458 |
+
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 459 |
+
yield frame, last_curve, agent.Q, env_state, f"Episode {ep} klaar • return {total_r:+.2f} • success={reached_goal_this_ep} • eps {eps:.3f}"
|
| 460 |
+
|
| 461 |
+
eps = max(float(eps_end), eps * float(eps_decay))
|
| 462 |
+
|
| 463 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=episodes, step_i=None, total_reward=None)
|
| 464 |
+
last_curve = draw_learning_curve(returns, successes, window=int(curve_window))
|
| 465 |
+
status = "Training klaar ✅ Klik nu op ‘Play learned policy’."
|
| 466 |
+
yield frame, last_curve, agent.Q, env_state, status
|
| 467 |
+
|
| 468 |
+
def play_stream(q_table, env_state, speed, show_q_overlay, max_steps):
|
| 469 |
+
if q_table is None or env_state is None:
|
| 470 |
+
env = GridWorld(size=5, start=(0, 0), goal=(4, 4), walls=[], lava=[])
|
| 471 |
+
agent = QAgent(size=5)
|
| 472 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode=None, step_i=None, total_reward=None)
|
| 473 |
+
curve = draw_learning_curve([], [], window=25)
|
| 474 |
+
yield frame, curve, "Nog geen training gedaan. Klik eerst op Train."
|
| 475 |
+
return
|
| 476 |
+
|
| 477 |
+
size = int(env_state["size"])
|
| 478 |
+
start = tuple(env_state["start"])
|
| 479 |
+
goal = tuple(env_state["goal"])
|
| 480 |
+
walls = [tuple(x) for x in env_state["walls"]]
|
| 481 |
+
lava = [tuple(x) for x in env_state["lava"]]
|
| 482 |
+
|
| 483 |
+
env = GridWorld(size=size, start=start, goal=goal, walls=walls, lava=lava)
|
| 484 |
+
agent = QAgent(size=size)
|
| 485 |
+
agent.Q = np.array(q_table, dtype=np.float32)
|
| 486 |
+
|
| 487 |
+
s = env.reset()
|
| 488 |
+
total_r = 0.0
|
| 489 |
+
|
| 490 |
+
curve = draw_learning_curve([], [], window=25) # keep curve visible (static) during play
|
| 491 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=0, total_reward=total_r)
|
| 492 |
+
yield frame, curve, "Play • epsilon=0.0 (deterministisch)"
|
| 493 |
+
|
| 494 |
+
for t in range(1, int(max_steps) + 1):
|
| 495 |
a = agent.act_greedy(s)
|
| 496 |
+
s2, r, done = env.step(a)
|
| 497 |
+
s = s2
|
| 498 |
+
total_r += r
|
| 499 |
+
|
| 500 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=t, total_reward=total_r)
|
| 501 |
+
yield frame, curve, f"Play • step {t}/{max_steps} • return {total_r:+.2f}"
|
| 502 |
+
|
| 503 |
+
if speed > 0:
|
| 504 |
+
time.sleep(float(speed))
|
| 505 |
+
if done:
|
| 506 |
+
break
|
| 507 |
+
|
| 508 |
+
if env.pos == env.goal:
|
| 509 |
+
end = f"🏁 Goal bereikt! return {total_r:+.2f}"
|
| 510 |
+
elif env.pos in env.lava:
|
| 511 |
+
end = "🔥 In lava beland. Tip: train langer of zet density lager."
|
| 512 |
+
else:
|
| 513 |
+
end = "Play klaar. Tip: train langer of max_steps omhoog."
|
| 514 |
+
|
| 515 |
+
frame = draw_grid(env, agent, show_q=show_q_overlay, episode="PLAY", step_i=None, total_reward=total_r)
|
| 516 |
+
yield frame, curve, end
|
| 517 |
+
|
| 518 |
+
# -----------------------------
|
| 519 |
+
# Gradio UI (layout stays the same)
|
| 520 |
+
# -----------------------------
|
| 521 |
+
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), title="Warehouse Robot RL Demo") as demo:
|
| 522 |
+
# Header: text LEFT, image RIGHT (as you requested)
|
| 523 |
with gr.Row():
|
| 524 |
with gr.Column(scale=3):
|
| 525 |
+
gr.Markdown(
|
| 526 |
+
"""
|
| 527 |
### 🤖 Een robot in het magazijn
|
| 528 |
+
|
| 529 |
+
Stel je voor: je werkt in een groot magazijn.
|
| 530 |
+
Tussen de stellingen rijdt een robot rond die bestellingen moet ophalen en naar het inpakstation brengen.
|
| 531 |
+
Die robot krijgt geen kaart, geen regels en geen instructies over wat de snelste route is.
|
| 532 |
+
In deze demo zie je hoe zo’n robot zelf leert wat slim gedrag is.
|
| 533 |
+
|
| 534 |
+
In het begin rijdt hij willekeurig rond en maakt hij fouten.
|
| 535 |
+
Maar naarmate hij meer ervaring opdoet, ontdekt hij vanzelf: hoe hij veilig, efficiënt en zo snel mogelijk door het magazijn kan bewegen.
|
| 536 |
+
|
| 537 |
+
Boven zie je de robot rijden tussen stellingen en gevaarlijke zones.
|
| 538 |
+
Onder zie je hoe zijn prestaties verbeteren naarmate hij leert.
|
| 539 |
+
|
| 540 |
+
👉 Probeer het zelf: maak het magazijn makkelijker of moeilijker, train de robot,
|
| 541 |
+
en laat daarna zien wat hij geleerd heeft.
|
| 542 |
+
"""
|
| 543 |
+
)
|
| 544 |
with gr.Column(scale=2):
|
| 545 |
+
gr.Image(
|
| 546 |
+
value="humanoid-robot-apptronic-1024x684.jpg.webp",
|
| 547 |
+
show_label=False,
|
| 548 |
+
height=340,
|
| 549 |
+
)
|
| 550 |
|
| 551 |
+
# ---- Demo section (unchanged) ----
|
| 552 |
q_state = gr.State(None)
|
| 553 |
env_state = gr.State(None)
|
| 554 |
|
| 555 |
with gr.Row():
|
| 556 |
+
with gr.Column(scale=1):
|
| 557 |
+
grid_size = gr.Slider(4, 10, value=5, step=1, label="Grid size")
|
| 558 |
+
|
| 559 |
+
obstacle_density = gr.Slider(
|
| 560 |
+
0.0, 0.45, value=0.15, step=0.05,
|
| 561 |
+
label="Obstacle density (meer blokken/gevaar)"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
with gr.Accordion("RL parameters (defaults = goede convergentie)", open=True):
|
| 565 |
+
alpha = gr.Slider(0.01, 1.0, value=0.45, step=0.01, label="Alpha (learning rate)")
|
| 566 |
+
gamma = gr.Slider(0.0, 0.999, value=0.97, step=0.001, label="Gamma (discount)")
|
| 567 |
+
|
| 568 |
+
with gr.Accordion("Exploration (epsilon decay)", open=True):
|
| 569 |
+
eps_start = gr.Slider(0.0, 1.0, value=0.90, step=0.01, label="Epsilon start (veel explore)")
|
| 570 |
+
eps_end = gr.Slider(0.0, 0.2, value=0.02, step=0.005, label="Epsilon end (bijna greedy)")
|
| 571 |
+
eps_decay = gr.Slider(0.90, 0.999, value=0.985, step=0.001, label="Epsilon decay per episode")
|
| 572 |
+
|
| 573 |
+
episodes = gr.Slider(1, 400, value=200, step=1, label="Episodes")
|
| 574 |
+
max_steps_train = gr.Slider(5, 200, value=60, step=1, label="Max steps per episode")
|
| 575 |
+
|
| 576 |
+
with gr.Accordion("Visuals & snelheid", open=True):
|
| 577 |
+
speed = gr.Slider(0.0, 0.3, value=0.02, step=0.01, label="Animatie vertraging (sec/frame)")
|
| 578 |
+
show_q_overlay = gr.Checkbox(value=True, label="Toon beste actie & Q-waarde per vakje (overlay)")
|
| 579 |
+
curve_window = gr.Slider(5, 80, value=25, step=1, label="Moving average window (episodes)")
|
| 580 |
+
|
| 581 |
+
with gr.Row():
|
| 582 |
+
train_btn = gr.Button("🚀 Train (epsilon decay)", variant="primary")
|
| 583 |
+
play_btn = gr.Button("▶️ Play learned policy (epsilon=0)")
|
| 584 |
+
|
| 585 |
+
status = gr.Textbox(label="Status", value="Kies density en klik Train.", interactive=False)
|
| 586 |
+
|
| 587 |
+
with gr.Column(scale=1):
|
| 588 |
+
frame_out = gr.Image(label="Live animatie", type="pil", height=520)
|
| 589 |
+
curve_out = gr.Image(label="Learning curve (live)", type="pil", height=420)
|
| 590 |
|
| 591 |
train_btn.click(
|
| 592 |
+
fn=train_stream,
|
| 593 |
+
inputs=[
|
| 594 |
+
grid_size,
|
| 595 |
+
obstacle_density,
|
| 596 |
+
alpha, gamma,
|
| 597 |
+
eps_start, eps_end, eps_decay,
|
| 598 |
+
episodes, max_steps_train,
|
| 599 |
+
speed, show_q_overlay, curve_window
|
| 600 |
+
],
|
| 601 |
+
outputs=[frame_out, curve_out, q_state, env_state, status],
|
| 602 |
)
|
| 603 |
|
| 604 |
play_btn.click(
|
| 605 |
+
fn=play_stream,
|
| 606 |
+
inputs=[q_state, env_state, speed, show_q_overlay, max_steps_train],
|
| 607 |
+
outputs=[frame_out, curve_out, status],
|
| 608 |
)
|
| 609 |
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
demo.launch()
|