anthonym21 commited on
Commit
88b4be3
·
verified ·
1 Parent(s): 38f1411

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +205 -521
app.py CHANGED
@@ -1,584 +1,268 @@
1
  """
2
- World Model Demo - Interactive Visualization
3
- A Hugging Face Space demonstrating the three phases of world model learning:
4
- 1. Exploration (Motor Babbling) - Random exploration to learn physics
5
- 2. Dreaming (Planning) - Using learned model to plan without acting
6
- 3. Execution - Following the plan in reality
7
-
8
- Based on the concept that intelligent agents build internal models of their world.
9
  """
10
 
11
  import gradio as gr
12
  import random
13
- import time
14
- from collections import deque
15
- from typing import Optional
16
  import json
17
 
18
- # ==========================================
19
- # 1. THE ENVIRONMENT (Reality)
20
- # ==========================================
21
- class GridEnvironment:
22
- """The ground truth physics engine."""
 
23
 
24
- def __init__(self, size: int = 4, obstacles: set = None):
25
  self.size = size
26
- self.agent_pos = (0, 0)
27
- self.goal = (size - 1, size - 1)
28
- self.obstacles = obstacles if obstacles else {(1, 1), (1, 2), (2, 2)}
29
 
30
  def reset(self):
31
- self.agent_pos = (0, 0)
32
- return self.agent_pos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def step(self, action: int):
35
- """Execute action: 0=Up, 1=Down, 2=Left, 3=Right"""
36
- x, y = self.agent_pos
 
37
 
38
- if action == 0: y -= 1 # Up
39
- elif action == 1: y += 1 # Down
40
- elif action == 2: x -= 1 # Left
41
- elif action == 3: x += 1 # Right
42
 
43
- # Check boundaries and walls
44
- if 0 <= x < self.size and 0 <= y < self.size and (x, y) not in self.obstacles:
45
- self.agent_pos = (x, y)
46
 
47
- return self.agent_pos
48
-
49
 
50
- # ==========================================
51
- # 2. THE WORLD MODEL (The Brain)
52
- # ==========================================
53
  class WorldModel:
54
- """Internal simulation learned from experience."""
55
 
56
  def __init__(self):
57
- self.transitions = {}
58
-
59
- def learn(self, state, action, next_state):
60
- self.transitions[(state, action)] = next_state
61
 
62
  def predict(self, state, action):
63
- return self.transitions.get((state, action))
64
-
65
- def get_learned_states(self):
66
- """Return all states the model has learned about."""
67
- states = set()
68
- for (state, _), next_state in self.transitions.items():
69
- states.add(state)
70
- states.add(next_state)
71
- return states
72
-
73
-
74
- # ==========================================
75
- # 3. THE AGENT (The Controller)
76
- # ==========================================
77
- class Agent:
78
- """The intelligent agent with world model."""
79
-
80
- def __init__(self):
81
- self.model = WorldModel()
82
- self.actions = [0, 1, 2, 3]
83
- self.action_names = ["↑ Up", "↓ Down", "← Left", "→ Right"]
84
- self.exploration_history = []
85
-
86
- def explore_step(self, env):
87
- """Single exploration step."""
88
- state = env.agent_pos
89
- action = random.choice(self.actions)
90
- next_state = env.step(action)
91
- self.model.learn(state, action, next_state)
92
 
93
- return {
94
- 'state': state,
95
- 'action': action,
96
- 'action_name': self.action_names[action],
97
- 'next_state': next_state,
98
- 'bounced': state == next_state
99
- }
 
 
100
 
101
- def dream_and_plan(self, start, goal):
102
- """BFS planning using only the world model."""
103
- queue = deque([(start, [])])
104
- visited = {start}
105
- search_states = [] # Track states visited during planning
106
 
107
- while queue:
108
- curr_state, path = queue.popleft()
109
- search_states.append(curr_state)
110
-
111
- if curr_state == goal:
112
- return path, search_states, True
113
-
114
- for action in self.actions:
115
- predicted_next = self.model.predict(curr_state, action)
116
-
117
- if predicted_next is not None and predicted_next not in visited:
118
- visited.add(predicted_next)
119
- new_path = path + [action]
120
- queue.append((predicted_next, new_path))
121
 
