""" World Model Demo - Interactive AI Planning Visualization Educational demonstration of model-based reinforcement learning concepts """ import gradio as gr import random import time # ============================================================================ # World Model Core Classes # ============================================================================ class GridWorld: """Simple grid environment for world model demonstration""" def __init__(self, size=6): self.size = size self.reset() def reset(self): self.agent_pos = [0, 0] self.goal_pos = [self.size - 1, self.size - 1] self.obstacles = self._generate_obstacles() self.steps = 0 return self._get_state() def _generate_obstacles(self): obstacles = set() num_obstacles = self.size - 1 attempts = 0 while len(obstacles) < num_obstacles and attempts < 100: x, y = random.randint(0, self.size-1), random.randint(0, self.size-1) if [x, y] != self.agent_pos and [x, y] != self.goal_pos: # Don't block the only path if not (x == 0 and y == 1) and not (x == 1 and y == 0): obstacles.add((x, y)) attempts += 1 return obstacles def _get_state(self): return { 'agent': self.agent_pos.copy(), 'goal': self.goal_pos, 'obstacles': list(self.obstacles), 'size': self.size, 'steps': self.steps } def step(self, action): moves = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)} dx, dy = moves.get(action, (0, 0)) new_x = max(0, min(self.size - 1, self.agent_pos[0] + dx)) new_y = max(0, min(self.size - 1, self.agent_pos[1] + dy)) if (new_x, new_y) not in self.obstacles: self.agent_pos = [new_x, new_y] self.steps += 1 done = self.agent_pos == self.goal_pos return self._get_state(), done def copy(self): new_world = GridWorld(self.size) new_world.agent_pos = self.agent_pos.copy() new_world.goal_pos = self.goal_pos.copy() new_world.obstacles = self.obstacles.copy() new_world.steps = self.steps return new_world class WorldModelAgent: """Agent that uses a world model to plan ahead""" def __init__(self): self.imagination_steps = [] self.best_path = [] self.action_values = {} def imagine_action(self, world, action): """Use world model to predict outcome without actually taking action""" imagined_world = world.copy() imagined_state, done = imagined_world.step(action) return imagined_state, done, imagined_world def evaluate_position(self, pos, goal): """Simple heuristic: negative manhattan distance to goal""" return -(abs(pos[0] - goal[0]) + abs(pos[1] - goal[1])) def plan(self, world, depth=3): """ Plan ahead by imagining future states. This is what makes world models special - we can "think" before acting. """ self.imagination_steps = [] self.action_values = {} actions = ['up', 'down', 'left', 'right'] for action in actions: # Imagine taking this action imagined_state, done, imagined_world = self.imagine_action(world, action) # Record what we imagined self.imagination_steps.append({ 'action': action, 'predicted_pos': imagined_state['agent'].copy(), 'depth': 1 }) if done: # Found goal! self.action_values[action] = 100 continue # Look deeper - imagine further into the future value = self.evaluate_position(imagined_state['agent'], imagined_state['goal']) # Plan 2 steps ahead best_future_value = -999 for next_action in actions: future_state, future_done, _ = self.imagine_action(imagined_world, next_action) self.imagination_steps.append({ 'action': f"{action}→{next_action}", 'predicted_pos': future_state['agent'].copy(), 'depth': 2 }) if future_done: best_future_value = 100 break future_value = self.evaluate_position(future_state['agent'], future_state['goal']) best_future_value = max(best_future_value, future_value) self.action_values[action] = value + 0.9 * best_future_value # Return best action best_action = max(self.action_values, key=self.action_values.get) return best_action, self.action_values, self.imagination_steps # ============================================================================ # Visualization # ============================================================================ def render_grid(state, phase="observe", imagined_positions=None, highlight_action=None): """Render the grid as HTML""" agent = state['agent'] goal = state['goal'] obstacles = set(tuple(o) if isinstance(o, list) else o for o in state['obstacles']) size = state['size'] phase_info = { 'observe': ('🔍 OBSERVE', '#3b82f6', 'Perceiving current state...'), 'imagine': ('💭 IMAGINE', '#f59e0b', 'Simulating possible futures...'), 'evaluate': ('⚖️ EVALUATE', '#8b5cf6', 'Scoring each path...'), 'act': ('⚡ ACT', '#10b981', 'Executing best action!'), } phase_name, phase_color, phase_desc = phase_info.get(phase, ('', '#6b7280', '')) html = f'''
{phase_name}
{phase_desc}
''' # Convert imagined positions to set for easy lookup imagined_set = set() if imagined_positions: for pos in imagined_positions: imagined_set.add(tuple(pos)) for y in range(size): html += '' for x in range(size): bg = '#334155' content = '' border = '2px solid #475569' opacity = '1' if (x, y) in obstacles: bg = '#991b1b' content = '🧱' elif [x, y] == goal: bg = '#166534' content = '⭐' elif [x, y] == agent: bg = '#1d4ed8' content = '🤖' elif (x, y) in imagined_set: # Show imagined positions as ghost agents bg = '#475569' content = '👻' border = f'2px dashed {phase_color}' html += f''' ''' html += '' html += '''
{content}
🤖 Agent | ⭐ Goal | 🧱 Wall | 👻 Imagined Position
''' return html def render_thinking(action_values, imagination_steps, best_action): """Render the agent's thinking process""" if not action_values: return "
Click 'Think & Move' to see the agent plan!
" html = '''

🧠 Agent's Reasoning

The agent imagined taking each action and predicted the outcomes:

