| """ |
| ARC-AGI-3 Agent: Full interactive agent combining JEPA, RSSM, and planning. |
| |
| Core loop: |
| 1. Encode observation via Grid-JEPA |
| 2. Update RSSM world model with (obs, action) history |
| 3. Use imagination rollouts to evaluate candidate actions |
| 4. Execute best action in environment |
| 5. Persist RSSM state across levels within an environment |
| 6. TTT LoRA fine-tune on collected demos |
| 7. Goal-inference from state transitions |
| """ |
|
|
| import random |
| from typing import List, Tuple, Optional, Dict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from encoder import GridPatchEmbed, ViTEncoder |
| from predictor import DiscreteActionEmbed, ActionConditionedPredictor |
| from grid_jepa import GridJEPA |
| from rssm import RSSM |
|
|
|
|
| class GoalInferenceModule(nn.Module): |
| """Infers the goal/terminal state from observed transitions.""" |
| |
| def __init__(self, obs_dim: int, hidden_dim: int = 128): |
| super().__init__() |
| self.obs_dim = obs_dim |
| self.hidden_dim = hidden_dim |
| self.goal_encoder = nn.Sequential( |
| nn.Linear(obs_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), |
| ) |
| self.goal_classifier = nn.Sequential( |
| nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Linear(64, 1), |
| ) |
| self.observed_goals: List[torch.Tensor] = [] |
| |
| def forward(self, obs_repr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| goal_repr = self.goal_encoder(obs_repr) |
| is_goal_logit = self.goal_classifier(goal_repr) |
| return goal_repr, is_goal_logit |
| |
| def register_terminal(self, obs_repr: torch.Tensor): |
| self.observed_goals.append(obs_repr.detach().cpu()) |
| |
| def get_goal_target(self) -> Optional[torch.Tensor]: |
| if len(self.observed_goals) == 0: |
| return None |
| return torch.stack(self.observed_goals).mean(dim=0) |
|
|
|
|
| class UncertaintyTracker: |
| """Tracks prediction errors to detect when the world model is wrong and triggers hypothesis revision.""" |
| |
| def __init__( |
| self, |
| window_size: int = 5, |
| error_threshold: float = 2.0, |
| revision_threshold: int = 3, |
| ): |
| self.window_size = window_size |
| self.error_threshold = error_threshold |
| self.revision_threshold = revision_threshold |
| self.prediction_errors: List[float] = [] |
| self.revision_count: int = 0 |
| self.last_revision_step: int = 0 |
| |
| def record_prediction_error(self, predicted_obs: torch.Tensor, actual_obs: torch.Tensor): |
| error = torch.norm(predicted_obs - actual_obs).item() |
| self.prediction_errors.append(error) |
| if len(self.prediction_errors) > self.window_size: |
| self.prediction_errors.pop(0) |
| |
| def should_revise_hypothesis(self) -> bool: |
| if len(self.prediction_errors) < self.revision_threshold: |
| return False |
| recent_errors = self.prediction_errors[-self.revision_threshold:] |
| high_error_count = sum(1 for e in recent_errors if e > self.error_threshold) |
| return high_error_count >= self.revision_threshold |
| |
| def get_error_stats(self) -> Dict[str, float]: |
| if len(self.prediction_errors) == 0: |
| return {"mean": 0.0, "max": 0.0, "recent": 0.0, "revision_count": self.revision_count} |
| recent = self.prediction_errors[-self.window_size:] |
| return { |
| "mean": sum(self.prediction_errors) / len(self.prediction_errors), |
| "max": max(self.prediction_errors), |
| "recent": sum(recent) / len(recent), |
| "revision_count": self.revision_count, |
| } |
| |
| def mark_revision(self, step: int): |
| self.revision_count += 1 |
| self.last_revision_step = step |
| self.prediction_errors.clear() |
|
|
|
|
| class ExplorationPolicy: |
| """Novelty-seeking exploration for unknown ARC environments.""" |
| |
| def __init__(self, num_actions: int, grid_size: int = 64): |
| self.num_actions = num_actions |
| self.grid_size = grid_size |
| self.num_positions = grid_size * grid_size |
| self.visited_states: set = set() |
| self.action_history: List[Tuple[int, int]] = [] |
| |
| def hash_state(self, grid: torch.Tensor) -> int: |
| return hash(grid.cpu().numpy().tobytes()) |
| |
| def select_action(self, grid: torch.Tensor, novelty_bonus: bool = True, avoid_undo: bool = True) -> Tuple[int, int]: |
| state_hash = self.hash_state(grid) |
| is_novel = state_hash not in self.visited_states |
| self.visited_states.add(state_hash) |
| |
| action_key = random.randint(0, self.num_actions - 1) |
| grid_np = grid.cpu().numpy() |
| import numpy as np |
| non_bg = list(zip(*np.where(grid_np != 0))) |
| if len(non_bg) > 0 and random.random() < 0.7: |
| r, c = random.choice(non_bg) |
| action_pos = r * self.grid_size + c |
| else: |
| action_pos = random.randint(0, self.num_positions - 1) |
| |
| if avoid_undo and len(self.action_history) > 0: |
| last_key, last_pos = self.action_history[-1] |
| if action_key == last_key and action_pos == last_pos: |
| action_key = (action_key + 1) % self.num_actions |
| |
| self.action_history.append((action_key, action_pos)) |
| return action_key, action_pos |
| |
| def reset(self): |
| self.visited_states.clear() |
| self.action_history.clear() |
|
|
|
|
| class PlanningModule: |
| """Model-based planning using RSSM imagination rollouts.""" |
| |
| def __init__( |
| self, |
| rssm: RSSM, |
| goal_module: GoalInferenceModule, |
| jepa_encoder: GridJEPA, |
| horizon: int = 10, |
| num_candidates: int = 16, |
| ): |
| self.rssm = rssm |
| self.goal_module = goal_module |
| self.jepa_encoder = jepa_encoder |
| self.horizon = horizon |
| self.num_candidates = num_candidates |
| |
| def plan_action( |
| self, grid: torch.Tensor, h_state: torch.Tensor, z_state: torch.Tensor, |
| num_actions: int, device: torch.device, |
| ) -> Tuple[int, int]: |
| B = 1 |
| obs_repr = self.jepa_encoder.encode(grid) |
| obs_repr = obs_repr.mean(dim=1) |
| goal_target = self.goal_module.get_goal_target() |
| num_positions = grid.shape[-1] * grid.shape[-2] |
| total_actions = num_actions * num_positions |
| |
| candidate_actions = torch.randint(0, total_actions, (B, self.num_candidates, self.horizon), device=device) |
| best_score = float("-inf") |
| best_action_idx = 0 |
| |
| for i in range(self.num_candidates): |
| actions = candidate_actions[0, i] |
| h_roll, z_roll = h_state.clone(), z_state.clone() |
| rollout_scores = [] |
| |
| for t in range(self.horizon): |
| a = actions[t:t+1] |
| h_roll, z_roll, _ = self.rssm.imagine(h_roll, z_roll, a) |
| if goal_target is not None: |
| dist_to_goal = -torch.norm(z_roll - goal_target.to(device)) |
| rollout_scores.append(dist_to_goal.item()) |
| else: |
| continue_logits = self.rssm.predict_continue(h_roll, z_roll) |
| rollout_scores.append(-torch.sigmoid(continue_logits).item()) |
| |
| avg_score = sum(rollout_scores) / len(rollout_scores) if rollout_scores else 0.0 |
| if goal_target is not None and avg_score > -0.1: |
| avg_score += (self.horizon - len(rollout_scores)) * 0.1 |
| if avg_score > best_score: |
| best_score = avg_score |
| best_action_idx = i |
| |
| best_action = candidate_actions[0, best_action_idx, 0].item() |
| action_key = best_action // num_positions |
| action_pos = best_action % num_positions |
| return action_key, action_pos |
|
|
|
|
| class ARCAgent(nn.Module): |
| """Complete ARC-AGI-3 agent with persistent state across levels.""" |
| |
| def __init__( |
| self, |
| jepa: GridJEPA, |
| rssm: RSSM, |
| num_actions: int = 6, |
| grid_size: int = 64, |
| exploration_ratio: float = 0.3, |
| device: str = "cuda", |
| ): |
| super().__init__() |
| self.jepa = jepa |
| self.rssm = rssm |
| self.num_actions = num_actions |
| self.grid_size = grid_size |
| self.num_positions = grid_size * grid_size |
| self.exploration_ratio = exploration_ratio |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") |
| |
| obs_dim = jepa.embed_dim |
| self.goal_module = GoalInferenceModule(obs_dim) |
| self.exploration = ExplorationPolicy(num_actions, grid_size) |
| self.planning = PlanningModule(rssm, self.goal_module, jepa) |
| self.uncertainty_tracker = UncertaintyTracker() |
| |
| self.persistent_h: Optional[torch.Tensor] = None |
| self.persistent_z: Optional[torch.Tensor] = None |
| self.demo_buffer: List[Dict] = [] |
| self.step_counter: int = 0 |
| |
| def reset_for_new_environment(self): |
| """Reset ALL state when starting a completely new environment/game.""" |
| self.persistent_h = None |
| self.persistent_z = None |
| self.exploration.reset() |
| self.goal_module.observed_goals.clear() |
| self.demo_buffer.clear() |
| self.step_counter = 0 |
| self.uncertainty_tracker = UncertaintyTracker() |
| |
| def reset_for_new_level(self): |
| """Reset level-specific state but PERSIST world model knowledge.""" |
| self.exploration.reset() |
| |
| |
| def encode_observation(self, grid: torch.Tensor) -> torch.Tensor: |
| return self.jepa.encode(grid) |
| |
| def step( |
| self, |
| grid: torch.Tensor, |
| reward: Optional[float] = None, |
| done: bool = False, |
| is_exploration_phase: bool = False, |
| ) -> Tuple[int, int]: |
| grid = grid.to(self.device) |
| obs_repr = self.encode_observation(grid) |
| obs_repr_pooled = obs_repr.mean(dim=1) |
| |
| if self.persistent_h is None: |
| self.persistent_h, self.persistent_z = self.rssm.init_state(1, self.device) |
| |
| if len(self.exploration.action_history) == 0: |
| prev_action = torch.zeros(1, dtype=torch.long, device=self.device) |
| else: |
| last_key, last_pos = self.exploration.action_history[-1] |
| prev_action = torch.tensor([last_key * self.num_positions + last_pos], device=self.device) |
| |
| self.persistent_h, self.persistent_z, _, _ = self.rssm.observe( |
| obs_repr_pooled, prev_action, self.persistent_h, self.persistent_z |
| ) |
| |
| if done: |
| self.goal_module.register_terminal(obs_repr_pooled) |
| |
| |
| if self.uncertainty_tracker.should_revise_hypothesis(): |
| |
| self.exploration.reset() |
| self.uncertainty_tracker.mark_revision(self.step_counter) |
| |
| if is_exploration_phase or random.random() < self.exploration_ratio: |
| action_key, action_pos = self.exploration.select_action(grid[0]) |
| else: |
| action_key, action_pos = self.planning.plan_action( |
| grid, self.persistent_h, self.persistent_z, self.num_actions, self.device |
| ) |
| |
| self.demo_buffer.append({ |
| "grid": grid[0].cpu().clone(), |
| "action_key": action_key, |
| "action_pos": action_pos, |
| "obs_repr": obs_repr_pooled.detach().cpu().clone(), |
| "h_state": self.persistent_h.detach().cpu().clone(), |
| "z_state": self.persistent_z.detach().cpu().clone(), |
| }) |
| |
| self.step_counter += 1 |
| return action_key, action_pos |
| |
| def run_level(self, env, max_steps: int = 100, exploration_steps: int = 10) -> Dict: |
| trajectory = [] |
| for step_idx in range(max_steps): |
| grid = env.get_observation().unsqueeze(0) |
| reward, done = env.get_reward(), env.is_done() |
| is_exploration = step_idx < exploration_steps |
| action_key, action_pos = self.step(grid, reward, done, is_exploration) |
| env.step(action_key, action_pos) |
| trajectory.append({ |
| "step": step_idx, |
| "action_key": action_key, |
| "action_pos": action_pos, |
| "reward": reward, |
| "done": done, |
| }) |
| if done: |
| break |
| return {"trajectory": trajectory, "num_steps": len(trajectory), "success": done} |
|
|
|
|
| def create_agent(num_colors: int = 16, embed_dim: int = 384, grid_size: int = 64, |
| num_actions: int = 6, device: str = "cuda") -> ARCAgent: |
| jepa = GridJEPA(num_colors=num_colors, embed_dim=embed_dim, encoder_depth=12, |
| predictor_depth=12, num_heads=6, max_grid_size=grid_size) |
| rssm = RSSM(embed_dim=embed_dim, latent_dim=32, latent_classes=32, hidden_dim=256, |
| action_dim=64, num_actions=num_actions * grid_size * grid_size, obs_dim=embed_dim) |
| agent = ARCAgent(jepa=jepa, rssm=rssm, num_actions=num_actions, grid_size=grid_size, device=device) |
| agent = agent.to(device) |
| return agent |
|
|
|
|
| if __name__ == "__main__": |
| import numpy as np |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| agent = create_agent(num_colors=10, embed_dim=192, grid_size=10, num_actions=6, device=device) |
| |
| class MockEnv: |
| def __init__(self, size=10): |
| self.grid = torch.zeros(size, size, dtype=torch.long) |
| self.grid[size//2, size//2] = 1 |
| self.step_count = 0 |
| self.max_steps = 20 |
| def get_observation(self): |
| return self.grid |
| def get_reward(self): |
| return 0.0 |
| def is_done(self): |
| return self.step_count >= self.max_steps |
| def step(self, action_key, action_pos): |
| r = action_pos // self.grid.shape[0] |
| c = action_pos % self.grid.shape[0] |
| if 0 <= r < self.grid.shape[0] and 0 <= c < self.grid.shape[1]: |
| self.grid[r, c] = action_key |
| self.step_count += 1 |
| |
| env = MockEnv(size=10) |
| grid = env.get_observation().unsqueeze(0).to(device) |
| action_key, action_pos = agent.step(grid) |
| print(f"Action: key={action_key}, pos={action_pos}") |
| |
| agent.reset_for_new_environment() |
| result = agent.run_level(env, max_steps=15, exploration_steps=5) |
| print(f"Level result: {result['num_steps']} steps, success={result['success']}") |
| |
| h_before = agent.persistent_h.clone() if agent.persistent_h is not None else None |
| env2 = MockEnv(size=10) |
| agent.reset_for_new_level() |
| result2 = agent.run_level(env2, max_steps=10, exploration_steps=3) |
| h_after = agent.persistent_h.clone() if agent.persistent_h is not None else None |
| |
| if h_before is not None and h_after is not None: |
| state_persisted = not torch.allclose(h_before, torch.zeros_like(h_before)) |
| print(f"State persisted across levels: {state_persisted}") |
| |
| print("\nAgent tests passed!") |
|
|