122
- return None, search_states, False
 
123
 
 
 
 
124
 
125
- # ==========================================
126
- # VISUALIZATION HELPERS
127
- # ==========================================
128
- def render_grid(env: GridEnvironment, agent_pos: tuple, highlight_cells: dict = None,
129
- show_model_knowledge: set = None, plan_path: list = None) -> str:
130
- """
131
- Render the grid as HTML.
132
- highlight_cells: dict mapping (x,y) -> color class
133
- show_model_knowledge: set of states the model knows about
134
- plan_path: list of (x,y) positions in the planned path
135
- """
136
- size = env.size
137
- cell_size = 60
 
 
138
 
139
  html = f'''
140
- <style>
141
- .grid-container {{
142
- display: inline-grid;
143
- grid-template-columns: repeat({size}, {cell_size}px);
144
- gap: 2px;
145
- background: #1a1a2e;
146
- padding: 10px;
147
- border-radius: 12px;
148
- box-shadow: 0 4px 20px rgba(0,0,0,0.3);
149
- }}
150
- .grid-cell {{
151
- width: {cell_size}px;
152
- height: {cell_size}px;
153
- display: flex;
154
- align-items: center;
155
- justify-content: center;
156
- font-size: 24px;
157
- border-radius: 8px;
158
- transition: all 0.3s ease;
159
- position: relative;
160
- }}
161
- .cell-empty {{ background: #16213e; }}
162
- .cell-agent {{ background: #4ecca3; animation: pulse 1s infinite; }}
163
- .cell-goal {{ background: #ffd369; }}
164
- .cell-obstacle {{ background: #e94560; }}
165
- .cell-start {{ background: #7b68ee; }}
166
- .cell-explored {{ background: #2d4263; border: 2px solid #4ecca3; }}
167
- .cell-path {{ background: #00adb5; }}
168
- .cell-search {{ background: #533483; border: 2px dashed #9d65c9; }}
169
- .cell-agent-at-goal {{ background: linear-gradient(135deg, #4ecca3, #ffd369); }}
170
- @keyframes pulse {{
171
- 0%, 100% {{ transform: scale(1); }}
172
- 50% {{ transform: scale(0.95); }}
173
- }}
174
- .coord-label {{
175
- position: absolute;
176
- bottom: 2px;
177
- right: 4px;
178
- font-size: 9px;
179
- color: rgba(255,255,255,0.4);
180
- }}
181
- </style>
182
- <div class="grid-container">
183
  '''
184
 
185
- highlight_cells = highlight_cells or {}
186
- plan_path_set = set(plan_path) if plan_path else set()
187
-
188
  for y in range(size):
 
189
  for x in range(size):
190
- pos = (x, y)
191
- cell_class = "cell-empty"
192
- content = ""
193
-
194
- # Layer the cell states (order matters)
195
- if show_model_knowledge and pos in show_model_knowledge:
196
- cell_class = "cell-explored"
197
-
198
- if pos in plan_path_set and pos != env.goal:
199
- cell_class = "cell-path"
200
-
201
- if pos in highlight_cells:
202
- cell_class = highlight_cells[pos]
203
 
204
- if pos in env.obstacles:
205
- cell_class = "cell-obstacle"
206
- content = "🧱"
207
- elif pos == (0, 0) and pos != agent_pos:
208
- cell_class = "cell-start"
209
- content = "🏁"
210
- elif pos == env.goal and pos != agent_pos:
211
- cell_class = "cell-goal"
212
- content = "⭐"
213
 
214
- if pos == agent_pos:
215
- if pos == env.goal:
216
- cell_class = "cell-agent-at-goal"
217
- content = "🤖⭐"
218
- else:
219
- cell_class = "cell-agent"
220
- content = "🤖"
221
 
