anthonym21 commited on
Commit
5bcb831
·
verified ·
1 Parent(s): e345b60

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +273 -164
app.py CHANGED
@@ -5,7 +5,7 @@ Educational demonstration of model-based reinforcement learning concepts
5
 
6
  import gradio as gr
7
  import random
8
- import json
9
 
10
  # ============================================================================
11
  # World Model Core Classes
@@ -14,24 +14,28 @@ import json
14
  class GridWorld:
15
  """Simple grid environment for world model demonstration"""
16
 
17
- def __init__(self, size=8):
18
  self.size = size
19
  self.reset()
20
 
21
  def reset(self):
22
- self.agent_pos = [1, 1]
23
- self.goal_pos = [self.size - 2, self.size - 2]
24
  self.obstacles = self._generate_obstacles()
25
  self.steps = 0
26
  return self._get_state()
27
 
28
  def _generate_obstacles(self):
29
  obstacles = set()
30
- num_obstacles = self.size
31
- while len(obstacles) < num_obstacles:
 
32
  x, y = random.randint(0, self.size-1), random.randint(0, self.size-1)
33
  if [x, y] != self.agent_pos and [x, y] != self.goal_pos:
34
- obstacles.add((x, y))
 
 
 
35
  return obstacles
36
 
37
  def _get_state(self):
@@ -44,7 +48,8 @@ class GridWorld:
44
  }
45
 
46
  def step(self, action):
47
- dx, dy = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)}.get(action, (0, 0))
 
48
  new_x = max(0, min(self.size - 1, self.agent_pos[0] + dx))
49
  new_y = max(0, min(self.size - 1, self.agent_pos[1] + dy))
50
 
@@ -53,86 +58,136 @@ class GridWorld:
53
 
54
  self.steps += 1
55
  done = self.agent_pos == self.goal_pos
56
- reward = 10 if done else -0.1
57
-
58
- return self._get_state(), reward, done
 
 
 
 
 
 
59
 
60
- class WorldModel:
61
- """Simple world model that learns to predict state transitions"""
62
 
63
  def __init__(self):
64
- self.transition_counts = {}
65
- self.prediction_accuracy = 0.5
66
- self.total_predictions = 0
67
- self.correct_predictions = 0
68
-
69
- def predict(self, state, action):
70
- """Predict next state given current state and action"""
71
- agent = tuple(state['agent'])
72
- key = (agent, action)
73
-
74
- if key in self.transition_counts:
75
- predicted = list(self.transition_counts[key])
76
- confidence = min(0.95, 0.5 + self.correct_predictions / max(1, self.total_predictions) * 0.5)
77
- else:
78
- dx, dy = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)}.get(action, (0, 0))
79
- predicted = [agent[0] + dx, agent[1] + dy]
80
- confidence = 0.3
81
-
82
- return predicted, confidence
83
 
84
- def learn(self, state, action, next_state):
85
- """Learn from observed transition"""
86
- agent = tuple(state['agent'])
87
- next_agent = tuple(next_state['agent'])
88
- key = (agent, action)
 
 
 
 
 
 
 
89
 
90
- predicted, _ = self.predict(state, action)
91
- self.total_predictions += 1
92
- if tuple(predicted) == next_agent:
93
- self.correct_predictions += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- self.transition_counts[key] = next_agent
96
- self.prediction_accuracy = self.correct_predictions / max(1, self.total_predictions)
 
97
 
98
  # ============================================================================
99
  # Visualization
100
  # ============================================================================
101
 
102
- def render_grid_html(state, phase="observe", prediction=None):
103
- """Render the grid as HTML with phase-appropriate styling"""
104
  agent = state['agent']
105
  goal = state['goal']
106
  obstacles = set(tuple(o) if isinstance(o, list) else o for o in state['obstacles'])
107
  size = state['size']
108
 
109
- phase_colors = {
110
- 'observe': '#3b82f6', # blue
111
- 'predict': '#f59e0b', # amber
112
- 'plan': '#8b5cf6', # purple
113
- 'act': '#10b981', # green
114
- 'learn': '#ec4899' # pink
115
  }
116
- phase_color = phase_colors.get(phase, '#6b7280')
 
117
 