''' action_symbols = {'up': '⬆️', 'down': '⬇️', 'left': '⬅️', 'right': '➡️'} for action, value in sorted(action_values.items(), key=lambda x: -x[1]): is_best = action == best_action border_color = '#10b981' if is_best else '#475569' bg = '#064e3b' if is_best else '#334155' label = ' ✓ BEST' if is_best else '' html += f'''
{action_symbols.get(action, '?')}
{action.upper()}{label}
Score: {value:.1f}
''' html += '''
💡 Why this works:
The agent imagined each possible action, predicted where it would end up, and evaluated how close that gets to the goal. It can even imagine 2 steps ahead!

This is different from trial-and-error learning — the agent "thinks" before acting.
''' return html # ============================================================================ # Global State # ============================================================================ world = GridWorld(6) agent = WorldModelAgent() current_state = world.reset() def reset_game(): global world, agent, current_state world = GridWorld(6) agent = WorldModelAgent() current_state = world.reset() grid_html = render_grid(current_state, phase="observe") thinking_html = "
Click 'Think & Move' to watch the agent plan!
" status = "🔄 New environment! Click 'Think & Move' to see the world model in action." return grid_html, thinking_html, status def think_and_move(): """Main function: Agent thinks using world model, then acts""" global current_state, world, agent # Check if already at goal if current_state['agent'] == current_state['goal']: return reset_game() # Phase 1: Observe (already done - we have current_state) # Phase 2: Imagine & Evaluate - Plan using world model best_action, action_values, imagination_steps = agent.plan(world) # Get imagined positions for visualization imagined_positions = [step['predicted_pos'] for step in imagination_steps if step['depth'] == 1] # Show imagination phase grid_html = render_grid(current_state, phase="imagine", imagined_positions=imagined_positions) thinking_html = render_thinking(action_values, imagination_steps, best_action) # Phase 3: Act - Execute the best action current_state, done = world.step(best_action) # Update grid to show result grid_html = render_grid(current_state, phase="act" if not done else "observe") if done: status = f"🎉 Goal reached in {current_state['steps']} steps! Click 'Reset' for a new puzzle." else: status = f"Step {current_state['steps']}: Chose {best_action.upper()} (score: {action_values[best_action]:.1f})" return grid_html, thinking_html, status def manual_move(action): """Let user move manually to compare with agent""" global current_state, world if current_state['agent'] == current_state['goal']: return reset_game() current_state, done = world.step(action) grid_html = render_grid(current_state, phase="observe") thinking_html = "
You moved manually. Click 'Think & Move' to see how the agent would plan!
" if done: status = f"🎉 You reached the goal in {current_state['steps']} steps!" else: status = f"You moved {action}. Steps: {current_state['steps']}" return grid_html, thinking_html, status # ============================================================================ # Gradio Interface # ============================================================================ with gr.Blocks(title="World Model Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🧠 World Model Demo **Watch an AI agent "think" before it acts!** Unlike reactive AI that just responds to inputs, this agent uses a **world model** to: 1. **Imagine** what would happen if it took each action 2. **Evaluate** which imagined future is best 3. **Act** based on its mental simulation 👉 **Click "Think & Move"** to watch the agent plan its path to the ⭐ goal! """) with gr.Row(): with gr.Column(scale=3): grid_display = gr.HTML() status_display = gr.Textbox(label="Status", interactive=False) with gr.Column(scale=2): thinking_display = gr.HTML() gr.Markdown("### 🎮 Controls") think_btn = gr.Button("🧠 Think & Move", variant="primary", size="lg") reset_btn = gr.Button("🔄 Reset", variant="secondary") gr.Markdown("---") gr.Markdown("**Manual controls** (to compare with agent):") with gr.Row(): up_btn = gr.Button("⬆️") with gr.Row(): left_btn = gr.Button("⬅️") down_btn = gr.Button("⬇️") right_btn = gr.Button("➡️") with gr.Accordion("📖 What makes this different from ChatGPT/Claude?", open=False): gr.Markdown(""" | Aspect | Language Model (GPT, Claude) | World Model (This Demo) | |--------|------------------------------|-------------------------| | **Predicts** | Next *word* in text | Next *state* given action | | **"Thinking"** | Generates plausible text | Simulates physical outcomes | | **Planning** | Implicit (chain-of-thought) | Explicit (tree search) | **The key insight:** This agent can "imagine" taking actions and see the results *before* committing to them in the real world. It's like planning your route on a map before driving. **Real examples:** MuZero (mastered Chess/Go without knowing rules), Dreamer (robot control), IRIS (Atari games) """) with gr.Accordion("🔬 Why does this matter for AI Safety?", open=False): gr.Markdown(""" World models are important for AI safety because: - **Predictability**: We can inspect what futures the agent is considering - **Interpretability**: The agent's "reasoning" is explicit, not hidden - **Control**: We can verify the agent isn't planning harmful actions - **Corrigibility**: Planning agents can incorporate "avoid irreversible actions" Understanding how AI systems model the world helps us build systems we can trust. """) # Connect buttons think_btn.click(think_and_move, outputs=[grid_display, thinking_display, status_display]) reset_btn.click(reset_game, outputs=[grid_display, thinking_display, status_display]) up_btn.click(lambda: manual_move("up"), outputs=[grid_display, thinking_display, status_display]) down_btn.click(lambda: manual_move("down"), outputs=[grid_display, thinking_display, status_display]) left_btn.click(lambda: manual_move("left"), outputs=[grid_display, thinking_display, status_display]) right_btn.click(lambda: manual_move("right"), outputs=[grid_display, thinking_display, status_display]) # Initialize demo.load(reset_game, outputs=[grid_display, thinking_display, status_display]) if __name__ == "__main__": demo.launch()