222
- html += f'<div class="grid-cell {cell_class}">{content}<span class="coord-label">{x},{y}</span></div>'
223
-
224
- html += '</div>'
225
- return html
226
-
227
-
228
- def create_stats_html(rules_learned: int, states_explored: int, plan_length: int = 0,
229
- phase: str = "Ready") -> str:
230
- """Create HTML for statistics display."""
231
- phase_colors = {
232
- "Ready": "#888",
233
- "Exploring": "#4ecca3",
234
- "Dreaming": "#9d65c9",
235
- "Executing": "#00adb5",
236
- "Complete": "#ffd369"
237
- }
238
- color = phase_colors.get(phase, "#888")
239
-
240
- return f'''
241
- <div style="
242
- background: linear-gradient(135deg, #1a1a2e, #16213e);
243
- padding: 20px;
244
- border-radius: 12px;
245
- color: white;
246
- font-family: 'Segoe UI', sans-serif;
247
- display: grid;
248
- grid-template-columns: repeat(2, 1fr);
249
- gap: 15px;
250
- max-width: 400px;
251
- ">
252
- <div style="text-align: center; padding: 10px; background: rgba(255,255,255,0.1); border-radius: 8px;">
253
- <div style="font-size: 28px; font-weight: bold; color: #4ecca3;">{rules_learned}</div>
254
- <div style="font-size: 12px; opacity: 0.8;">Physics Rules Learned</div>
255
- </div>
256
- <div style="text-align: center; padding: 10px; background: rgba(255,255,255,0.1); border-radius: 8px;">
257
- <div style="font-size: 28px; font-weight: bold; color: #7b68ee;">{states_explored}</div>
258
- <div style="font-size: 12px; opacity: 0.8;">States Explored</div>
259
- </div>
260
- <div style="text-align: center; padding: 10px; background: rgba(255,255,255,0.1); border-radius: 8px;">
261
- <div style="font-size: 28px; font-weight: bold; color: #00adb5;">{plan_length}</div>
262
- <div style="font-size: 12px; opacity: 0.8;">Plan Length</div>
263
- </div>
264
- <div style="text-align: center; padding: 10px; background: rgba(255,255,255,0.1); border-radius: 8px;">
265
- <div style="font-size: 16px; font-weight: bold; color: {color};">● {phase}</div>
266
- <div style="font-size: 12px; opacity: 0.8;">Current Phase</div>
267
  </div>
268
  </div>
269
  '''
 
270
 
 
 
 
271
 
272
- # ==========================================
273
- # GRADIO INTERFACE
274
- # ==========================================
275
- class WorldModelDemo:
276
- """Main demo controller."""
277
-
278
- def __init__(self):
279
- self.reset()
 
 
 
 
 
280
 
281
- def reset(self, grid_size: int = 4, obstacle_preset: str = "Default"):
282
- """Reset the demo state."""
283
- obstacles = self._get_obstacles(grid_size, obstacle_preset)
284
- self.env = GridEnvironment(size=grid_size, obstacles=obstacles)
285
- self.agent = Agent()
286
- self.plan = None
287
- self.plan_positions = []
288
- self.search_states = []
289
- self.current_step = 0
290
- self.phase = "Ready"
291
- self.log = []
292
-
293
- return self._render_state()
294
 
295
- def _get_obstacles(self, size: int, preset: str) -> set:
296
- """Get obstacle configuration based on preset."""
297
- if preset == "None":
298
- return set()
299
- elif preset == "Default":
300
- if size == 4:
301
- return {(1, 1), (1, 2), (2, 2)}
302
- elif size == 5:
303
- return {(1, 1), (1, 2), (2, 2), (3, 1)}
304
- else:
305
- return {(1, 1), (2, 2), (3, 3)}
306
- elif preset == "Maze":
307
- if size == 4:
308
- return {(1, 0), (1, 1), (1, 2), (2, 2)}
309
- elif size == 5:
310
- return {(1, 0), (1, 1), (1, 2), (3, 2), (3, 3), (3, 4)}
311
- else:
312
- return {(1, 0), (1, 1), (2, 3), (2, 4), (4, 1), (4, 2)}
313
- elif preset == "Scattered":
314
- if size == 4:
315
- return {(0, 2), (2, 0), (2, 3)}
316
- elif size == 5:
317
- return {(0, 2), (2, 0), (2, 3), (4, 1)}
318
- else:
319
- return {(0, 2), (2, 0), (2, 4), (4, 2), (5, 0)}
320
- return set()
321
 