118
  html = f'''
119
  <div style="text-align: center; font-family: system-ui, sans-serif;">
120
- <div style="display: inline-block; background: #1e293b; padding: 20px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.3);">
121
- <div style="margin-bottom: 10px; color: {phase_color}; font-weight: bold; font-size: 18px;">
122
- Phase: {phase.upper()}
 
 
 
 
 
123
  </div>
124
- <table style="border-collapse: collapse; margin: auto;">
125
  '''
126
 
 
 
 
 
 
 
127
  for y in range(size):
128
  html += '<tr>'
129
  for x in range(size):
130
  bg = '#334155'
131
  content = ''
132
- border = '1px solid #475569'
 
133
 
134
  if (x, y) in obstacles:
135
- bg = '#7f1d1d'
136
  content = '🧱'
137
  elif [x, y] == goal:
138
  bg = '#166534'
@@ -140,13 +195,16 @@ def render_grid_html(state, phase="observe", prediction=None):
140
  elif [x, y] == agent:
141
  bg = '#1d4ed8'
142
  content = '🤖'
143
-
144
- if prediction and [x, y] == prediction:
145
- border = f'3px solid {phase_color}'
 
 
146
 
147
  html += f'''
148
- <td style="width: 45px; height: 45px; background: {bg};
149
- border: {border}; text-align: center; font-size: 20px;">
 
150
  {content}
151
  </td>
152
  '''
@@ -154,8 +212,51 @@ def render_grid_html(state, phase="observe", prediction=None):
154
 
