"""
World Model Demo - Interactive AI Planning Visualization
Educational demonstration of model-based reinforcement learning concepts
"""
import gradio as gr
import random
import time
# ============================================================================
# World Model Core Classes
# ============================================================================
class GridWorld:
"""Simple grid environment for world model demonstration"""
def __init__(self, size=6):
self.size = size
self.reset()
def reset(self):
self.agent_pos = [0, 0]
self.goal_pos = [self.size - 1, self.size - 1]
self.obstacles = self._generate_obstacles()
self.steps = 0
return self._get_state()
def _generate_obstacles(self):
obstacles = set()
num_obstacles = self.size - 1
attempts = 0
while len(obstacles) < num_obstacles and attempts < 100:
x, y = random.randint(0, self.size-1), random.randint(0, self.size-1)
if [x, y] != self.agent_pos and [x, y] != self.goal_pos:
# Don't block the only path
if not (x == 0 and y == 1) and not (x == 1 and y == 0):
obstacles.add((x, y))
attempts += 1
return obstacles
def _get_state(self):
return {
'agent': self.agent_pos.copy(),
'goal': self.goal_pos,
'obstacles': list(self.obstacles),
'size': self.size,
'steps': self.steps
}
def step(self, action):
moves = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)}
dx, dy = moves.get(action, (0, 0))
new_x = max(0, min(self.size - 1, self.agent_pos[0] + dx))
new_y = max(0, min(self.size - 1, self.agent_pos[1] + dy))
if (new_x, new_y) not in self.obstacles:
self.agent_pos = [new_x, new_y]
self.steps += 1
done = self.agent_pos == self.goal_pos
return self._get_state(), done
def copy(self):
new_world = GridWorld(self.size)
new_world.agent_pos = self.agent_pos.copy()
new_world.goal_pos = self.goal_pos.copy()
new_world.obstacles = self.obstacles.copy()
new_world.steps = self.steps
return new_world
class WorldModelAgent:
"""Agent that uses a world model to plan ahead"""
def __init__(self):
self.imagination_steps = []
self.best_path = []
self.action_values = {}
def imagine_action(self, world, action):
"""Use world model to predict outcome without actually taking action"""
imagined_world = world.copy()
imagined_state, done = imagined_world.step(action)
return imagined_state, done, imagined_world
def evaluate_position(self, pos, goal):
"""Simple heuristic: negative manhattan distance to goal"""
return -(abs(pos[0] - goal[0]) + abs(pos[1] - goal[1]))
def plan(self, world, depth=3):
"""
Plan ahead by imagining future states.
This is what makes world models special - we can "think" before acting.
"""
self.imagination_steps = []
self.action_values = {}
actions = ['up', 'down', 'left', 'right']
for action in actions:
# Imagine taking this action
imagined_state, done, imagined_world = self.imagine_action(world, action)
# Record what we imagined
self.imagination_steps.append({
'action': action,
'predicted_pos': imagined_state['agent'].copy(),
'depth': 1
})
if done:
# Found goal!
self.action_values[action] = 100
continue
# Look deeper - imagine further into the future
value = self.evaluate_position(imagined_state['agent'], imagined_state['goal'])
# Plan 2 steps ahead
best_future_value = -999
for next_action in actions:
future_state, future_done, _ = self.imagine_action(imagined_world, next_action)
self.imagination_steps.append({
'action': f"{action}→{next_action}",
'predicted_pos': future_state['agent'].copy(),
'depth': 2
})
if future_done:
best_future_value = 100
break
future_value = self.evaluate_position(future_state['agent'], future_state['goal'])
best_future_value = max(best_future_value, future_value)
self.action_values[action] = value + 0.9 * best_future_value
# Return best action
best_action = max(self.action_values, key=self.action_values.get)
return best_action, self.action_values, self.imagination_steps
# ============================================================================
# Visualization
# ============================================================================
def render_grid(state, phase="observe", imagined_positions=None, highlight_action=None):
"""Render the grid as HTML"""
agent = state['agent']
goal = state['goal']
obstacles = set(tuple(o) if isinstance(o, list) else o for o in state['obstacles'])
size = state['size']
phase_info = {
'observe': ('🔍 OBSERVE', '#3b82f6', 'Perceiving current state...'),
'imagine': ('💭 IMAGINE', '#f59e0b', 'Simulating possible futures...'),
'evaluate': ('⚖️ EVALUATE', '#8b5cf6', 'Scoring each path...'),
'act': ('⚡ ACT', '#10b981', 'Executing best action!'),
}
phase_name, phase_color, phase_desc = phase_info.get(phase, ('', '#6b7280', ''))
html = f'''
{phase_name}
{phase_desc}
'''
# Convert imagined positions to set for easy lookup
imagined_set = set()
if imagined_positions:
for pos in imagined_positions:
imagined_set.add(tuple(pos))
for y in range(size):
html += ''
for x in range(size):
bg = '#334155'
content = ''
border = '2px solid #475569'
opacity = '1'
if (x, y) in obstacles:
bg = '#991b1b'
content = '🧱'
elif [x, y] == goal:
bg = '#166534'
content = '⭐'
elif [x, y] == agent:
bg = '#1d4ed8'
content = '🤖'
elif (x, y) in imagined_set:
# Show imagined positions as ghost agents
bg = '#475569'
content = '👻'
border = f'2px dashed {phase_color}'
html += f'''
|
{content}
|
'''
html += '
'
html += '''
🤖 Agent | ⭐ Goal | 🧱 Wall | 👻 Imagined Position
'''
return html
def render_thinking(action_values, imagination_steps, best_action):
"""Render the agent's thinking process"""
if not action_values:
return "Click 'Think & Move' to see the agent plan!
"
html = '''
🧠 Agent's Reasoning
The agent imagined taking each action and predicted the outcomes:
'''
action_symbols = {'up': '⬆️', 'down': '⬇️', 'left': '⬅️', 'right': '➡️'}
for action, value in sorted(action_values.items(), key=lambda x: -x[1]):
is_best = action == best_action
border_color = '#10b981' if is_best else '#475569'
bg = '#064e3b' if is_best else '#334155'
label = ' ✓ BEST' if is_best else ''
html += f'''
{action_symbols.get(action, '?')}
{action.upper()}{label}
Score: {value:.1f}
'''
html += '''
💡 Why this works:
The agent imagined each possible action, predicted where it would end up,
and evaluated how close that gets to the goal. It can even imagine 2 steps ahead!
This is different from trial-and-error learning — the agent "thinks" before acting.
'''
return html
# ============================================================================
# Global State
# ============================================================================
world = GridWorld(6)
agent = WorldModelAgent()
current_state = world.reset()
def reset_game():
global world, agent, current_state
world = GridWorld(6)
agent = WorldModelAgent()
current_state = world.reset()
grid_html = render_grid(current_state, phase="observe")
thinking_html = "Click 'Think & Move' to watch the agent plan!
"
status = "🔄 New environment! Click 'Think & Move' to see the world model in action."
return grid_html, thinking_html, status
def think_and_move():
"""Main function: Agent thinks using world model, then acts"""
global current_state, world, agent
# Check if already at goal
if current_state['agent'] == current_state['goal']:
return reset_game()
# Phase 1: Observe (already done - we have current_state)
# Phase 2: Imagine & Evaluate - Plan using world model
best_action, action_values, imagination_steps = agent.plan(world)
# Get imagined positions for visualization
imagined_positions = [step['predicted_pos'] for step in imagination_steps if step['depth'] == 1]
# Show imagination phase
grid_html = render_grid(current_state, phase="imagine", imagined_positions=imagined_positions)
thinking_html = render_thinking(action_values, imagination_steps, best_action)
# Phase 3: Act - Execute the best action
current_state, done = world.step(best_action)
# Update grid to show result
grid_html = render_grid(current_state, phase="act" if not done else "observe")
if done:
status = f"🎉 Goal reached in {current_state['steps']} steps! Click 'Reset' for a new puzzle."
else:
status = f"Step {current_state['steps']}: Chose {best_action.upper()} (score: {action_values[best_action]:.1f})"
return grid_html, thinking_html, status
def manual_move(action):
"""Let user move manually to compare with agent"""
global current_state, world
if current_state['agent'] == current_state['goal']:
return reset_game()
current_state, done = world.step(action)
grid_html = render_grid(current_state, phase="observe")
thinking_html = "You moved manually. Click 'Think & Move' to see how the agent would plan!
"
if done:
status = f"🎉 You reached the goal in {current_state['steps']} steps!"
else:
status = f"You moved {action}. Steps: {current_state['steps']}"
return grid_html, thinking_html, status
# ============================================================================
# Gradio Interface
# ============================================================================
with gr.Blocks(title="World Model Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧠 World Model Demo
**Watch an AI agent "think" before it acts!**
Unlike reactive AI that just responds to inputs, this agent uses a **world model** to:
1. **Imagine** what would happen if it took each action
2. **Evaluate** which imagined future is best
3. **Act** based on its mental simulation
👉 **Click "Think & Move"** to watch the agent plan its path to the ⭐ goal!
""")
with gr.Row():
with gr.Column(scale=3):
grid_display = gr.HTML()
status_display = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
thinking_display = gr.HTML()
gr.Markdown("### 🎮 Controls")
think_btn = gr.Button("🧠 Think & Move", variant="primary", size="lg")
reset_btn = gr.Button("🔄 Reset", variant="secondary")
gr.Markdown("---")
gr.Markdown("**Manual controls** (to compare with agent):")
with gr.Row():
up_btn = gr.Button("⬆️")
with gr.Row():
left_btn = gr.Button("⬅️")
down_btn = gr.Button("⬇️")
right_btn = gr.Button("➡️")
with gr.Accordion("📖 What makes this different from ChatGPT/Claude?", open=False):
gr.Markdown("""
| Aspect | Language Model (GPT, Claude) | World Model (This Demo) |
|--------|------------------------------|-------------------------|
| **Predicts** | Next *word* in text | Next *state* given action |
| **"Thinking"** | Generates plausible text | Simulates physical outcomes |
| **Planning** | Implicit (chain-of-thought) | Explicit (tree search) |
**The key insight:** This agent can "imagine" taking actions and see the results
*before* committing to them in the real world. It's like planning your route
on a map before driving.
**Real examples:** MuZero (mastered Chess/Go without knowing rules),
Dreamer (robot control), IRIS (Atari games)
""")
with gr.Accordion("🔬 Why does this matter for AI Safety?", open=False):
gr.Markdown("""
World models are important for AI safety because:
- **Predictability**: We can inspect what futures the agent is considering
- **Interpretability**: The agent's "reasoning" is explicit, not hidden
- **Control**: We can verify the agent isn't planning harmful actions
- **Corrigibility**: Planning agents can incorporate "avoid irreversible actions"
Understanding how AI systems model the world helps us build systems we can trust.
""")
# Connect buttons
think_btn.click(think_and_move, outputs=[grid_display, thinking_display, status_display])
reset_btn.click(reset_game, outputs=[grid_display, thinking_display, status_display])
up_btn.click(lambda: manual_move("up"), outputs=[grid_display, thinking_display, status_display])
down_btn.click(lambda: manual_move("down"), outputs=[grid_display, thinking_display, status_display])
left_btn.click(lambda: manual_move("left"), outputs=[grid_display, thinking_display, status_display])
right_btn.click(lambda: manual_move("right"), outputs=[grid_display, thinking_display, status_display])
# Initialize
demo.load(reset_game, outputs=[grid_display, thinking_display, status_display])
if __name__ == "__main__":
demo.launch()