322
- def _render_state(self, highlight: dict = None) -> tuple:
323
- """Render current state as HTML outputs."""
324
- known_states = self.agent.model.get_learned_states()
325
-
326
- grid_html = render_grid(
327
- self.env,
328
- self.env.agent_pos,
329
- highlight_cells=highlight,
330
- show_model_knowledge=known_states if self.phase != "Ready" else None,
331
- plan_path=self.plan_positions if self.plan_positions else None
332
- )
333
-
334
- stats_html = create_stats_html(
335
- rules_learned=len(self.agent.model.transitions),
336
- states_explored=len(known_states),
337
- plan_length=len(self.plan) if self.plan else 0,
338
- phase=self.phase
339
- )
340
-
341
- log_text = "\n".join(self.log[-20:]) # Last 20 log entries
342
-
343
- return grid_html, stats_html, log_text
344
 
345
- def explore(self, steps: int = 100) -> tuple:
346
- """Run exploration phase."""
347
- self.phase = "Exploring"
348
- self.env.reset()
349
- self.log.append(f"═══ PHASE 1: EXPLORATION ({steps} steps) ═══")
350
-
351
- for i in range(steps):
352
- result = self.agent.explore_step(self.env)
353
-
354
- if i < 10 or i % 50 == 0: # Log first 10 and every 50th
355
- bounce_str = " (BOUNCE!)" if result['bounced'] else ""
356
- self.log.append(f"Step {i+1}: {result['state']} → {result['action_name']} → {result['next_state']}{bounce_str}")
357
-
358
- if self.env.agent_pos == self.env.goal:
359
- self.env.reset()
360
-
361
- self.log.append(f"✓ Learned {len(self.agent.model.transitions)} physics rules")
362
- self.log.append(f"✓ Explored {len(self.agent.model.get_learned_states())} unique states")
363
-
364
- return self._render_state()
365
 
366
- def dream(self) -> tuple:
367
- """Run planning phase without moving in real world."""
368
- self.phase = "Dreaming"
369
- self.env.reset()
370
-
371
- self.log.append(f"═══ PHASE 2: DREAMING ═══")
372
- self.log.append(f"Planning from (0,0) to {self.env.goal}...")
373
- self.log.append("(No real-world movement - pure simulation!)")
374
-
375
- start = (0, 0)
376
- goal = self.env.goal
377
-
378
- self.plan, self.search_states, success = self.agent.dream_and_plan(start, goal)
379
-
380
- if success:
381
- # Convert plan to position list
382
- self.plan_positions = [start]
383
- pos = start
384
- for action in self.plan:
385
- predicted = self.agent.model.predict(pos, action)
386
- if predicted:
387
- self.plan_positions.append(predicted)
388
- pos = predicted
389
-
390
- path_str = " → ".join([self.agent.action_names[a] for a in self.plan])
391
- self.log.append(f"✓ Plan found! Length: {len(self.plan)}")
392
- self.log.append(f" Path: {path_str}")
393
- self.log.append(f" Positions: {' → '.join(str(p) for p in self.plan_positions)}")
394
- else:
395
- self.log.append("✗ No path found - need more exploration!")
396
- self.plan = None
397
- self.plan_positions = []
398
-
399
- # Highlight searched states
400
- highlight = {s: "cell-search" for s in self.search_states}
401
-
402
- return self._render_state(highlight)
403
-
404
- def execute(self) -> tuple:
405
- """Execute the plan in reality."""
406
- if not self.plan:
407
- self.log.append("⚠ No plan to execute! Run 'Dream' first.")
408
- return self._render_state()
409
-
410
- self.phase = "Executing"
411
- self.env.reset()
412
-
413
- self.log.append(f"═══ PHASE 3: EXECUTION ═══")
414
- self.log.append(f"Start: {self.env.agent_pos}")
415
-
416
- for i, action in enumerate(self.plan):
417
- state = self.env.step(action)
418
- self.log.append(f" {self.agent.action_names[action]} → {state}")
419
-
420
- if self.env.agent_pos == self.env.goal:
421
- self.phase = "Complete"
422
- self.log.append("🎉 SUCCESS! Goal reached!")
423
- else:
424
- self.log.append("⚠ FAILURE: Plan didn't reach goal")
425
-
426
- return self._render_state()
427
-
428
- def run_full_demo(self, steps: int = 200) -> tuple:
429
- """Run all three phases automatically."""
430
- self.reset()
431
-
432
- # Phase 1
433
- self.explore(steps)
434
-
435
- # Phase 2
436
- self.dream()
437
-
438
- # Phase 3
439
- if self.plan:
440
- self.execute()
441
-
442
- return self._render_state()
443
-
444
-
445
- # Create global demo instance
446
- demo = WorldModelDemo()
447
-
448
-
449
- def reset_demo(grid_size, obstacle_preset):
450
- return demo.reset(int(grid_size), obstacle_preset)
451
-
452
- def run_explore(steps):
453
- return demo.explore(int(steps))
454
-
455
- def run_dream():
456
- return demo.dream()
457
 