155
  html += '''
156
  </table>
157
- <div style="margin-top: 15px; color: #94a3b8; font-size: 14px;">
158
- 🤖 Agent | ⭐ Goal | 🧱 Obstacle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  </div>
160
  </div>
161
  </div>
@@ -163,147 +264,155 @@ def render_grid_html(state, phase="observe", prediction=None):
163
  return html
164
 
165
  # ============================================================================
166
- # Gradio Interface
167
  # ============================================================================
168
 
169
- world = GridWorld()
170
- model = WorldModel()
171
  current_state = world.reset()
172
- current_phase = "observe"
173
 
174
- def get_display():
175
- global current_state, current_phase
176
- html = render_grid_html(current_state, phase=current_phase)
177
- stats = f"Steps: {current_state['steps']} | Model Accuracy: {model.prediction_accuracy:.1%}"
178
- return html, stats
 
 
 
 
 
 
179
 
180
- def do_action(action):
181
- global current_state, current_phase, world, model
 
 
 
 
 
 
 
 
 
 
182
 
183
- current_phase = "predict"
184
- prediction, confidence = model.predict(current_state, action)
185
 
186
- current_phase = "act"
187
- old_state = current_state.copy()
188
- current_state, reward, done = world.step(action)
189
 
190
- current_phase = "learn"
191
- model.learn(old_state, action, current_state)
 
 
 
192
 
193
  if done:
194
- current_phase = "observe"
195
- current_state = world.reset()
196
- message = "🎉 Goal reached! Environment reset."
197
  else:
198
- current_phase = "observe"
199
- message = f"Moved {action}. Prediction confidence: {confidence:.1%}"
200
 
201
- html, stats = get_display()
202
- return html, stats, message
203
 
204
- def reset_env():
205
- global current_state, current_phase, world, model
206
- world = GridWorld() # Create fresh world
207
- model = WorldModel() # Create fresh model
208
- current_state = world.reset()
209
- current_phase = "observe"
210
- html, stats = get_display()
211
- return html, stats, "Environment reset!"
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- # Build the interface
214
  with gr.Blocks(title="World Model Demo", theme=gr.themes.Soft()) as demo:
215
  gr.Markdown("""
216
  # 🧠 World Model Demo
217
 
218
- **What is this?** An interactive demonstration of how AI agents can build internal "mental models"
219
- of the world to plan and reason, rather than just reacting to inputs.
 
 
 
 
 
 
220
  """)
221
 
222
  with gr.Row():
223
- with gr.Column(scale=2):
224
- grid_display = gr.HTML(label="Environment")
225
- stats_display = gr.Textbox(label="Statistics", interactive=False)
226
- message_display = gr.Textbox(label="Status", interactive=False)
227
 
228
- with gr.Column(scale=1):
229
- gr.Markdown("### Controls")
230
- with gr.Row():
231
- gr.Button("", visible=False, min_width=1)
232
- up_btn = gr.Button("⬆️ Up")
233
- gr.Button("", visible=False, min_width=1)
234
- with gr.Row():
235
- left_btn = gr.Button("⬅️ Left")
236
- down_btn = gr.Button("⬇️ Down")
237
- right_btn = gr.Button("➡️ Right")
238
 
 
 
 
239
  reset_btn = gr.Button("🔄 Reset", variant="secondary")
240
 
241
- gr.Markdown("""
242
- ---
243
- **The Learning Cycle:**
244
- 1. 🔍 **Observe** - Perceive state
245
- 2. 💭 **Predict** - Imagine outcomes
246
- 3. **Act** - Execute action
247
- 4. 📚 **Learn** - Update model
248
- """)
249
-
250
- # Educational content in collapsible sections
251
- with gr.Accordion("📖 What is a World Model?", open=False):
252
- gr.Markdown("""
253
- A **world model** is an internal representation that an AI agent uses to *simulate* the
254
- environment without actually interacting with it. Think of it as the agent's "imagination."
255
-
256
- **Instead of pure trial-and-error, an agent with a world model can:**
257
- - 🎯 **Imagine** possible futures ("what if I do X?")
258
- - ⚖️ **Evaluate** which imagined future looks best
259
- - 🗺️ **Plan** a sequence of actions to reach that future
260
- - ✅ **Act** with confidence, having already "seen" the outcome
261
-
262
- **Real examples:** MuZero (mastered Go/Chess without knowing rules), Dreamer (robot control),
263
- IRIS (Atari from pixels)
264
- """)
265
 
266
- with gr.Accordion("🤔 How is this different from ChatGPT/Claude?", open=False):
267
  gr.Markdown("""
268
  | Aspect | Language Model (GPT, Claude) | World Model (This Demo) |
269
  |--------|------------------------------|-------------------------|
270
- | **Predicts** | Next *word* in a sequence | Next *state* given an action |
271
- | **Training** | Text prediction | Reward from environment |
272
  | **"Thinking"** | Generates plausible text | Simulates physical outcomes |
273
  | **Planning** | Implicit (chain-of-thought) | Explicit (tree search) |
274
- | **Grounding** | Statistical text patterns | Causal dynamics |
275
 
276
- **Example:**
277
- - **LLM**: "If I push a ball off a table..." generates plausible *text*
278
- - **World Model**: state(ball on table) + action(push) → predicts actual *trajectory*
279
 
280
- Language models learn *what sounds right*. World models learn *what actually happens*.
 
281
  """)
282
 
283
  with gr.Accordion("🔬 Why does this matter for AI Safety?", open=False):
284
  gr.Markdown("""
285
- World models are crucial for AI safety research because:
286
-
287
- - **Predictability**: Agents that plan can be analyzed - we can inspect what futures they're considering
288
- - **Corrigibility**: Planning agents can incorporate "avoid irreversible actions" into their search
289
- - **Interpretability**: The model's predictions can be examined for accuracy and bias
290
- - **Scalable Oversight**: Humans can audit the agent's "reasoning" by inspecting simulated futures
291
 
292
- Understanding how AI systems model the world helps us build systems we can trust and verify.
 
 
 
293
 
294
- ---
295
- *Created by [Anthony Maio](https://huggingface.co/anthonym21) as an educational resource*
296
  """)
297
 
298
  # Connect buttons
299
- up_btn.click(lambda: do_action("up"), outputs=[grid_display, stats_display, message_display])
300
- down_btn.click(lambda: do_action("down"), outputs=[grid_display, stats_display, message_display])
301
- left_btn.click(lambda: do_action("left"), outputs=[grid_display, stats_display, message_display])
302
- right_btn.click(lambda: do_action("right"), outputs=[grid_display, stats_display, message_display])
303
- reset_btn.click(reset_env, outputs=[grid_display, stats_display, message_display])
304
-
305
- # Initial display
306
- demo.load(get_display, outputs=[grid_display, stats_display])
 
 
307
 
308
  if __name__ == "__main__":
309
  demo.launch()
 
5
 
6
  import gradio as gr
7
  import random
8
+ import time
9
 
10
  # ============================================================================
11
  # World Model Core Classes
 
14
  class GridWorld:
15
  """Simple grid environment for world model demonstration"""
16
 
17
+ def __init__(self, size=6):
18
  self.size = size
19
  self.reset()
20
 
21
  def reset(self):
22
+ self.agent_pos = [0, 0]
23
+ self.goal_pos = [self.size - 1, self.size - 1]
24
  self.obstacles = self._generate_obstacles()
25
  self.steps = 0
26
  return self._get_state()
27
 
28
  def _generate_obstacles(self):
29
  obstacles = set()
30
+ num_obstacles = self.size - 1
31
+ attempts = 0
32
+ while len(obstacles) < num_obstacles and attempts < 100:
33
  x, y = random.randint(0, self.size-1), random.randint(0, self.size-1)
34
  if [x, y] != self.agent_pos and [x, y] != self.goal_pos:
35
+ # Don't block the only path
36
+ if not (x == 0 and y == 1) and not (x == 1 and y == 0):
37
+ obstacles.add((x, y))
38
+ attempts += 1
39
  return obstacles
40
 
41
  def _get_state(self):
 
48
  }
49
 
50
  def step(self, action):
51
+ moves = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)}
52
+ dx, dy = moves.get(action, (0, 0))
53
  new_x = max(0, min(self.size - 1, self.agent_pos[0] + dx))
54
  new_y = max(0, min(self.size - 1, self.agent_pos[1] + dy))
55
 
 
58
 
59
  self.steps += 1
60
  done = self.agent_pos == self.goal_pos
61
+ return self._get_state(), done
62
+
63
+ def copy(self):
64
+ new_world = GridWorld(self.size)
65
+ new_world.agent_pos = self.agent_pos.copy()
66
+ new_world.goal_pos = self.goal_pos.copy()
67
+ new_world.obstacles = self.obstacles.copy()
68
+ new_world.steps = self.steps
69
+ return new_world
70
 
71
+ class WorldModelAgent:
72
+ """Agent that uses a world model to plan ahead"""
73
 
74
  def __init__(self):
75
+ self.imagination_steps = []
76
+ self.best_path = []
77
+ self.action_values = {}
78
+
79
+ def imagine_action(self, world, action):
80
+ """Use world model to predict outcome without actually taking action"""
81
+ imagined_world = world.copy()
82
+ imagined_state, done = imagined_world.step(action)
83
+ return imagined_state, done, imagined_world
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def evaluate_position(self, pos, goal):
86
+ """Simple heuristic: negative manhattan distance to goal"""
87
+ return -(abs(pos[0] - goal[0]) + abs(pos[1] - goal[1]))
88
+
89
+ def plan(self, world, depth=3):
90
+ """
91
+ Plan ahead by imagining future states.
92
+ This is what makes world models special - we can "think" before acting.
93
+ """
94
+ self.imagination_steps = []
95
+ self.action_values = {}
96
+ actions = ['up', 'down', 'left', 'right']
97
 
98
+ for action in actions:
99
+ # Imagine taking this action
100
+ imagined_state, done, imagined_world = self.imagine_action(world, action)
101
+
102
+ # Record what we imagined
103
+ self.imagination_steps.append({
104
+ 'action': action,
105
+ 'predicted_pos': imagined_state['agent'].copy(),
106
+ 'depth': 1
107
+ })
108
+
109
+ if done:
110
+ # Found goal!
111
+ self.action_values[action] = 100
112
+ continue
113
+
114
+ # Look deeper - imagine further into the future
115
+ value = self.evaluate_position(imagined_state['agent'], imagined_state['goal'])
116
+
117
+ # Plan 2 steps ahead
118
+ best_future_value = -999
119
+ for next_action in actions:
120
+ future_state, future_done, _ = self.imagine_action(imagined_world, next_action)
121
+
122
+ self.imagination_steps.append({
123
+ 'action': f"{action}→{next_action}",
124
+ 'predicted_pos': future_state['agent'].copy(),
125
+ 'depth': 2
126
+ })
127
+
128
+ if future_done:
129
+ best_future_value = 100
130
+ break
131
+
132
+ future_value = self.evaluate_position(future_state['agent'], future_state['goal'])
133
+ best_future_value = max(best_future_value, future_value)
134
+
135
+ self.action_values[action] = value + 0.9 * best_future_value
136
 
137
+ # Return best action
138
+ best_action = max(self.action_values, key=self.action_values.get)
139
+ return best_action, self.action_values, self.imagination_steps
140
 
141
  # ============================================================================
142
  # Visualization
143
  # ============================================================================
144
 
145
+ def render_grid(state, phase="observe", imagined_positions=None, highlight_action=None):
146
+ """Render the grid as HTML"""
147
  agent = state['agent']
148
  goal = state['goal']
149
  obstacles = set(tuple(o) if isinstance(o, list) else o for o in state['obstacles'])
150
  size = state['size']
151
 
152
+ phase_info = {
153
+ 'observe': ('🔍 OBSERVE', '#3b82f6', 'Perceiving current state...'),
154
+ 'imagine': ('💭 IMAGINE', '#f59e0b', 'Simulating possible futures...'),
155
+ 'evaluate': ('⚖️ EVALUATE', '#8b5cf6', 'Scoring each path...'),
156
+ 'act': ('⚡ ACT', '#10b981', 'Executing best action!'),
 
157
  }
158
+
159
+ phase_name, phase_color, phase_desc = phase_info.get(phase, ('', '#6b7280', ''))
160
 
161
  html = f'''
