""" Parametric Edit Operations Engine Executes edit operations on palette tensors with pointer-based addressing. Operations preserve scope balance and support region-relative indexing. Key Features: - Parametric actions with arguments - Pointer arithmetic (region-relative → absolute) - Scope balance verification - Macro pattern transformations - Stateless execution (pure functions) """ import torch from dataclasses import dataclass from typing import Tuple, Optional, List from enum import IntEnum # Import RegionMetadata from scope_pooler from .scope_pooler import RegionMetadata class OpCode(IntEnum): """Operation codes for edit actions""" # Control (0-99) NO_OP = 0 MOVE_NEXT = 1 FOCUS_PARENT = 2 # Atomic edits (150-199) DELETE_RANGE = 150 INSERT_TOKEN = 151 REPLACE_TOKEN = 152 SWAP_TOKENS = 153 # Macro transforms (300-349) PY_FOR_TO_JS_FOR = 300 SYNC_TO_ASYNC = 301 ADD_TRY_EXCEPT = 302 # Control flow DONE = 499 @dataclass class EditAction: """ Parametric edit operation with arguments Fields: op_id: Operation code from OpCode enum region_id: Which semantic region to operate on [0, R) i_start: Token index within region (relative addressing) i_end: End token index (for range operations, -1 if unused) payload_idx: Palette index to insert/replace (0-4095) confidence: Model confidence in [0, 1] Invariants: - 0 <= region_id < num_regions - 0 <= i_start < region_length - i_end >= i_start or i_end == -1 - 0 <= payload_idx < 4096 - 0 <= confidence <= 1 """ op_id: int region_id: int i_start: int i_end: int payload_idx: int confidence: float = 1.0 def __post_init__(self): """Validate invariants""" assert self.op_id >= 0, f"Invalid op_id: {self.op_id}" assert self.region_id >= 0, f"Invalid region_id: {self.region_id}" assert self.i_start >= 0, f"Invalid i_start: {self.i_start}" assert self.i_end >= -1, f"Invalid i_end: {self.i_end}" if self.i_end != -1: assert self.i_end >= self.i_start, f"i_end ({self.i_end}) < i_start ({self.i_start})" assert 0 <= self.payload_idx < 4096, f"Invalid payload_idx: {self.payload_idx}" assert 0 <= self.confidence <= 1, f"Invalid confidence: {self.confidence}" # Custom exceptions class EditError(Exception): """Base class for edit errors""" pass class ScopeBalanceError(EditError): """Operation would break scope balance""" pass class InvalidPointerError(EditError): """Pointer out of bounds""" pass class RegionNotFoundError(EditError): """region_id invalid""" pass class PatternNotFoundError(EditError): """Macro pattern not found in region""" pass class PaletteEditOps: """ Stateless edit operation executor All methods are pure functions (no internal state). Thread-safe and deterministic. Constants: START_OF_SCOPE: 0 END_OF_SCOPE: 1 NOOP: 2 """ # Palette constants START_OF_SCOPE = 0 END_OF_SCOPE = 1 NOOP = 2 # Macro pattern definitions MACRO_PATTERNS = { OpCode.PY_FOR_TO_JS_FOR: { 'pattern': [20, 220, 220], # for x in collection 'target': [20, 201, 220], # for(let x of collection) 'name': 'Python for → JavaScript for' }, } @staticmethod def apply( palette_img: torch.Tensor, # (H, W) action: EditAction, metadata: RegionMetadata ) -> Tuple[torch.Tensor, bool]: """ Apply edit action to palette Args: palette_img: (H, W) palette indices action: Edit action with parameters metadata: Region metadata from ScopePooler Returns: new_palette: (H, W) modified palette success: bool - whether operation succeeded Guarantees: - Original palette unchanged (pure function) - If success=False, new_palette == palette_img - Scope balance preserved on success Failure cases: - Invalid region_id - Invalid pointer indices - Scope imbalance after edit """ # Validate inputs if action.region_id >= len(metadata.starts): return palette_img, False # Create working copy palette = palette_img.clone() try: # Pre-check balance if not PaletteEditOps.verify_scope_balance(palette): raise ScopeBalanceError("Input palette has unbalanced scopes") # Dispatch to operation if action.op_id == OpCode.NO_OP: new_palette = palette elif action.op_id == OpCode.DELETE_RANGE: new_palette = PaletteEditOps.delete_range( palette, action.region_id, action.i_start, action.i_end, metadata ) elif action.op_id == OpCode.INSERT_TOKEN: new_palette = PaletteEditOps.insert_token( palette, action.region_id, action.i_start, action.payload_idx, metadata ) elif action.op_id == OpCode.REPLACE_TOKEN: new_palette = PaletteEditOps.replace_token( palette, action.region_id, action.i_start, action.payload_idx, metadata ) elif action.op_id == OpCode.SWAP_TOKENS: new_palette = PaletteEditOps.swap_tokens( palette, action.region_id, action.i_start, action.i_end, metadata ) elif action.op_id in PaletteEditOps.MACRO_PATTERNS: new_palette = PaletteEditOps.macro_transform( palette, action.op_id, action.region_id, metadata ) else: # Unknown operation return palette_img, False # Post-check balance if not PaletteEditOps.verify_scope_balance(new_palette): raise ScopeBalanceError("Operation broke scope balance") return new_palette, True except EditError as e: # Operation failed - return original return palette_img, False @staticmethod def delete_range( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """ Delete tokens in range [i_start, i_end] within region Algorithm: 1. Convert region-relative indices to absolute (h,w) 2. Create mask for tokens to delete 3. Check if deleting scope markers 4. Shift remaining tokens left 5. Pad with NOOP at end Returns: Modified palette Raises: InvalidPointerError: If indices out of bounds ScopeBalanceError: If would break balance """ H, W = palette.shape region_mask = metadata.masks[region_id] # (H, W) # Get region positions in flattened order (content tokens only) positions = PaletteEditOps._get_region_positions(region_mask, W, palette) # Validate indices if i_start < 0 or i_start >= len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds (region size={len(positions)})") if i_end < i_start or i_end >= len(positions): raise InvalidPointerError(f"i_end={i_end} out of bounds") # Convert to absolute positions abs_positions = [positions[i] for i in range(i_start, i_end + 1)] # Check if deleting scope markers palette_flat = palette.flatten() for pos in abs_positions: if palette_flat[pos] in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]: # For now, disallow deleting scope markers # TODO: Implement smart deletion (delete matching pairs) raise ScopeBalanceError("Cannot delete scope markers") # Create deletion mask delete_mask = torch.zeros(H * W, dtype=torch.bool) for pos in abs_positions: delete_mask[pos] = True # Shift tokens left new_palette_flat = palette_flat[~delete_mask] # Pad with NOOP pad_size = H * W - len(new_palette_flat) if pad_size > 0: padding = torch.full((pad_size,), PaletteEditOps.NOOP, dtype=palette.dtype) new_palette_flat = torch.cat([new_palette_flat, padding]) return new_palette_flat.view(H, W) @staticmethod def insert_token( palette: torch.Tensor, region_id: int, i_start: int, payload_idx: int, metadata: RegionMetadata ) -> torch.Tensor: """ Insert token at position i_start within region Algorithm: 1. Convert to absolute position 2. Check if inserting scope marker (requires pair) 3. Shift tokens right from position 4. Insert payload 5. Handle overflow Returns: Modified palette Raises: InvalidPointerError: If index out of bounds ScopeBalanceError: If inserting unpaired scope marker """ H, W = palette.shape region_mask = metadata.masks[region_id] positions = PaletteEditOps._get_region_positions(region_mask, W, palette) # Validate index if i_start < 0 or i_start > len(positions): # Allow insertion at end raise InvalidPointerError(f"i_start={i_start} out of bounds") # Get absolute position if i_start < len(positions): abs_pos = positions[i_start] else: # Insert at end of region abs_pos = positions[-1] + 1 if positions else 0 # Check for scope marker insertion if payload_idx == PaletteEditOps.START_OF_SCOPE: # TODO: Also insert matching END raise ScopeBalanceError("Cannot insert unpaired START_OF_SCOPE") if payload_idx == PaletteEditOps.END_OF_SCOPE: raise ScopeBalanceError("Cannot insert unpaired END_OF_SCOPE") # Flatten and insert palette_flat = palette.flatten() # Create new palette with extra space new_palette_flat = torch.zeros(H * W, dtype=palette.dtype) # Copy before insertion point new_palette_flat[:abs_pos] = palette_flat[:abs_pos] # Insert payload new_palette_flat[abs_pos] = payload_idx # Copy after insertion point (shifted right) if abs_pos < H * W - 1: new_palette_flat[abs_pos + 1:] = palette_flat[abs_pos:H*W-1] return new_palette_flat.view(H, W) @staticmethod def replace_token( palette: torch.Tensor, region_id: int, i_start: int, payload_idx: int, metadata: RegionMetadata ) -> torch.Tensor: """ Replace token at i_start with payload_idx Simplest operation: direct assignment Restriction: Cannot replace scope markers with non-markers Returns: Modified palette Raises: InvalidPointerError: If index out of bounds ScopeBalanceError: If replacing scope marker with non-marker """ H, W = palette.shape region_mask = metadata.masks[region_id] positions = PaletteEditOps._get_region_positions(region_mask, W, palette) # Validate index if i_start < 0 or i_start >= len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds") abs_pos = positions[i_start] h, w = abs_pos // W, abs_pos % W # Check scope marker replacement old_value = palette[h, w].item() is_old_scope = old_value in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE] is_new_scope = payload_idx in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE] if is_old_scope and not is_new_scope: raise ScopeBalanceError("Cannot replace scope marker with non-marker") if is_old_scope and is_new_scope and old_value != payload_idx: raise ScopeBalanceError("Cannot replace START with END or vice versa") # Direct replacement new_palette = palette.clone() new_palette[h, w] = payload_idx return new_palette @staticmethod def swap_tokens( palette: torch.Tensor, region_id: int, i_start: int, i_end: int, metadata: RegionMetadata ) -> torch.Tensor: """ Swap tokens at i_start and i_end Restriction: Cannot swap START ↔ END (would break structure) Returns: Modified palette Raises: InvalidPointerError: If indices out of bounds ScopeBalanceError: If swapping START ↔ END """ H, W = palette.shape region_mask = metadata.masks[region_id] positions = PaletteEditOps._get_region_positions(region_mask, W, palette) # Validate indices if i_start < 0 or i_start >= len(positions): raise InvalidPointerError(f"i_start={i_start} out of bounds") if i_end < 0 or i_end >= len(positions): raise InvalidPointerError(f"i_end={i_end} out of bounds") abs_pos1 = positions[i_start] abs_pos2 = positions[i_end] h1, w1 = abs_pos1 // W, abs_pos1 % W h2, w2 = abs_pos2 // W, abs_pos2 % W # Check scope marker swap val1, val2 = palette[h1, w1].item(), palette[h2, w2].item() if (val1 == PaletteEditOps.START_OF_SCOPE and val2 == PaletteEditOps.END_OF_SCOPE) or \ (val1 == PaletteEditOps.END_OF_SCOPE and val2 == PaletteEditOps.START_OF_SCOPE): raise ScopeBalanceError("Cannot swap START ↔ END") # Swap new_palette = palette.clone() new_palette[h1, w1], new_palette[h2, w2] = palette[h2, w2], palette[h1, w1] return new_palette @staticmethod def macro_transform( palette: torch.Tensor, op_id: int, region_id: int, metadata: RegionMetadata ) -> torch.Tensor: """ Apply macro pattern transformation Searches for pattern in region and replaces with target. Returns: Modified palette Raises: PatternNotFoundError: If pattern not found in region """ if op_id not in PaletteEditOps.MACRO_PATTERNS: raise EditError(f"Unknown macro operation: {op_id}") macro = PaletteEditOps.MACRO_PATTERNS[op_id] pattern = macro['pattern'] target = macro['target'] # Extract region tokens region_mask = metadata.masks[region_id] positions = PaletteEditOps._get_region_positions(region_mask, palette.shape[1], palette) palette_flat = palette.flatten() region_tokens = [palette_flat[pos].item() for pos in positions] # Pattern matching (sliding window) pattern_len = len(pattern) found = False for i in range(len(region_tokens) - pattern_len + 1): window = region_tokens[i:i + pattern_len] if window == pattern: # Match found - apply transformation for j, target_token in enumerate(target): abs_pos = positions[i + j] h, w = abs_pos // palette.shape[1], abs_pos % palette.shape[1] palette[h, w] = target_token found = True break # Apply once per region if not found: raise PatternNotFoundError(f"Pattern {pattern} not found in region {region_id}") return palette @staticmethod def verify_scope_balance(palette: torch.Tensor) -> bool: """ Check if START_OF_SCOPE == END_OF_SCOPE count Returns: True if balanced, False otherwise """ num_starts = (palette == PaletteEditOps.START_OF_SCOPE).sum().item() num_ends = (palette == PaletteEditOps.END_OF_SCOPE).sum().item() return num_starts == num_ends @staticmethod def _get_region_positions(mask: torch.Tensor, W: int, palette: torch.Tensor = None) -> List[int]: """ Get flattened positions of region in scan order (content tokens only, excluding scope markers) Args: mask: (H, W) boolean mask W: width palette: (H, W) palette tensor (optional, used to filter scope markers) Returns: List of flattened indices where mask is True and token is not a scope marker, sorted """ positions = mask.nonzero(as_tuple=False) # (N, 2) of (h, w) flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist() # If palette provided, filter out scope markers if palette is not None: filtered_positions = [] for pos in flat_positions: h, w = pos // W, pos % W token = palette[h, w].item() if token not in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]: filtered_positions.append(pos) return sorted(filtered_positions) return sorted(flat_positions)