458
- def run_execute():
459
- return demo.execute()
 
 
 
 
 
460
 
461
- def run_full(steps):
462
- return demo.run_full_demo(int(steps))
463
-
464
-
465
- # ==========================================
466
- # GRADIO UI
467
- # ==========================================
468
- with gr.Blocks(
469
- title="World Model Demo",
470
- theme=gr.themes.Soft(
471
- primary_hue="teal",
472
- secondary_hue="purple",
473
- ),
474
- css="""
475
- .main-title {
476
- text-align: center;
477
- margin-bottom: 10px;
478
- background: linear-gradient(90deg, #4ecca3, #7b68ee);
479
- -webkit-background-clip: text;
480
- -webkit-text-fill-color: transparent;
481
- }
482
- .phase-btn { min-width: 120px; }
483
- footer { display: none !important; }
484
- """
485
- ) as interface:
486
-
487
  gr.Markdown("""
488
  # 🧠 World Model Demo
489
- ### How Intelligent Agents Learn to Dream and Plan
490
 
491
- This interactive demo shows how an AI agent builds an internal model of its world through three phases:
492
 
493
- | Phase | Description |
494
- |-------|-------------|
495
- | 🔍 **Exploration** | Random movement to discover physics rules ("motor babbling") |
496
- | 💭 **Dreaming** | Planning a path using *only* the internal model (no real movement!) |
497
- | 🚀 **Execution** | Following the imagined plan in the real world |
498
-
499
- ---
500
  """)
501
 
502
  with gr.Row():
503
  with gr.Column(scale=2):
504
- grid_display = gr.HTML(label="Grid World")
 
 
505
 
506
  with gr.Column(scale=1):
507
- stats_display = gr.HTML(label="Statistics")
508
-
509
- with gr.Row():
510
- with gr.Column():
511
- gr.Markdown("### ⚙️ Configuration")
512
  with gr.Row():
513
- grid_size = gr.Dropdown(
514
- choices=["4", "5", "6"],
515
- value="4",
516
- label="Grid Size"
517
- )
518
- obstacle_preset = gr.Dropdown(
519
- choices=["None", "Default", "Maze", "Scattered"],
520
- value="Default",
521
- label="Obstacles"
522
- )
523
- exploration_steps = gr.Slider(
524
- minimum=50, maximum=500, value=200, step=50,
525
- label="Exploration Steps"
526
- )
527
-
528
- gr.Markdown("### 🎮 Controls")
529
  with gr.Row():
530
- reset_btn = gr.Button("🔄 Reset", variant="secondary")
531
- full_btn = gr.Button("▶️ Run All Phases", variant="primary")
 
532
 