162
  <div style="text-align: center; font-family: system-ui, sans-serif;">
163
+ <div style="display: inline-block; background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%);
164
+ padding: 24px; border-radius: 16px; box-shadow: 0 8px 32px rgba(0,0,0,0.4);">
165
+ <div style="margin-bottom: 8px; color: {phase_color}; font-weight: bold; font-size: 22px;
166
+ text-shadow: 0 0 20px {phase_color}40;">
167
+ {phase_name}
168
+ </div>
169
+ <div style="margin-bottom: 16px; color: #94a3b8; font-size: 14px;">
170
+ {phase_desc}
171
  </div>
172
+ <table style="border-collapse: collapse; margin: auto; border-radius: 8px; overflow: hidden;">
173
  '''
174
 
175
+ # Convert imagined positions to set for easy lookup
176
+ imagined_set = set()
177
+ if imagined_positions:
178
+ for pos in imagined_positions:
179
+ imagined_set.add(tuple(pos))
180
+
181
  for y in range(size):
182
  html += '<tr>'
183
  for x in range(size):
184
  bg = '#334155'
185
  content = ''
186
+ border = '2px solid #475569'
187
+ opacity = '1'
188
 
189
  if (x, y) in obstacles:
190
+ bg = '#991b1b'
191
  content = '🧱'
192
  elif [x, y] == goal:
193
  bg = '#166534'
 
195
  elif [x, y] == agent:
196
  bg = '#1d4ed8'
197
  content = '🤖'
198
+ elif (x, y) in imagined_set:
199
+ # Show imagined positions as ghost agents
200
+ bg = '#475569'
201
+ content = '👻'
202
+ border = f'2px dashed {phase_color}'
203
 
204
  html += f'''
