Cascade / cascade /hold /session.py
tostido's picture
Initial commit - cascade-lattice 0.5.4
77bcbf1
"""
HOLD Session - Arcade-Style Inference Interception
══════════════════════════════════════════════════════════
"Pause the machine. See what it sees. Choose what it chooses."
The arcade layer of HOLD:
- CausationHold: Session management with history
- InferenceStep: Single crystallized moment
- Time travel via state snapshots
- Speed controls and combo tracking
Controls:
SPACE - Accept model's choice, advance
1-9 - Override with alternative
←/β†’ - Step back/forward through history
+/- - Speed up/slow down auto-advance
P - Pause/unpause auto-advance
ESC - Exit hold mode
"""
import numpy as np
import time
import json
import hashlib
import threading
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable, Tuple
from datetime import datetime
from pathlib import Path
from enum import Enum
class SessionState(Enum):
"""Current state of the hold session."""
IDLE = "idle" # Not holding anything
PAUSED = "paused" # Frozen, waiting for input
STEPPING = "stepping" # Auto-advancing at set speed
REWINDING = "rewinding" # Going backwards through history
@dataclass
class InferenceStep:
"""A single crystallized moment of inference."""
step_id: str
step_index: int
timestamp: float
# What the model sees
input_context: Dict[str, Any]
# What the model wants to do
candidates: List[Dict[str, Any]] # [{value, probability, metadata}]
top_choice: Any
top_probability: float
# Internal state snapshot (for true rewind)
hidden_state: Optional[np.ndarray] = None
attention_weights: Optional[Dict[str, float]] = None
# What actually happened
chosen_value: Any = None
was_override: bool = False
override_by: str = "model" # "model" or "human"
# Provenance
cascade_hash: Optional[str] = None
# Private: full state snapshot for true rewind
_state_snapshot: Optional[Dict[str, Any]] = field(default=None, repr=False)
@dataclass
class HoldSession:
"""A complete hold session with history."""
session_id: str
agent_id: str
started_at: float
# All steps in order
steps: List[InferenceStep] = field(default_factory=list)
current_index: int = 0
# Arcade stats
total_steps: int = 0
human_overrides: int = 0
correct_predictions: int = 0 # Human guessed what model would do
combo: int = 0
max_combo: int = 0
# Speed control (steps per second, 0 = manual only)
speed_level: int = 0 # 0=manual, 1=slow, 2=medium, 3=fast, 4=ludicrous
speed_map: Dict[int, float] = field(default_factory=lambda: {
0: 0.0, # Manual
1: 0.5, # 2 sec per step
2: 1.0, # 1 sec per step
3: 2.0, # 0.5 sec per step
4: 10.0, # 0.1 sec per step (ludicrous speed)
})
# State
state: SessionState = SessionState.IDLE
@dataclass
class ArcadeFeedback:
"""Visual/audio feedback cues."""
message: str
intensity: float # 0-1, for glow/shake/etc
sound_cue: str # "accept", "override", "combo", "combo_break", "rewind"
color: Tuple[int, int, int] = (255, 255, 255)
class CausationHold:
"""
The arcade-layer hold system. Wraps any inference function.
Features:
- Session management with full history
- True state restoration for time travel
- Speed controls (manual to ludicrous)
- Combo tracking and high scores
Usage:
hold = CausationHold()
# Start a session
hold.begin_session(agent_id="agent_123")
# In inference loop:
for step in inference_steps:
choice, feedback = hold.capture(
input_context={"tokens": tokens},
candidates=[{"value": "A", "probability": 0.8}, ...]
) # Pauses here until user input!
# Time travel
hold.rewind(steps=3)
hold.branch_from(step_index=5, choice_index=2)
stats = hold.end_session()
"""
def __init__(self, cascade_bus=None):
"""
Args:
cascade_bus: Optional CASCADE event bus for provenance
"""
self.bus = cascade_bus
self.session: Optional[HoldSession] = None
self.callbacks: Dict[str, List[Callable]] = {
'on_step': [],
'on_override': [],
'on_combo': [],
'on_combo_break': [],
'on_rewind': [],
'on_state_restore': [],
}
# Thread safety
self._lock = threading.Lock()
self._input_event = threading.Event()
self._user_choice: Optional[Any] = None
# High scores (persisted)
self.high_scores_path = Path("data/hold_high_scores.json")
self.high_scores = self._load_high_scores()
# ========================================================================
# SESSION MANAGEMENT
# ========================================================================
def begin_session(self, agent_id: str) -> HoldSession:
"""Start a new hold session."""
session_id = f"hold_{agent_id}_{int(time.time()*1000)}"
self.session = HoldSession(
session_id=session_id,
agent_id=agent_id,
started_at=time.time(),
)
self.session.state = SessionState.PAUSED
self._emit_cascade("hold_session_start", {
"session_id": session_id,
"agent_id": agent_id,
})
return self.session
def end_session(self) -> Dict[str, Any]:
"""End session and return stats."""
if not self.session:
return {}
stats = {
"session_id": self.session.session_id,
"agent_id": self.session.agent_id,
"duration": time.time() - self.session.started_at,
"total_steps": self.session.total_steps,
"human_overrides": self.session.human_overrides,
"correct_predictions": self.session.correct_predictions,
"max_combo": self.session.max_combo,
"accuracy": (
self.session.correct_predictions / max(1, self.session.total_steps)
),
}
# Check for high score
self._check_high_score(stats)
self._emit_cascade("hold_session_end", stats)
self.session = None
return stats
# ========================================================================
# CAPTURE & ADVANCE - WITH STATE SNAPSHOT FOR TRUE REWIND
# ========================================================================
def capture(
self,
input_context: Dict[str, Any],
candidates: List[Dict[str, Any]],
hidden_state: Optional[np.ndarray] = None,
attention: Optional[Dict[str, float]] = None,
state_snapshot: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, ArcadeFeedback]:
"""
Capture an inference step. BLOCKS until user input or auto-advance.
IMPORTANT: Pass state_snapshot for true rewind capability.
This should be a complete snapshot of the model's internal state
that can be restored to allow execution from this decision point
with a different choice.
This is NOT prediction - you will ACTUALLY execute the choice and
see REAL outcomes. If you don't like them, rewind and try again.
Args:
input_context: What the model is looking at
candidates: List of {value, probability, ...} options
hidden_state: Optional internal state snapshot (deprecated, use state_snapshot)
attention: Optional attention weights
state_snapshot: Complete model state for TRUE rewind capability
Returns:
(chosen_value, feedback) - The value to use and arcade feedback
"""
if not self.session:
# No session = passthrough, just return top choice
return candidates[0]['value'], ArcadeFeedback("", 0, "")
# Sort candidates by probability
candidates = sorted(candidates, key=lambda x: x.get('probability', 0), reverse=True)
top = candidates[0]
# Merge hidden_state into state_snapshot if provided separately
if state_snapshot is None and hidden_state is not None:
state_snapshot = {'hidden_state': hidden_state}
elif state_snapshot is not None and hidden_state is not None:
state_snapshot['hidden_state'] = hidden_state
# Create step - this is a CHECKPOINT for true rewind
step = InferenceStep(
step_id=f"step_{self.session.total_steps}",
step_index=self.session.total_steps,
timestamp=time.time(),
input_context=input_context,
candidates=candidates,
top_choice=top['value'],
top_probability=top.get('probability', 1.0),
hidden_state=hidden_state,
attention_weights=attention,
)
# Store state snapshot for TRUE rewind (not just history navigation)
if state_snapshot is not None:
step._state_snapshot = state_snapshot
# Compute merkle hash for provenance
step.cascade_hash = self._compute_step_hash(step)
# Add to history
with self._lock:
self.session.steps.append(step)
self.session.current_index = len(self.session.steps) - 1
self.session.total_steps += 1
# Emit step event
self._emit_callback('on_step', step)
self._emit_cascade("hold_step", {
"step_index": step.step_index,
"top_choice": str(top['value']),
"top_prob": top.get('probability', 1.0),
"num_candidates": len(candidates),
"has_snapshot": state_snapshot is not None,
"merkle": step.cascade_hash,
})
# Wait for input
choice, feedback = self._wait_for_input(step)
# Record what happened
step.chosen_value = choice
step.was_override = (choice != top['value'])
step.override_by = "human" if step.was_override else "model"
if step.was_override:
self.session.human_overrides += 1
self._emit_callback('on_override', step, choice)
return choice, feedback
def _wait_for_input(self, step: InferenceStep) -> Tuple[Any, ArcadeFeedback]:
"""Wait for user input or auto-advance timer."""
# Manual mode = wait indefinitely
if self.session.speed_level == 0:
self._input_event.clear()
self._input_event.wait() # Blocks until input()
choice = self._user_choice
self._user_choice = None
else:
# Auto-advance mode
speed = self.session.speed_map[self.session.speed_level]
wait_time = 1.0 / speed if speed > 0 else float('inf')
self._input_event.clear()
got_input = self._input_event.wait(timeout=wait_time)
if got_input and self._user_choice is not None:
choice = self._user_choice
self._user_choice = None
else:
# Auto-accepted
choice = step.top_choice
# Generate feedback
return choice, self._generate_feedback(step, choice)
def input(self, choice: Any):
"""
Provide user input. Call from UI thread.
Args:
choice: The value to use (or index into candidates)
"""
if not self.session:
return
current_step = self.session.steps[self.session.current_index]
# Handle index input (1-9 keys)
if isinstance(choice, int) and 0 <= choice < len(current_step.candidates):
choice = current_step.candidates[choice]['value']
self._user_choice = choice
self._input_event.set()
def accept(self):
"""Accept model's top choice (SPACE key)."""
if not self.session or not self.session.steps:
return
current = self.session.steps[self.session.current_index]
self.input(current.top_choice)
def override(self, index: int):
"""Override with candidate at index (1-9 keys)."""
self.input(index)
# ========================================================================
# NAVIGATION (TIME TRAVEL) - TRUE STATE RESTORATION
# ========================================================================
def rewind(self, steps: int = 1, restore_state: bool = True) -> Optional[InferenceStep]:
"""
Go back in history with optional state restoration.
This is NOT simulation - we actually restore the model's internal state
to the snapshot taken at that decision point. From there, you can
execute a different branch and see REAL outcomes.
Args:
steps: Number of steps to go back
restore_state: If True, actually restore hidden_state to model
Returns:
The step we rewound to
"""
if not self.session:
return None
with self._lock:
new_index = max(0, self.session.current_index - steps)
if new_index != self.session.current_index:
self.session.current_index = new_index
self.session.state = SessionState.REWINDING
step = self.session.steps[new_index]
# TRUE STATE RESTORATION
if restore_state and step.hidden_state is not None:
self._restore_state(step)
self._emit_callback('on_rewind', step, -steps)
return step
return None
def _restore_state(self, step: InferenceStep):
"""
Restore model state from a snapshot.
This is the key that makes execution + rewind possible.
The model's internal state is set back to exactly what it was
at this decision point, allowing you to branch differently.
"""
if step.hidden_state is None and step._state_snapshot is None:
return
# Emit state restoration event - hooked components can restore themselves
self._emit_callback('on_state_restore', step)
self._emit_cascade("state_restored", {
"step_index": step.step_index,
"merkle": step.cascade_hash,
"had_hidden_state": step.hidden_state is not None,
"had_snapshot": step._state_snapshot is not None,
})
def branch_from(self, step_index: int, choice_index: int) -> Optional[InferenceStep]:
"""
Rewind to a step and immediately choose a different branch.
This is the core gameplay loop:
1. Rewind to decision point
2. Choose different option
3. Execute and see what happens
4. Repeat until satisfied
Args:
step_index: Which decision point to branch from
choice_index: Which candidate to choose (0 = model's choice)
Returns:
The step after branching (with state restored)
"""
step = self.jump_to(step_index)
if step is None:
return None
# Restore state
self._restore_state(step)
# Set up the override
if choice_index < len(step.candidates):
self.override(choice_index)
else:
self.accept()
return step
def forward(self, steps: int = 1) -> Optional[InferenceStep]:
"""Go forward in history (if we've rewound)."""
if not self.session:
return None
with self._lock:
max_index = len(self.session.steps) - 1
new_index = min(max_index, self.session.current_index + steps)
if new_index != self.session.current_index:
self.session.current_index = new_index
step = self.session.steps[new_index]
self._emit_callback('on_rewind', step, steps)
return step
return None
def jump_to(self, index: int) -> Optional[InferenceStep]:
"""Jump to specific step."""
if not self.session:
return None
with self._lock:
index = max(0, min(index, len(self.session.steps) - 1))
self.session.current_index = index
return self.session.steps[index]
# ========================================================================
# SPEED CONTROL
# ========================================================================
def speed_up(self):
"""Increase auto-advance speed."""
if self.session:
self.session.speed_level = min(4, self.session.speed_level + 1)
def speed_down(self):
"""Decrease auto-advance speed."""
if self.session:
self.session.speed_level = max(0, self.session.speed_level - 1)
def set_speed(self, level: int):
"""Set speed level directly (0-4)."""
if self.session:
self.session.speed_level = max(0, min(4, level))
def pause(self):
"""Pause auto-advance."""
if self.session:
self.session.state = SessionState.PAUSED
def unpause(self):
"""Resume auto-advance."""
if self.session:
self.session.state = SessionState.STEPPING
# ========================================================================
# PROVENANCE HASHING
# ========================================================================
def _compute_step_hash(self, step: InferenceStep) -> str:
"""
Compute merkle hash for a step.
This hash uniquely identifies this decision point and allows
verification that rewind is restoring to the exact right state.
"""
# Include parent hash for chain integrity
parent_hash = ""
if self.session and len(self.session.steps) > 0:
prev_step = self.session.steps[-1]
parent_hash = prev_step.cascade_hash or ""
content = json.dumps({
'step_index': step.step_index,
'timestamp': step.timestamp,
'top_choice': str(step.top_choice),
'top_prob': step.top_probability,
'num_candidates': len(step.candidates),
'parent_hash': parent_hash,
}, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:16]
# ========================================================================
# ARCADE FEEDBACK
# ========================================================================
def _generate_feedback(self, step: InferenceStep, choice: Any) -> ArcadeFeedback:
"""Generate arcade-style feedback for a step."""
is_override = (choice != step.top_choice)
if is_override:
# Combo break!
if self.session.combo > 0:
self._emit_callback('on_combo_break', self.session.combo)
self.session.combo = 0
return ArcadeFeedback(
message="OVERRIDE",
intensity=0.8,
sound_cue="override",
color=(255, 165, 0), # Orange
)
else:
# Accepted model choice
self.session.combo += 1
self.session.max_combo = max(self.session.max_combo, self.session.combo)
# Combo milestones
if self.session.combo in [10, 25, 50, 100]:
self._emit_callback('on_combo', self.session.combo)
return ArcadeFeedback(
message=f"COMBO x{self.session.combo}!",
intensity=1.0,
sound_cue="combo",
color=(0, 255, 255), # Cyan
)
# Regular accept
return ArcadeFeedback(
message="",
intensity=0.3 + min(0.5, self.session.combo * 0.02),
sound_cue="accept",
color=(0, 255, 0), # Green
)
# ========================================================================
# CALLBACKS
# ========================================================================
def on(self, event: str, callback: Callable):
"""Register callback for events."""
if event in self.callbacks:
self.callbacks[event].append(callback)
def _emit_callback(self, event: str, *args):
"""Emit event to callbacks."""
for cb in self.callbacks.get(event, []):
try:
cb(*args)
except Exception as e:
print(f"Callback error: {e}")
# ========================================================================
# CASCADE PROVENANCE
# ========================================================================
def _emit_cascade(self, event_type: str, data: Dict[str, Any]):
"""Emit event to CASCADE bus if available."""
if self.bus:
try:
self.bus.emit(event_type, {
**data,
"source": "causation_hold",
"timestamp": time.time(),
})
except Exception:
pass
# ========================================================================
# HIGH SCORES
# ========================================================================
def _load_high_scores(self) -> Dict[str, Any]:
"""Load high scores from disk."""
if self.high_scores_path.exists():
try:
return json.loads(self.high_scores_path.read_text())
except Exception:
pass
return {"max_combo": 0, "best_accuracy": 0.0, "total_sessions": 0}
def _save_high_scores(self):
"""Save high scores to disk."""
self.high_scores_path.parent.mkdir(parents=True, exist_ok=True)
self.high_scores_path.write_text(json.dumps(self.high_scores, indent=2))
def _check_high_score(self, stats: Dict[str, Any]):
"""Check and update high scores."""
updated = False
if stats['max_combo'] > self.high_scores['max_combo']:
self.high_scores['max_combo'] = stats['max_combo']
updated = True
if stats['accuracy'] > self.high_scores['best_accuracy']:
self.high_scores['best_accuracy'] = stats['accuracy']
updated = True
self.high_scores['total_sessions'] += 1
if updated:
self._save_high_scores()
# ========================================================================
# DECORATOR FOR EASY WRAPPING
# ========================================================================
def intercept(self, granularity: str = "step"):
"""
Decorator to intercept a function's inference.
Args:
granularity: "step" (each call) or "token" (if function yields)
"""
def decorator(func):
def wrapper(*args, **kwargs):
# If no session, passthrough
if not self.session:
return func(*args, **kwargs)
# Capture the input
input_context = {
"args": str(args)[:200],
"kwargs": {k: str(v)[:100] for k, v in kwargs.items()},
}
# Get result
result = func(*args, **kwargs)
# Create candidates from result
if isinstance(result, np.ndarray):
# For embeddings, show top dimensions
top_dims = np.argsort(np.abs(result.flatten()))[-5:][::-1]
candidates = [
{"value": f"dim_{d}", "probability": float(np.abs(result.flatten()[d]))}
for d in top_dims
]
else:
candidates = [{"value": result, "probability": 1.0}]
# Capture (may block)
choice, feedback = self.capture(input_context, candidates)
return result
return wrapper
return decorator