533
- gr.Markdown("### 📍 Step-by-Step")
534
- with gr.Row():
535
- explore_btn = gr.Button("1️⃣ Explore", elem_classes="phase-btn")
536
- dream_btn = gr.Button("2️⃣ Dream", elem_classes="phase-btn")
537
- execute_btn = gr.Button("3️⃣ Execute", elem_classes="phase-btn")
538
-
539
- with gr.Column():
540
- log_display = gr.Textbox(
541
- label="📋 Activity Log",
542
- lines=15,
543
- max_lines=20,
544
- interactive=False
545
- )
546
-
547
- gr.Markdown("""
548
- ---
549
- ### 📚 How It Works
550
-
551
- **The Key Insight:** The agent's "brain" (World Model) is a simple dictionary that maps
552
- `(state, action) → next_state`. During **Dreaming**, the agent searches through this
553
- dictionary using BFS - it never calls the real environment!
554
-
555
- **Why This Matters:** This is the foundation of how advanced AI systems (like MuZero,
556
- Dreamer, and world models in robotics) learn to plan. Instead of trial-and-error in
557
- reality (expensive, dangerous), they simulate futures in their head.
558
-
559
- **Legend:**
560
- - 🤖 Agent | ⭐ Goal | 🏁 Start | 🧱 Wall
561
- - 🟢 Border = Explored states | 🟣 Dashed = States searched during planning | 🔵 = Planned path
562
-
563
- ---
564
- *Built with ❤️ using Gradio • Concept: World Models for Intelligent Agents*
565
- """)
566
-
567
- # Event handlers
568
- outputs = [grid_display, stats_display, log_display]
569
-
570
- reset_btn.click(reset_demo, inputs=[grid_size, obstacle_preset], outputs=outputs)
571
- grid_size.change(reset_demo, inputs=[grid_size, obstacle_preset], outputs=outputs)
572
- obstacle_preset.change(reset_demo, inputs=[grid_size, obstacle_preset], outputs=outputs)
573
-
574
- explore_btn.click(run_explore, inputs=[exploration_steps], outputs=outputs)
575
- dream_btn.click(run_dream, outputs=outputs)
576
- execute_btn.click(run_execute, outputs=outputs)
577
- full_btn.click(run_full, inputs=[exploration_steps], outputs=outputs)
578
-
579
- # Initialize on load
580
- interface.load(lambda: demo.reset(), outputs=outputs)
581
-
582
 
583
  if __name__ == "__main__":
584
- interface.launch()
 
1
  """
2
+ World Model Demo - Interactive AI Planning Visualization
3
+ Educational demonstration of model-based reinforcement learning concepts
 
 
 
 
 
4
  """
5
 
6
  import gradio as gr
7
  import random
 
 
 
8
  import json
9
 
10
+ # ============================================================================
11
+ # World Model Core Classes
12
+ # ============================================================================
13
+
14
+ class GridWorld:
15
+ """Simple grid environment for world model demonstration"""
16
 
17
+ def __init__(self, size=8):
18
  self.size = size
19
+ self.reset()
 
 
20
 
21
  def reset(self):
22
+ self.agent_pos = [1, 1]
23
+ self.goal_pos = [self.size - 2, self.size - 2]
24
+ self.obstacles = self._generate_obstacles()
25
+ self.steps = 0
26
+ return self._get_state()
27
+
28
+ def _generate_obstacles(self):
29
+ obstacles = set()
30
+ num_obstacles = self.size
31
+ while len(obstacles) < num_obstacles:
32
+ x, y = random.randint(0, self.size-1), random.randint(0, self.size-1)
33
+ if [x, y] != self.agent_pos and [x, y] != self.goal_pos:
34
+ obstacles.add((x, y))
35
+ return obstacles
36
+
37
+ def _get_state(self):
38
+ return {
39
+ 'agent': self.agent_pos.copy(),
40
+ 'goal': self.goal_pos,
41
+ 'obstacles': list(self.obstacles),
42
+ 'size': self.size,
43
+ 'steps': self.steps
44
+ }
45
 
46
+ def step(self, action):
47
+ dx, dy = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)}.get(action, (0, 0))
48
+ new_x = max(0, min(self.size - 1, self.agent_pos[0] + dx))
49
+ new_y = max(0, min(self.size - 1, self.agent_pos[1] + dy))
50
 
51
+ if (new_x, new_y) not in self.obstacles:
52
+ self.agent_pos = [new_x, new_y]
 
 
53
 