205
+ <td style="width: 50px; height: 50px; background: {bg};
206
+ border: {border}; text-align: center; font-size: 24px;
207
+ transition: all 0.3s ease;">
208
  {content}
209
  </td>
210
  '''
 
212
 
213
  html += '''
214
  </table>
215
+ <div style="margin-top: 16px; color: #64748b; font-size: 13px;">
216
+ 🤖 Agent | ⭐ Goal | 🧱 Wall | 👻 Imagined Position
217
+ </div>
218
+ </div>
219
+ </div>
220
+ '''
221
+ return html
222
+
223
+ def render_thinking(action_values, imagination_steps, best_action):
224
+ """Render the agent's thinking process"""
225
+ if not action_values:
226
+ return "<div style='color: #64748b; text-align: center; padding: 20px;'>Click 'Think & Move' to see the agent plan!</div>"
227
+
228
+ html = '''
229
+ <div style="font-family: system-ui, sans-serif; padding: 16px; background: #1e293b; border-radius: 12px;">
230
+ <h3 style="color: #f59e0b; margin-top: 0;">🧠 Agent's Reasoning</h3>
231
+ <p style="color: #94a3b8; font-size: 14px;">The agent imagined taking each action and predicted the outcomes:</p>
232
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 12px; margin-top: 12px;">
233
+ '''
234
+
235
+ action_symbols = {'up': '⬆️', 'down': '⬇️', 'left': '⬅️', 'right': '➡️'}
236
+
237
+ for action, value in sorted(action_values.items(), key=lambda x: -x[1]):
238
+ is_best = action == best_action
239
+ border_color = '#10b981' if is_best else '#475569'
240
+ bg = '#064e3b' if is_best else '#334155'
241
+ label = ' ✓ BEST' if is_best else ''
242
+
243
+ html += f'''
244
+ <div style="background: {bg}; border: 2px solid {border_color}; border-radius: 8px; padding: 12px; text-align: center;">
245
+ <div style="font-size: 24px;">{action_symbols.get(action, '?')}</div>
246
+ <div style="color: #e2e8f0; font-weight: bold; margin-top: 4px;">{action.upper()}{label}</div>
247
+ <div style="color: #94a3b8; font-size: 13px;">Score: {value:.1f}</div>
248
+ </div>
249
+ '''
250
+
251
+ html += '''
252
+ </div>
253
+ <div style="margin-top: 16px; padding: 12px; background: #0f172a; border-radius: 8px;">
254
+ <div style="color: #10b981; font-weight: bold;">💡 Why this works:</div>
255
+ <div style="color: #94a3b8; font-size: 13px; margin-top: 8px;">
256
+ The agent <b>imagined</b> each possible action, <b>predicted</b> where it would end up,
257
+ and <b>evaluated</b> how close that gets to the goal. It can even imagine 2 steps ahead!
258
+ <br><br>
259
+ This is different from trial-and-error learning — the agent "thinks" before acting.
260
  </div>
