Spaces:
Configuration error
Configuration error
| """ | |
| HOLD Session - Arcade-Style Inference Interception | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| "Pause the machine. See what it sees. Choose what it chooses." | |
| The arcade layer of HOLD: | |
| - CausationHold: Session management with history | |
| - InferenceStep: Single crystallized moment | |
| - Time travel via state snapshots | |
| - Speed controls and combo tracking | |
| Controls: | |
| SPACE - Accept model's choice, advance | |
| 1-9 - Override with alternative | |
| β/β - Step back/forward through history | |
| +/- - Speed up/slow down auto-advance | |
| P - Pause/unpause auto-advance | |
| ESC - Exit hold mode | |
| """ | |
| import numpy as np | |
| import time | |
| import json | |
| import hashlib | |
| import threading | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Any, Callable, Tuple | |
| from datetime import datetime | |
| from pathlib import Path | |
| from enum import Enum | |
| class SessionState(Enum): | |
| """Current state of the hold session.""" | |
| IDLE = "idle" # Not holding anything | |
| PAUSED = "paused" # Frozen, waiting for input | |
| STEPPING = "stepping" # Auto-advancing at set speed | |
| REWINDING = "rewinding" # Going backwards through history | |
| class InferenceStep: | |
| """A single crystallized moment of inference.""" | |
| step_id: str | |
| step_index: int | |
| timestamp: float | |
| # What the model sees | |
| input_context: Dict[str, Any] | |
| # What the model wants to do | |
| candidates: List[Dict[str, Any]] # [{value, probability, metadata}] | |
| top_choice: Any | |
| top_probability: float | |
| # Internal state snapshot (for true rewind) | |
| hidden_state: Optional[np.ndarray] = None | |
| attention_weights: Optional[Dict[str, float]] = None | |
| # What actually happened | |
| chosen_value: Any = None | |
| was_override: bool = False | |
| override_by: str = "model" # "model" or "human" | |
| # Provenance | |
| cascade_hash: Optional[str] = None | |
| # Private: full state snapshot for true rewind | |
| _state_snapshot: Optional[Dict[str, Any]] = field(default=None, repr=False) | |
| class HoldSession: | |
| """A complete hold session with history.""" | |
| session_id: str | |
| agent_id: str | |
| started_at: float | |
| # All steps in order | |
| steps: List[InferenceStep] = field(default_factory=list) | |
| current_index: int = 0 | |
| # Arcade stats | |
| total_steps: int = 0 | |
| human_overrides: int = 0 | |
| correct_predictions: int = 0 # Human guessed what model would do | |
| combo: int = 0 | |
| max_combo: int = 0 | |
| # Speed control (steps per second, 0 = manual only) | |
| speed_level: int = 0 # 0=manual, 1=slow, 2=medium, 3=fast, 4=ludicrous | |
| speed_map: Dict[int, float] = field(default_factory=lambda: { | |
| 0: 0.0, # Manual | |
| 1: 0.5, # 2 sec per step | |
| 2: 1.0, # 1 sec per step | |
| 3: 2.0, # 0.5 sec per step | |
| 4: 10.0, # 0.1 sec per step (ludicrous speed) | |
| }) | |
| # State | |
| state: SessionState = SessionState.IDLE | |
| class ArcadeFeedback: | |
| """Visual/audio feedback cues.""" | |
| message: str | |
| intensity: float # 0-1, for glow/shake/etc | |
| sound_cue: str # "accept", "override", "combo", "combo_break", "rewind" | |
| color: Tuple[int, int, int] = (255, 255, 255) | |
| class CausationHold: | |
| """ | |
| The arcade-layer hold system. Wraps any inference function. | |
| Features: | |
| - Session management with full history | |
| - True state restoration for time travel | |
| - Speed controls (manual to ludicrous) | |
| - Combo tracking and high scores | |
| Usage: | |
| hold = CausationHold() | |
| # Start a session | |
| hold.begin_session(agent_id="agent_123") | |
| # In inference loop: | |
| for step in inference_steps: | |
| choice, feedback = hold.capture( | |
| input_context={"tokens": tokens}, | |
| candidates=[{"value": "A", "probability": 0.8}, ...] | |
| ) # Pauses here until user input! | |
| # Time travel | |
| hold.rewind(steps=3) | |
| hold.branch_from(step_index=5, choice_index=2) | |
| stats = hold.end_session() | |
| """ | |
| def __init__(self, cascade_bus=None): | |
| """ | |
| Args: | |
| cascade_bus: Optional CASCADE event bus for provenance | |
| """ | |
| self.bus = cascade_bus | |
| self.session: Optional[HoldSession] = None | |
| self.callbacks: Dict[str, List[Callable]] = { | |
| 'on_step': [], | |
| 'on_override': [], | |
| 'on_combo': [], | |
| 'on_combo_break': [], | |
| 'on_rewind': [], | |
| 'on_state_restore': [], | |
| } | |
| # Thread safety | |
| self._lock = threading.Lock() | |
| self._input_event = threading.Event() | |
| self._user_choice: Optional[Any] = None | |
| # High scores (persisted) | |
| self.high_scores_path = Path("data/hold_high_scores.json") | |
| self.high_scores = self._load_high_scores() | |
| # ======================================================================== | |
| # SESSION MANAGEMENT | |
| # ======================================================================== | |
| def begin_session(self, agent_id: str) -> HoldSession: | |
| """Start a new hold session.""" | |
| session_id = f"hold_{agent_id}_{int(time.time()*1000)}" | |
| self.session = HoldSession( | |
| session_id=session_id, | |
| agent_id=agent_id, | |
| started_at=time.time(), | |
| ) | |
| self.session.state = SessionState.PAUSED | |
| self._emit_cascade("hold_session_start", { | |
| "session_id": session_id, | |
| "agent_id": agent_id, | |
| }) | |
| return self.session | |
| def end_session(self) -> Dict[str, Any]: | |
| """End session and return stats.""" | |
| if not self.session: | |
| return {} | |
| stats = { | |
| "session_id": self.session.session_id, | |
| "agent_id": self.session.agent_id, | |
| "duration": time.time() - self.session.started_at, | |
| "total_steps": self.session.total_steps, | |
| "human_overrides": self.session.human_overrides, | |
| "correct_predictions": self.session.correct_predictions, | |
| "max_combo": self.session.max_combo, | |
| "accuracy": ( | |
| self.session.correct_predictions / max(1, self.session.total_steps) | |
| ), | |
| } | |
| # Check for high score | |
| self._check_high_score(stats) | |
| self._emit_cascade("hold_session_end", stats) | |
| self.session = None | |
| return stats | |
| # ======================================================================== | |
| # CAPTURE & ADVANCE - WITH STATE SNAPSHOT FOR TRUE REWIND | |
| # ======================================================================== | |
| def capture( | |
| self, | |
| input_context: Dict[str, Any], | |
| candidates: List[Dict[str, Any]], | |
| hidden_state: Optional[np.ndarray] = None, | |
| attention: Optional[Dict[str, float]] = None, | |
| state_snapshot: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[Any, ArcadeFeedback]: | |
| """ | |
| Capture an inference step. BLOCKS until user input or auto-advance. | |
| IMPORTANT: Pass state_snapshot for true rewind capability. | |
| This should be a complete snapshot of the model's internal state | |
| that can be restored to allow execution from this decision point | |
| with a different choice. | |
| This is NOT prediction - you will ACTUALLY execute the choice and | |
| see REAL outcomes. If you don't like them, rewind and try again. | |
| Args: | |
| input_context: What the model is looking at | |
| candidates: List of {value, probability, ...} options | |
| hidden_state: Optional internal state snapshot (deprecated, use state_snapshot) | |
| attention: Optional attention weights | |
| state_snapshot: Complete model state for TRUE rewind capability | |
| Returns: | |
| (chosen_value, feedback) - The value to use and arcade feedback | |
| """ | |
| if not self.session: | |
| # No session = passthrough, just return top choice | |
| return candidates[0]['value'], ArcadeFeedback("", 0, "") | |
| # Sort candidates by probability | |
| candidates = sorted(candidates, key=lambda x: x.get('probability', 0), reverse=True) | |
| top = candidates[0] | |
| # Merge hidden_state into state_snapshot if provided separately | |
| if state_snapshot is None and hidden_state is not None: | |
| state_snapshot = {'hidden_state': hidden_state} | |
| elif state_snapshot is not None and hidden_state is not None: | |
| state_snapshot['hidden_state'] = hidden_state | |
| # Create step - this is a CHECKPOINT for true rewind | |
| step = InferenceStep( | |
| step_id=f"step_{self.session.total_steps}", | |
| step_index=self.session.total_steps, | |
| timestamp=time.time(), | |
| input_context=input_context, | |
| candidates=candidates, | |
| top_choice=top['value'], | |
| top_probability=top.get('probability', 1.0), | |
| hidden_state=hidden_state, | |
| attention_weights=attention, | |
| ) | |
| # Store state snapshot for TRUE rewind (not just history navigation) | |
| if state_snapshot is not None: | |
| step._state_snapshot = state_snapshot | |
| # Compute merkle hash for provenance | |
| step.cascade_hash = self._compute_step_hash(step) | |
| # Add to history | |
| with self._lock: | |
| self.session.steps.append(step) | |
| self.session.current_index = len(self.session.steps) - 1 | |
| self.session.total_steps += 1 | |
| # Emit step event | |
| self._emit_callback('on_step', step) | |
| self._emit_cascade("hold_step", { | |
| "step_index": step.step_index, | |
| "top_choice": str(top['value']), | |
| "top_prob": top.get('probability', 1.0), | |
| "num_candidates": len(candidates), | |
| "has_snapshot": state_snapshot is not None, | |
| "merkle": step.cascade_hash, | |
| }) | |
| # Wait for input | |
| choice, feedback = self._wait_for_input(step) | |
| # Record what happened | |
| step.chosen_value = choice | |
| step.was_override = (choice != top['value']) | |
| step.override_by = "human" if step.was_override else "model" | |
| if step.was_override: | |
| self.session.human_overrides += 1 | |
| self._emit_callback('on_override', step, choice) | |
| return choice, feedback | |
| def _wait_for_input(self, step: InferenceStep) -> Tuple[Any, ArcadeFeedback]: | |
| """Wait for user input or auto-advance timer.""" | |
| # Manual mode = wait indefinitely | |
| if self.session.speed_level == 0: | |
| self._input_event.clear() | |
| self._input_event.wait() # Blocks until input() | |
| choice = self._user_choice | |
| self._user_choice = None | |
| else: | |
| # Auto-advance mode | |
| speed = self.session.speed_map[self.session.speed_level] | |
| wait_time = 1.0 / speed if speed > 0 else float('inf') | |
| self._input_event.clear() | |
| got_input = self._input_event.wait(timeout=wait_time) | |
| if got_input and self._user_choice is not None: | |
| choice = self._user_choice | |
| self._user_choice = None | |
| else: | |
| # Auto-accepted | |
| choice = step.top_choice | |
| # Generate feedback | |
| return choice, self._generate_feedback(step, choice) | |
| def input(self, choice: Any): | |
| """ | |
| Provide user input. Call from UI thread. | |
| Args: | |
| choice: The value to use (or index into candidates) | |
| """ | |
| if not self.session: | |
| return | |
| current_step = self.session.steps[self.session.current_index] | |
| # Handle index input (1-9 keys) | |
| if isinstance(choice, int) and 0 <= choice < len(current_step.candidates): | |
| choice = current_step.candidates[choice]['value'] | |
| self._user_choice = choice | |
| self._input_event.set() | |
| def accept(self): | |
| """Accept model's top choice (SPACE key).""" | |
| if not self.session or not self.session.steps: | |
| return | |
| current = self.session.steps[self.session.current_index] | |
| self.input(current.top_choice) | |
| def override(self, index: int): | |
| """Override with candidate at index (1-9 keys).""" | |
| self.input(index) | |
| # ======================================================================== | |
| # NAVIGATION (TIME TRAVEL) - TRUE STATE RESTORATION | |
| # ======================================================================== | |
| def rewind(self, steps: int = 1, restore_state: bool = True) -> Optional[InferenceStep]: | |
| """ | |
| Go back in history with optional state restoration. | |
| This is NOT simulation - we actually restore the model's internal state | |
| to the snapshot taken at that decision point. From there, you can | |
| execute a different branch and see REAL outcomes. | |
| Args: | |
| steps: Number of steps to go back | |
| restore_state: If True, actually restore hidden_state to model | |
| Returns: | |
| The step we rewound to | |
| """ | |
| if not self.session: | |
| return None | |
| with self._lock: | |
| new_index = max(0, self.session.current_index - steps) | |
| if new_index != self.session.current_index: | |
| self.session.current_index = new_index | |
| self.session.state = SessionState.REWINDING | |
| step = self.session.steps[new_index] | |
| # TRUE STATE RESTORATION | |
| if restore_state and step.hidden_state is not None: | |
| self._restore_state(step) | |
| self._emit_callback('on_rewind', step, -steps) | |
| return step | |
| return None | |
| def _restore_state(self, step: InferenceStep): | |
| """ | |
| Restore model state from a snapshot. | |
| This is the key that makes execution + rewind possible. | |
| The model's internal state is set back to exactly what it was | |
| at this decision point, allowing you to branch differently. | |
| """ | |
| if step.hidden_state is None and step._state_snapshot is None: | |
| return | |
| # Emit state restoration event - hooked components can restore themselves | |
| self._emit_callback('on_state_restore', step) | |
| self._emit_cascade("state_restored", { | |
| "step_index": step.step_index, | |
| "merkle": step.cascade_hash, | |
| "had_hidden_state": step.hidden_state is not None, | |
| "had_snapshot": step._state_snapshot is not None, | |
| }) | |
| def branch_from(self, step_index: int, choice_index: int) -> Optional[InferenceStep]: | |
| """ | |
| Rewind to a step and immediately choose a different branch. | |
| This is the core gameplay loop: | |
| 1. Rewind to decision point | |
| 2. Choose different option | |
| 3. Execute and see what happens | |
| 4. Repeat until satisfied | |
| Args: | |
| step_index: Which decision point to branch from | |
| choice_index: Which candidate to choose (0 = model's choice) | |
| Returns: | |
| The step after branching (with state restored) | |
| """ | |
| step = self.jump_to(step_index) | |
| if step is None: | |
| return None | |
| # Restore state | |
| self._restore_state(step) | |
| # Set up the override | |
| if choice_index < len(step.candidates): | |
| self.override(choice_index) | |
| else: | |
| self.accept() | |
| return step | |
| def forward(self, steps: int = 1) -> Optional[InferenceStep]: | |
| """Go forward in history (if we've rewound).""" | |
| if not self.session: | |
| return None | |
| with self._lock: | |
| max_index = len(self.session.steps) - 1 | |
| new_index = min(max_index, self.session.current_index + steps) | |
| if new_index != self.session.current_index: | |
| self.session.current_index = new_index | |
| step = self.session.steps[new_index] | |
| self._emit_callback('on_rewind', step, steps) | |
| return step | |
| return None | |
| def jump_to(self, index: int) -> Optional[InferenceStep]: | |
| """Jump to specific step.""" | |
| if not self.session: | |
| return None | |
| with self._lock: | |
| index = max(0, min(index, len(self.session.steps) - 1)) | |
| self.session.current_index = index | |
| return self.session.steps[index] | |
| # ======================================================================== | |
| # SPEED CONTROL | |
| # ======================================================================== | |
| def speed_up(self): | |
| """Increase auto-advance speed.""" | |
| if self.session: | |
| self.session.speed_level = min(4, self.session.speed_level + 1) | |
| def speed_down(self): | |
| """Decrease auto-advance speed.""" | |
| if self.session: | |
| self.session.speed_level = max(0, self.session.speed_level - 1) | |
| def set_speed(self, level: int): | |
| """Set speed level directly (0-4).""" | |
| if self.session: | |
| self.session.speed_level = max(0, min(4, level)) | |
| def pause(self): | |
| """Pause auto-advance.""" | |
| if self.session: | |
| self.session.state = SessionState.PAUSED | |
| def unpause(self): | |
| """Resume auto-advance.""" | |
| if self.session: | |
| self.session.state = SessionState.STEPPING | |
| # ======================================================================== | |
| # PROVENANCE HASHING | |
| # ======================================================================== | |
| def _compute_step_hash(self, step: InferenceStep) -> str: | |
| """ | |
| Compute merkle hash for a step. | |
| This hash uniquely identifies this decision point and allows | |
| verification that rewind is restoring to the exact right state. | |
| """ | |
| # Include parent hash for chain integrity | |
| parent_hash = "" | |
| if self.session and len(self.session.steps) > 0: | |
| prev_step = self.session.steps[-1] | |
| parent_hash = prev_step.cascade_hash or "" | |
| content = json.dumps({ | |
| 'step_index': step.step_index, | |
| 'timestamp': step.timestamp, | |
| 'top_choice': str(step.top_choice), | |
| 'top_prob': step.top_probability, | |
| 'num_candidates': len(step.candidates), | |
| 'parent_hash': parent_hash, | |
| }, sort_keys=True) | |
| return hashlib.sha256(content.encode()).hexdigest()[:16] | |
| # ======================================================================== | |
| # ARCADE FEEDBACK | |
| # ======================================================================== | |
| def _generate_feedback(self, step: InferenceStep, choice: Any) -> ArcadeFeedback: | |
| """Generate arcade-style feedback for a step.""" | |
| is_override = (choice != step.top_choice) | |
| if is_override: | |
| # Combo break! | |
| if self.session.combo > 0: | |
| self._emit_callback('on_combo_break', self.session.combo) | |
| self.session.combo = 0 | |
| return ArcadeFeedback( | |
| message="OVERRIDE", | |
| intensity=0.8, | |
| sound_cue="override", | |
| color=(255, 165, 0), # Orange | |
| ) | |
| else: | |
| # Accepted model choice | |
| self.session.combo += 1 | |
| self.session.max_combo = max(self.session.max_combo, self.session.combo) | |
| # Combo milestones | |
| if self.session.combo in [10, 25, 50, 100]: | |
| self._emit_callback('on_combo', self.session.combo) | |
| return ArcadeFeedback( | |
| message=f"COMBO x{self.session.combo}!", | |
| intensity=1.0, | |
| sound_cue="combo", | |
| color=(0, 255, 255), # Cyan | |
| ) | |
| # Regular accept | |
| return ArcadeFeedback( | |
| message="", | |
| intensity=0.3 + min(0.5, self.session.combo * 0.02), | |
| sound_cue="accept", | |
| color=(0, 255, 0), # Green | |
| ) | |
| # ======================================================================== | |
| # CALLBACKS | |
| # ======================================================================== | |
| def on(self, event: str, callback: Callable): | |
| """Register callback for events.""" | |
| if event in self.callbacks: | |
| self.callbacks[event].append(callback) | |
| def _emit_callback(self, event: str, *args): | |
| """Emit event to callbacks.""" | |
| for cb in self.callbacks.get(event, []): | |
| try: | |
| cb(*args) | |
| except Exception as e: | |
| print(f"Callback error: {e}") | |
| # ======================================================================== | |
| # CASCADE PROVENANCE | |
| # ======================================================================== | |
| def _emit_cascade(self, event_type: str, data: Dict[str, Any]): | |
| """Emit event to CASCADE bus if available.""" | |
| if self.bus: | |
| try: | |
| self.bus.emit(event_type, { | |
| **data, | |
| "source": "causation_hold", | |
| "timestamp": time.time(), | |
| }) | |
| except Exception: | |
| pass | |
| # ======================================================================== | |
| # HIGH SCORES | |
| # ======================================================================== | |
| def _load_high_scores(self) -> Dict[str, Any]: | |
| """Load high scores from disk.""" | |
| if self.high_scores_path.exists(): | |
| try: | |
| return json.loads(self.high_scores_path.read_text()) | |
| except Exception: | |
| pass | |
| return {"max_combo": 0, "best_accuracy": 0.0, "total_sessions": 0} | |
| def _save_high_scores(self): | |
| """Save high scores to disk.""" | |
| self.high_scores_path.parent.mkdir(parents=True, exist_ok=True) | |
| self.high_scores_path.write_text(json.dumps(self.high_scores, indent=2)) | |
| def _check_high_score(self, stats: Dict[str, Any]): | |
| """Check and update high scores.""" | |
| updated = False | |
| if stats['max_combo'] > self.high_scores['max_combo']: | |
| self.high_scores['max_combo'] = stats['max_combo'] | |
| updated = True | |
| if stats['accuracy'] > self.high_scores['best_accuracy']: | |
| self.high_scores['best_accuracy'] = stats['accuracy'] | |
| updated = True | |
| self.high_scores['total_sessions'] += 1 | |
| if updated: | |
| self._save_high_scores() | |
| # ======================================================================== | |
| # DECORATOR FOR EASY WRAPPING | |
| # ======================================================================== | |
| def intercept(self, granularity: str = "step"): | |
| """ | |
| Decorator to intercept a function's inference. | |
| Args: | |
| granularity: "step" (each call) or "token" (if function yields) | |
| """ | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| # If no session, passthrough | |
| if not self.session: | |
| return func(*args, **kwargs) | |
| # Capture the input | |
| input_context = { | |
| "args": str(args)[:200], | |
| "kwargs": {k: str(v)[:100] for k, v in kwargs.items()}, | |
| } | |
| # Get result | |
| result = func(*args, **kwargs) | |
| # Create candidates from result | |
| if isinstance(result, np.ndarray): | |
| # For embeddings, show top dimensions | |
| top_dims = np.argsort(np.abs(result.flatten()))[-5:][::-1] | |
| candidates = [ | |
| {"value": f"dim_{d}", "probability": float(np.abs(result.flatten()[d]))} | |
| for d in top_dims | |
| ] | |
| else: | |
| candidates = [{"value": result, "probability": 1.0}] | |
| # Capture (may block) | |
| choice, feedback = self.capture(input_context, candidates) | |
| return result | |
| return wrapper | |
| return decorator | |