54
+ self.steps += 1
55
+ done = self.agent_pos == self.goal_pos
56
+ reward = 10 if done else -0.1
57
 
58
+ return self._get_state(), reward, done
 
59
 
 
 
 
60
  class WorldModel:
61
+ """Simple world model that learns to predict state transitions"""
62
 
63
  def __init__(self):
64
+ self.transition_counts = {}
65
+ self.prediction_accuracy = 0.5
66
+ self.total_predictions = 0
67
+ self.correct_predictions = 0
68
 
69
  def predict(self, state, action):
70
+ """Predict next state given current state and action"""
71
+ agent = tuple(state['agent'])
72
+ key = (agent, action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ if key in self.transition_counts:
75
+ predicted = list(self.transition_counts[key])
76
+ confidence = min(0.95, 0.5 + self.correct_predictions / max(1, self.total_predictions) * 0.5)
77
+ else:
78
+ dx, dy = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)}.get(action, (0, 0))
79
+ predicted = [agent[0] + dx, agent[1] + dy]
80
+ confidence = 0.3
81
+
82
+ return predicted, confidence
83
 
84
+ def learn(self, state, action, next_state):
85
+ """Learn from observed transition"""
86
+ agent = tuple(state['agent'])
87
+ next_agent = tuple(next_state['agent'])
88
+ key = (agent, action)
89
 
90
+ predicted, _ = self.predict(state, action)
91
+ self.total_predictions += 1
92
+ if tuple(predicted) == next_agent:
93
+ self.correct_predictions += 1
 
 
 
 
 
 
 
 
 
 
94
 
95
+ self.transition_counts[key] = next_agent
96
+ self.prediction_accuracy = self.correct_predictions / max(1, self.total_predictions)
97
 
98
+ # ============================================================================
99
+ # Visualization
100
+ # ============================================================================
101
 
102
+ def render_grid_html(state, prediction=None, phase="observe"):
103
+ """Render the grid as an HTML table"""
104
+ size = state['size']
105
+ agent = state['agent']
106
+ goal = state['goal']
107
+ obstacles = set(map(tuple, state['obstacles']))
108
+
109
+ colors = {
110
+ 'observe': '#3b82f6',
111
+ 'predict': '#8b5cf6',
112
+ 'plan': '#f59e0b',
113
+ 'act': '#22c55e',
114
+ 'learn': '#ec4899'
115
+ }
116
+ phase_color = colors.get(phase, '#6b7280')
117
 
118
  html = f'''
119
+ <div style="text-align: center; font-family: system-ui, sans-serif;">
120
+ <div style="display: inline-block; background: #1e293b; padding: 20px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.3);">
121
+ <div style="margin-bottom: 10px; color: {phase_color}; font-weight: bold; font-size: 18px;">
122
+ Phase: {phase.upper()}
123
+ </div>
124
+ <table style="border-collapse: collapse; margin: auto;">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  '''
126
 
 
 
 
127
  for y in range(size):
128
+ html += '<tr>'
129
  for x in range(size):
130
+ bg = '#334155'
131
+ content = ''
132
+ border = '1px solid #475569'
 
 
 
 
 
 
 
 
 
 
133
 
134
+ if (x, y) in obstacles:
135
+ bg = '#7f1d1d'
136
+ content = '🧱'
137
+ elif [x, y] == goal:
138
+ bg = '#166534'
139
+ content = '⭐'
140
+ elif [x, y] == agent:
141
+ bg = '#1d4ed8'
142
+ content = '🤖'
143
 
144
+ if prediction and [x, y] == prediction:
145
+ border = f'3px solid {phase_color}'
 
 
 
 
 
146
 
147
+ html += f'''
148
+ <td style="width: 45px; height: 45px; background: {bg};
149
+ border: {border}; text-align: center; font-size: 20px;">
150
+ {content}
151
+ </td>
152
+ '''
153
+ html += '</tr>'
154
+
155
+ html += '''
156
+ </table>
157
+ <div style="margin-top: 15px; color: #94a3b8; font-size: 14px;">
158
+ 🤖 Agent | ⭐ Goal | 🧱 Obstacle
159
+ </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  </div>
161
  </div>
162
  '''
