| """
|
| Parametric Edit Operations Engine
|
|
|
| Executes edit operations on palette tensors with pointer-based addressing.
|
| Operations preserve scope balance and support region-relative indexing.
|
|
|
| Key Features:
|
| - Parametric actions with arguments
|
| - Pointer arithmetic (region-relative → absolute)
|
| - Scope balance verification
|
| - Macro pattern transformations
|
| - Stateless execution (pure functions)
|
| """
|
|
|
| import torch
|
| from dataclasses import dataclass
|
| from typing import Tuple, Optional, List
|
| from enum import IntEnum
|
|
|
|
|
| from .scope_pooler import RegionMetadata
|
|
|
|
|
| class OpCode(IntEnum):
|
| """Operation codes for edit actions"""
|
|
|
| NO_OP = 0
|
| MOVE_NEXT = 1
|
| FOCUS_PARENT = 2
|
|
|
|
|
| DELETE_RANGE = 150
|
| INSERT_TOKEN = 151
|
| REPLACE_TOKEN = 152
|
| SWAP_TOKENS = 153
|
|
|
|
|
| PY_FOR_TO_JS_FOR = 300
|
| SYNC_TO_ASYNC = 301
|
| ADD_TRY_EXCEPT = 302
|
|
|
|
|
| DONE = 499
|
|
|
|
|
| @dataclass
|
| class EditAction:
|
| """
|
| Parametric edit operation with arguments
|
|
|
| Fields:
|
| op_id: Operation code from OpCode enum
|
| region_id: Which semantic region to operate on [0, R)
|
| i_start: Token index within region (relative addressing)
|
| i_end: End token index (for range operations, -1 if unused)
|
| payload_idx: Palette index to insert/replace (0-4095)
|
| confidence: Model confidence in [0, 1]
|
|
|
| Invariants:
|
| - 0 <= region_id < num_regions
|
| - 0 <= i_start < region_length
|
| - i_end >= i_start or i_end == -1
|
| - 0 <= payload_idx < 4096
|
| - 0 <= confidence <= 1
|
| """
|
| op_id: int
|
| region_id: int
|
| i_start: int
|
| i_end: int
|
| payload_idx: int
|
| confidence: float = 1.0
|
|
|
| def __post_init__(self):
|
| """Validate invariants"""
|
| assert self.op_id >= 0, f"Invalid op_id: {self.op_id}"
|
| assert self.region_id >= 0, f"Invalid region_id: {self.region_id}"
|
| assert self.i_start >= 0, f"Invalid i_start: {self.i_start}"
|
| assert self.i_end >= -1, f"Invalid i_end: {self.i_end}"
|
| if self.i_end != -1:
|
| assert self.i_end >= self.i_start, f"i_end ({self.i_end}) < i_start ({self.i_start})"
|
| assert 0 <= self.payload_idx < 4096, f"Invalid payload_idx: {self.payload_idx}"
|
| assert 0 <= self.confidence <= 1, f"Invalid confidence: {self.confidence}"
|
|
|
|
|
|
|
| class EditError(Exception):
|
| """Base class for edit errors"""
|
| pass
|
|
|
|
|
| class ScopeBalanceError(EditError):
|
| """Operation would break scope balance"""
|
| pass
|
|
|
|
|
| class InvalidPointerError(EditError):
|
| """Pointer out of bounds"""
|
| pass
|
|
|
|
|
| class RegionNotFoundError(EditError):
|
| """region_id invalid"""
|
| pass
|
|
|
|
|
| class PatternNotFoundError(EditError):
|
| """Macro pattern not found in region"""
|
| pass
|
|
|
|
|
| class PaletteEditOps:
|
| """
|
| Stateless edit operation executor
|
|
|
| All methods are pure functions (no internal state).
|
| Thread-safe and deterministic.
|
|
|
| Constants:
|
| START_OF_SCOPE: 0
|
| END_OF_SCOPE: 1
|
| NOOP: 2
|
| """
|
|
|
|
|
| START_OF_SCOPE = 0
|
| END_OF_SCOPE = 1
|
| NOOP = 2
|
|
|
|
|
| MACRO_PATTERNS = {
|
| OpCode.PY_FOR_TO_JS_FOR: {
|
| 'pattern': [20, 220, 220],
|
| 'target': [20, 201, 220],
|
| 'name': 'Python for → JavaScript for'
|
| },
|
| }
|
|
|
| @staticmethod
|
| def apply(
|
| palette_img: torch.Tensor,
|
| action: EditAction,
|
| metadata: RegionMetadata
|
| ) -> Tuple[torch.Tensor, bool]:
|
| """
|
| Apply edit action to palette
|
|
|
| Args:
|
| palette_img: (H, W) palette indices
|
| action: Edit action with parameters
|
| metadata: Region metadata from ScopePooler
|
|
|
| Returns:
|
| new_palette: (H, W) modified palette
|
| success: bool - whether operation succeeded
|
|
|
| Guarantees:
|
| - Original palette unchanged (pure function)
|
| - If success=False, new_palette == palette_img
|
| - Scope balance preserved on success
|
|
|
| Failure cases:
|
| - Invalid region_id
|
| - Invalid pointer indices
|
| - Scope imbalance after edit
|
| """
|
|
|
| if action.region_id >= len(metadata.starts):
|
| return palette_img, False
|
|
|
|
|
| palette = palette_img.clone()
|
|
|
| try:
|
|
|
| if not PaletteEditOps.verify_scope_balance(palette):
|
| raise ScopeBalanceError("Input palette has unbalanced scopes")
|
|
|
|
|
| if action.op_id == OpCode.NO_OP:
|
| new_palette = palette
|
|
|
| elif action.op_id == OpCode.DELETE_RANGE:
|
| new_palette = PaletteEditOps.delete_range(
|
| palette, action.region_id, action.i_start, action.i_end, metadata
|
| )
|
|
|
| elif action.op_id == OpCode.INSERT_TOKEN:
|
| new_palette = PaletteEditOps.insert_token(
|
| palette, action.region_id, action.i_start, action.payload_idx, metadata
|
| )
|
|
|
| elif action.op_id == OpCode.REPLACE_TOKEN:
|
| new_palette = PaletteEditOps.replace_token(
|
| palette, action.region_id, action.i_start, action.payload_idx, metadata
|
| )
|
|
|
| elif action.op_id == OpCode.SWAP_TOKENS:
|
| new_palette = PaletteEditOps.swap_tokens(
|
| palette, action.region_id, action.i_start, action.i_end, metadata
|
| )
|
|
|
| elif action.op_id in PaletteEditOps.MACRO_PATTERNS:
|
| new_palette = PaletteEditOps.macro_transform(
|
| palette, action.op_id, action.region_id, metadata
|
| )
|
|
|
| else:
|
|
|
| return palette_img, False
|
|
|
|
|
| if not PaletteEditOps.verify_scope_balance(new_palette):
|
| raise ScopeBalanceError("Operation broke scope balance")
|
|
|
| return new_palette, True
|
|
|
| except EditError as e:
|
|
|
| return palette_img, False
|
|
|
| @staticmethod
|
| def delete_range(
|
| palette: torch.Tensor,
|
| region_id: int,
|
| i_start: int,
|
| i_end: int,
|
| metadata: RegionMetadata
|
| ) -> torch.Tensor:
|
| """
|
| Delete tokens in range [i_start, i_end] within region
|
|
|
| Algorithm:
|
| 1. Convert region-relative indices to absolute (h,w)
|
| 2. Create mask for tokens to delete
|
| 3. Check if deleting scope markers
|
| 4. Shift remaining tokens left
|
| 5. Pad with NOOP at end
|
|
|
| Returns:
|
| Modified palette
|
|
|
| Raises:
|
| InvalidPointerError: If indices out of bounds
|
| ScopeBalanceError: If would break balance
|
| """
|
| H, W = palette.shape
|
| region_mask = metadata.masks[region_id]
|
|
|
|
|
| positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
|
|
|
|
|
| if i_start < 0 or i_start >= len(positions):
|
| raise InvalidPointerError(f"i_start={i_start} out of bounds (region size={len(positions)})")
|
| if i_end < i_start or i_end >= len(positions):
|
| raise InvalidPointerError(f"i_end={i_end} out of bounds")
|
|
|
|
|
| abs_positions = [positions[i] for i in range(i_start, i_end + 1)]
|
|
|
|
|
| palette_flat = palette.flatten()
|
| for pos in abs_positions:
|
| if palette_flat[pos] in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]:
|
|
|
|
|
| raise ScopeBalanceError("Cannot delete scope markers")
|
|
|
|
|
| delete_mask = torch.zeros(H * W, dtype=torch.bool)
|
| for pos in abs_positions:
|
| delete_mask[pos] = True
|
|
|
|
|
| new_palette_flat = palette_flat[~delete_mask]
|
|
|
|
|
| pad_size = H * W - len(new_palette_flat)
|
| if pad_size > 0:
|
| padding = torch.full((pad_size,), PaletteEditOps.NOOP, dtype=palette.dtype)
|
| new_palette_flat = torch.cat([new_palette_flat, padding])
|
|
|
| return new_palette_flat.view(H, W)
|
|
|
| @staticmethod
|
| def insert_token(
|
| palette: torch.Tensor,
|
| region_id: int,
|
| i_start: int,
|
| payload_idx: int,
|
| metadata: RegionMetadata
|
| ) -> torch.Tensor:
|
| """
|
| Insert token at position i_start within region
|
|
|
| Algorithm:
|
| 1. Convert to absolute position
|
| 2. Check if inserting scope marker (requires pair)
|
| 3. Shift tokens right from position
|
| 4. Insert payload
|
| 5. Handle overflow
|
|
|
| Returns:
|
| Modified palette
|
|
|
| Raises:
|
| InvalidPointerError: If index out of bounds
|
| ScopeBalanceError: If inserting unpaired scope marker
|
| """
|
| H, W = palette.shape
|
| region_mask = metadata.masks[region_id]
|
| positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
|
|
|
|
|
| if i_start < 0 or i_start > len(positions):
|
| raise InvalidPointerError(f"i_start={i_start} out of bounds")
|
|
|
|
|
| if i_start < len(positions):
|
| abs_pos = positions[i_start]
|
| else:
|
|
|
| abs_pos = positions[-1] + 1 if positions else 0
|
|
|
|
|
| if payload_idx == PaletteEditOps.START_OF_SCOPE:
|
|
|
| raise ScopeBalanceError("Cannot insert unpaired START_OF_SCOPE")
|
| if payload_idx == PaletteEditOps.END_OF_SCOPE:
|
| raise ScopeBalanceError("Cannot insert unpaired END_OF_SCOPE")
|
|
|
|
|
| palette_flat = palette.flatten()
|
|
|
|
|
| new_palette_flat = torch.zeros(H * W, dtype=palette.dtype)
|
|
|
|
|
| new_palette_flat[:abs_pos] = palette_flat[:abs_pos]
|
|
|
|
|
| new_palette_flat[abs_pos] = payload_idx
|
|
|
|
|
| if abs_pos < H * W - 1:
|
| new_palette_flat[abs_pos + 1:] = palette_flat[abs_pos:H*W-1]
|
|
|
| return new_palette_flat.view(H, W)
|
|
|
| @staticmethod
|
| def replace_token(
|
| palette: torch.Tensor,
|
| region_id: int,
|
| i_start: int,
|
| payload_idx: int,
|
| metadata: RegionMetadata
|
| ) -> torch.Tensor:
|
| """
|
| Replace token at i_start with payload_idx
|
|
|
| Simplest operation: direct assignment
|
| Restriction: Cannot replace scope markers with non-markers
|
|
|
| Returns:
|
| Modified palette
|
|
|
| Raises:
|
| InvalidPointerError: If index out of bounds
|
| ScopeBalanceError: If replacing scope marker with non-marker
|
| """
|
| H, W = palette.shape
|
| region_mask = metadata.masks[region_id]
|
| positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
|
|
|
|
|
| if i_start < 0 or i_start >= len(positions):
|
| raise InvalidPointerError(f"i_start={i_start} out of bounds")
|
|
|
| abs_pos = positions[i_start]
|
| h, w = abs_pos // W, abs_pos % W
|
|
|
|
|
| old_value = palette[h, w].item()
|
| is_old_scope = old_value in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]
|
| is_new_scope = payload_idx in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]
|
|
|
| if is_old_scope and not is_new_scope:
|
| raise ScopeBalanceError("Cannot replace scope marker with non-marker")
|
| if is_old_scope and is_new_scope and old_value != payload_idx:
|
| raise ScopeBalanceError("Cannot replace START with END or vice versa")
|
|
|
|
|
| new_palette = palette.clone()
|
| new_palette[h, w] = payload_idx
|
| return new_palette
|
|
|
| @staticmethod
|
| def swap_tokens(
|
| palette: torch.Tensor,
|
| region_id: int,
|
| i_start: int,
|
| i_end: int,
|
| metadata: RegionMetadata
|
| ) -> torch.Tensor:
|
| """
|
| Swap tokens at i_start and i_end
|
|
|
| Restriction: Cannot swap START ↔ END (would break structure)
|
|
|
| Returns:
|
| Modified palette
|
|
|
| Raises:
|
| InvalidPointerError: If indices out of bounds
|
| ScopeBalanceError: If swapping START ↔ END
|
| """
|
| H, W = palette.shape
|
| region_mask = metadata.masks[region_id]
|
| positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
|
|
|
|
|
| if i_start < 0 or i_start >= len(positions):
|
| raise InvalidPointerError(f"i_start={i_start} out of bounds")
|
| if i_end < 0 or i_end >= len(positions):
|
| raise InvalidPointerError(f"i_end={i_end} out of bounds")
|
|
|
| abs_pos1 = positions[i_start]
|
| abs_pos2 = positions[i_end]
|
|
|
| h1, w1 = abs_pos1 // W, abs_pos1 % W
|
| h2, w2 = abs_pos2 // W, abs_pos2 % W
|
|
|
|
|
| val1, val2 = palette[h1, w1].item(), palette[h2, w2].item()
|
|
|
| if (val1 == PaletteEditOps.START_OF_SCOPE and val2 == PaletteEditOps.END_OF_SCOPE) or \
|
| (val1 == PaletteEditOps.END_OF_SCOPE and val2 == PaletteEditOps.START_OF_SCOPE):
|
| raise ScopeBalanceError("Cannot swap START ↔ END")
|
|
|
|
|
| new_palette = palette.clone()
|
| new_palette[h1, w1], new_palette[h2, w2] = palette[h2, w2], palette[h1, w1]
|
| return new_palette
|
|
|
| @staticmethod
|
| def macro_transform(
|
| palette: torch.Tensor,
|
| op_id: int,
|
| region_id: int,
|
| metadata: RegionMetadata
|
| ) -> torch.Tensor:
|
| """
|
| Apply macro pattern transformation
|
|
|
| Searches for pattern in region and replaces with target.
|
|
|
| Returns:
|
| Modified palette
|
|
|
| Raises:
|
| PatternNotFoundError: If pattern not found in region
|
| """
|
| if op_id not in PaletteEditOps.MACRO_PATTERNS:
|
| raise EditError(f"Unknown macro operation: {op_id}")
|
|
|
| macro = PaletteEditOps.MACRO_PATTERNS[op_id]
|
| pattern = macro['pattern']
|
| target = macro['target']
|
|
|
|
|
| region_mask = metadata.masks[region_id]
|
| positions = PaletteEditOps._get_region_positions(region_mask, palette.shape[1], palette)
|
|
|
| palette_flat = palette.flatten()
|
| region_tokens = [palette_flat[pos].item() for pos in positions]
|
|
|
|
|
| pattern_len = len(pattern)
|
| found = False
|
|
|
| for i in range(len(region_tokens) - pattern_len + 1):
|
| window = region_tokens[i:i + pattern_len]
|
| if window == pattern:
|
|
|
| for j, target_token in enumerate(target):
|
| abs_pos = positions[i + j]
|
| h, w = abs_pos // palette.shape[1], abs_pos % palette.shape[1]
|
| palette[h, w] = target_token
|
| found = True
|
| break
|
|
|
| if not found:
|
| raise PatternNotFoundError(f"Pattern {pattern} not found in region {region_id}")
|
|
|
| return palette
|
|
|
| @staticmethod
|
| def verify_scope_balance(palette: torch.Tensor) -> bool:
|
| """
|
| Check if START_OF_SCOPE == END_OF_SCOPE count
|
|
|
| Returns:
|
| True if balanced, False otherwise
|
| """
|
| num_starts = (palette == PaletteEditOps.START_OF_SCOPE).sum().item()
|
| num_ends = (palette == PaletteEditOps.END_OF_SCOPE).sum().item()
|
| return num_starts == num_ends
|
|
|
| @staticmethod
|
| def _get_region_positions(mask: torch.Tensor, W: int, palette: torch.Tensor = None) -> List[int]:
|
| """
|
| Get flattened positions of region in scan order (content tokens only, excluding scope markers)
|
|
|
| Args:
|
| mask: (H, W) boolean mask
|
| W: width
|
| palette: (H, W) palette tensor (optional, used to filter scope markers)
|
|
|
| Returns:
|
| List of flattened indices where mask is True and token is not a scope marker, sorted
|
| """
|
| positions = mask.nonzero(as_tuple=False)
|
| flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist()
|
|
|
|
|
| if palette is not None:
|
| filtered_positions = []
|
| for pos in flat_positions:
|
| h, w = pos // W, pos % W
|
| token = palette[h, w].item()
|
| if token not in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]:
|
| filtered_positions.append(pos)
|
| return sorted(filtered_positions)
|
|
|
| return sorted(flat_positions)
|
|
|