261
  </div>
262
  </div>
 
264
  return html
265
 
266
  # ============================================================================
267
+ # Global State
268
  # ============================================================================
269
 
270
+ world = GridWorld(6)
271
+ agent = WorldModelAgent()
272
  current_state = world.reset()
 
273
 
274
+ def reset_game():
275
+ global world, agent, current_state
276
+ world = GridWorld(6)
277
+ agent = WorldModelAgent()
278
+ current_state = world.reset()
279
+
280
+ grid_html = render_grid(current_state, phase="observe")
281
+ thinking_html = "<div style='color: #64748b; text-align: center; padding: 20px;'>Click <b>'Think & Move'</b> to watch the agent plan!</div>"
282
+ status = "🔄 New environment! Click 'Think & Move' to see the world model in action."
283
+
284
+ return grid_html, thinking_html, status
285
 
286
+ def think_and_move():
287
+ """Main function: Agent thinks using world model, then acts"""
288
+ global current_state, world, agent
289
+
290
+ # Check if already at goal
291
+ if current_state['agent'] == current_state['goal']:
292
+ return reset_game()
293
+
294
+ # Phase 1: Observe (already done - we have current_state)
295
+
296
+ # Phase 2: Imagine & Evaluate - Plan using world model
297
+ best_action, action_values, imagination_steps = agent.plan(world)
298
 
299
+ # Get imagined positions for visualization
300
+ imagined_positions = [step['predicted_pos'] for step in imagination_steps if step['depth'] == 1]
301
 
302
+ # Show imagination phase
303
+ grid_html = render_grid(current_state, phase="imagine", imagined_positions=imagined_positions)
304
+ thinking_html = render_thinking(action_values, imagination_steps, best_action)
305
 
306
+ # Phase 3: Act - Execute the best action
307
+ current_state, done = world.step(best_action)
308
+
309
+ # Update grid to show result
310
+ grid_html = render_grid(current_state, phase="act" if not done else "observe")
311
 
312
  if done:
313
+ status = f"🎉 Goal reached in {current_state['steps']} steps! Click 'Reset' for a new puzzle."
 
 
314
  else:
315
+ status = f"Step {current_state['steps']}: Chose {best_action.upper()} (score: {action_values[best_action]:.1f})"
 
316
 
317
+ return grid_html, thinking_html, status
 
318
 
319
+ def manual_move(action):
320
+ """Let user move manually to compare with agent"""
321
+ global current_state, world
322
+
323
+ if current_state['agent'] == current_state['goal']:
324
+ return reset_game()
325
+
326
+ current_state, done = world.step(action)
327
+ grid_html = render_grid(current_state, phase="observe")
328
+ thinking_html = "<div style='color: #64748b; text-align: center; padding: 20px;'>You moved manually. Click 'Think & Move' to see how the agent would plan!</div>"
329
+
330
+ if done:
331
+ status = f"🎉 You reached the goal in {current_state['steps']} steps!"
332
+ else:
333
+ status = f"You moved {action}. Steps: {current_state['steps']}"
334
+
335
+ return grid_html, thinking_html, status
336
+
337
+ # ============================================================================
338
+ # Gradio Interface
339
+ # ============================================================================
340
 
 
341
  with gr.Blocks(title="World Model Demo", theme=gr.themes.Soft()) as demo:
342
  gr.Markdown("""
343
  # 🧠 World Model Demo
344
 
345
+ **Watch an AI agent "think" before it acts!**
346
+
347
+ Unlike reactive AI that just responds to inputs, this agent uses a **world model** to:
348
+ 1. **Imagine** what would happen if it took each action
349
+ 2. **Evaluate** which imagined future is best
350
+ 3. **Act** based on its mental simulation
351
+
352
+ 👉 **Click "Think & Move"** to watch the agent plan its path to the ⭐ goal!
353
  """)
354
 
355
  with gr.Row():
356
+ with gr.Column(scale=3):
357
+ grid_display = gr.HTML()
358
+ status_display = gr.Textbox(label="Status", interactive=False)
 
359
 
360
+ with gr.Column(scale=2):
361
+ thinking_display = gr.HTML()
 
 
 
 
 
 
 
 
362
 
363
+ gr.Markdown("### 🎮 Controls")
364
+
365
+ think_btn = gr.Button("🧠 Think & Move", variant="primary", size="lg")
366
  reset_btn = gr.Button("🔄 Reset", variant="secondary")
367
 
368
+ gr.Markdown("---")
369
+ gr.Markdown("**Manual controls** (to compare with agent):")
370
+ with gr.Row():
371
+ up_btn = gr.Button("⬆️")
372
+ with gr.Row():
373
+ left_btn = gr.Button("⬅️")
374
+ down_btn = gr.Button("⬇️")
375
+ right_btn = gr.Button("➡️")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
+ with gr.Accordion("📖 What makes this different from ChatGPT/Claude?", open=False):
378
  gr.Markdown("""
379
  | Aspect | Language Model (GPT, Claude) | World Model (This Demo) |
380
  |--------|------------------------------|-------------------------|
381
+ | **Predicts** | Next *word* in text | Next *state* given action |
 
382
  | **"Thinking"** | Generates plausible text | Simulates physical outcomes |
383
  | **Planning** | Implicit (chain-of-thought) | Explicit (tree search) |
 
384
 
385
+ **The key insight:** This agent can "imagine" taking actions and see the results
386
+ *before* committing to them in the real world. It's like planning your route
387
+ on a map before driving.
388
 
389
+ **Real examples:** MuZero (mastered Chess/Go without knowing rules),
390
+ Dreamer (robot control), IRIS (Atari games)
391
  """)
392
 
393
  with gr.Accordion("🔬 Why does this matter for AI Safety?", open=False):
394
  gr.Markdown("""
395
+ World models are important for AI safety because:
 
 
 
 
 
396
 
397
+ - **Predictability**: We can inspect what futures the agent is considering
398
+ - **Interpretability**: The agent's "reasoning" is explicit, not hidden
399
+ - **Control**: We can verify the agent isn't planning harmful actions
400
+ - **Corrigibility**: Planning agents can incorporate "avoid irreversible actions"
401
 
402
+ Understanding how AI systems model the world helps us build systems we can trust.
 
403
  """)
404
 
405
  # Connect buttons
406
+ think_btn.click(think_and_move, outputs=[grid_display, thinking_display, status_display])
407
+ reset_btn.click(reset_game, outputs=[grid_display, thinking_display, status_display])
408
+
409
+ up_btn.click(lambda: manual_move("up"), outputs=[grid_display, thinking_display, status_display])
410
+ down_btn.click(lambda: manual_move("down"), outputs=[grid_display, thinking_display, status_display])
411
+ left_btn.click(lambda: manual_move("left"), outputs=[grid_display, thinking_display, status_display])
412
+ right_btn.click(lambda: manual_move("right"), outputs=[grid_display, thinking_display, status_display])
413
+
414
+ # Initialize
415
+ demo.load(reset_game, outputs=[grid_display, thinking_display, status_display])
416
 
417
  if __name__ == "__main__":
418
  demo.launch()