163
+ return html
164
 
165
+ # ============================================================================
166
+ # Gradio Interface
167
+ # ============================================================================
168
 
169
+ world = GridWorld()
170
+ model = WorldModel()
171
+ current_state = world.reset()
172
+ current_phase = "observe"
173
+
174
+ def get_display():
175
+ global current_state, current_phase
176
+ html = render_grid_html(current_state, phase=current_phase)
177
+ stats = f"Steps: {current_state['steps']} | Model Accuracy: {model.prediction_accuracy:.1%}"
178
+ return html, stats
179
+
180
+ def do_action(action):
181
+ global current_state, current_phase, world, model
182
 
183
+ current_phase = "predict"
184
+ prediction, confidence = model.predict(current_state, action)
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ current_phase = "act"
187
+ old_state = current_state.copy()
188
+ current_state, reward, done = world.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ current_phase = "learn"
191
+ model.learn(old_state, action, current_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ if done:
194
+ current_phase = "observe"
195
+ current_state = world.reset()
196
+ message = "🎉 Goal reached! Environment reset."
197
+ else:
198
+ current_phase = "observe"
199
+ message = f"Moved {action}. Prediction confidence: {confidence:.1%}"
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ html, stats = get_display()
202
+ return html, stats, message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ def reset_env():
205
+ global current_state, current_phase, world, model
206
+ current_state = world.reset()
207
+ model = WorldModel()
208
+ current_phase = "observe"
209
+ html, stats = get_display()
210
+ return html, stats, "Environment reset!"
211
 
212
+ # Build the interface
213
+ with gr.Blocks(title="World Model Demo", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  gr.Markdown("""
215
  # 🧠 World Model Demo
 
216
 
217
+ Interactive demonstration of how AI agents build internal models of the world.
218
 
219
+ **The Learning Cycle:**
220
+ 1. **Observe** - Agent perceives current state
221
+ 2. **Predict** - World model predicts action outcomes
222
+ 3. **Plan** - Agent evaluates possible futures
223
+ 4. **Act** - Execute chosen action
224
+ 5. **Learn** - Update model from observed outcome
 
225
  """)
226
 
227
  with gr.Row():
228
  with gr.Column(scale=2):
229
+ grid_display = gr.HTML(label="Environment")
230
+ stats_display = gr.Textbox(label="Statistics", interactive=False)
231
+ message_display = gr.Textbox(label="Status", interactive=False)
232
 
233
  with gr.Column(scale=1):
234
+ gr.Markdown("### Controls")
 
 
 
 
235
  with gr.Row():
236
+ gr.Button("").click(lambda: None)
237
+ up_btn = gr.Button("⬆️ Up")
238
+ gr.Button("").click(lambda: None)
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  with gr.Row():
240
+ left_btn = gr.Button("⬅️ Left")
241
+ down_btn = gr.Button("⬇️ Down")
242
+ right_btn = gr.Button("➡️ Right")
243
 
244
+ reset_btn = gr.Button("🔄 Reset", variant="secondary")
245
+
246
+ gr.Markdown("""
247
+ ### About World Models
248
+
249
+ World models are internal representations that AI agents use to:
250
+ - Simulate possible futures
251
+ - Plan without trial-and-error
252
+ - Learn efficiently from experience
253
+
254
+ Used in: MuZero, Dreamer, PlaNet
255
+ """)
256
+
257
+ # Connect buttons
258
+ up_btn.click(lambda: do_action("up"), outputs=[grid_display, stats_display, message_display])
259
+ down_btn.click(lambda: do_action("down"), outputs=[grid_display, stats_display, message_display])
260
+ left_btn.click(lambda: do_action("left"), outputs=[grid_display, stats_display, message_display])
261
+ right_btn.click(lambda: do_action("right"), outputs=[grid_display, stats_display, message_display])
262
+ reset_btn.click(reset_env, outputs=[grid_display, stats_display, message_display])
263
+
264
+ # Initial display
265
+ demo.load(get_display, outputs=[grid_display, stats_display])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  if __name__ == "__main__":
268
+ demo.launch()