palette-reasoner / models /edit_ops.py
Jonttup's picture
Upload models/edit_ops.py with huggingface_hub
d17e9c3 verified
"""
Parametric Edit Operations Engine
Executes edit operations on palette tensors with pointer-based addressing.
Operations preserve scope balance and support region-relative indexing.
Key Features:
- Parametric actions with arguments
- Pointer arithmetic (region-relative → absolute)
- Scope balance verification
- Macro pattern transformations
- Stateless execution (pure functions)
"""
import torch
from dataclasses import dataclass
from typing import Tuple, Optional, List
from enum import IntEnum
# Import RegionMetadata from scope_pooler
from .scope_pooler import RegionMetadata
class OpCode(IntEnum):
"""Operation codes for edit actions"""
# Control (0-99)
NO_OP = 0
MOVE_NEXT = 1
FOCUS_PARENT = 2
# Atomic edits (150-199)
DELETE_RANGE = 150
INSERT_TOKEN = 151
REPLACE_TOKEN = 152
SWAP_TOKENS = 153
# Macro transforms (300-349)
PY_FOR_TO_JS_FOR = 300
SYNC_TO_ASYNC = 301
ADD_TRY_EXCEPT = 302
# Control flow
DONE = 499
@dataclass
class EditAction:
"""
Parametric edit operation with arguments
Fields:
op_id: Operation code from OpCode enum
region_id: Which semantic region to operate on [0, R)
i_start: Token index within region (relative addressing)
i_end: End token index (for range operations, -1 if unused)
payload_idx: Palette index to insert/replace (0-4095)
confidence: Model confidence in [0, 1]
Invariants:
- 0 <= region_id < num_regions
- 0 <= i_start < region_length
- i_end >= i_start or i_end == -1
- 0 <= payload_idx < 4096
- 0 <= confidence <= 1
"""
op_id: int
region_id: int
i_start: int
i_end: int
payload_idx: int
confidence: float = 1.0
def __post_init__(self):
"""Validate invariants"""
assert self.op_id >= 0, f"Invalid op_id: {self.op_id}"
assert self.region_id >= 0, f"Invalid region_id: {self.region_id}"
assert self.i_start >= 0, f"Invalid i_start: {self.i_start}"
assert self.i_end >= -1, f"Invalid i_end: {self.i_end}"
if self.i_end != -1:
assert self.i_end >= self.i_start, f"i_end ({self.i_end}) < i_start ({self.i_start})"
assert 0 <= self.payload_idx < 4096, f"Invalid payload_idx: {self.payload_idx}"
assert 0 <= self.confidence <= 1, f"Invalid confidence: {self.confidence}"
# Custom exceptions
class EditError(Exception):
"""Base class for edit errors"""
pass
class ScopeBalanceError(EditError):
"""Operation would break scope balance"""
pass
class InvalidPointerError(EditError):
"""Pointer out of bounds"""
pass
class RegionNotFoundError(EditError):
"""region_id invalid"""
pass
class PatternNotFoundError(EditError):
"""Macro pattern not found in region"""
pass
class PaletteEditOps:
"""
Stateless edit operation executor
All methods are pure functions (no internal state).
Thread-safe and deterministic.
Constants:
START_OF_SCOPE: 0
END_OF_SCOPE: 1
NOOP: 2
"""
# Palette constants
START_OF_SCOPE = 0
END_OF_SCOPE = 1
NOOP = 2
# Macro pattern definitions
MACRO_PATTERNS = {
OpCode.PY_FOR_TO_JS_FOR: {
'pattern': [20, 220, 220], # for x in collection
'target': [20, 201, 220], # for(let x of collection)
'name': 'Python for → JavaScript for'
},
}
@staticmethod
def apply(
palette_img: torch.Tensor, # (H, W)
action: EditAction,
metadata: RegionMetadata
) -> Tuple[torch.Tensor, bool]:
"""
Apply edit action to palette
Args:
palette_img: (H, W) palette indices
action: Edit action with parameters
metadata: Region metadata from ScopePooler
Returns:
new_palette: (H, W) modified palette
success: bool - whether operation succeeded
Guarantees:
- Original palette unchanged (pure function)
- If success=False, new_palette == palette_img
- Scope balance preserved on success
Failure cases:
- Invalid region_id
- Invalid pointer indices
- Scope imbalance after edit
"""
# Validate inputs
if action.region_id >= len(metadata.starts):
return palette_img, False
# Create working copy
palette = palette_img.clone()
try:
# Pre-check balance
if not PaletteEditOps.verify_scope_balance(palette):
raise ScopeBalanceError("Input palette has unbalanced scopes")
# Dispatch to operation
if action.op_id == OpCode.NO_OP:
new_palette = palette
elif action.op_id == OpCode.DELETE_RANGE:
new_palette = PaletteEditOps.delete_range(
palette, action.region_id, action.i_start, action.i_end, metadata
)
elif action.op_id == OpCode.INSERT_TOKEN:
new_palette = PaletteEditOps.insert_token(
palette, action.region_id, action.i_start, action.payload_idx, metadata
)
elif action.op_id == OpCode.REPLACE_TOKEN:
new_palette = PaletteEditOps.replace_token(
palette, action.region_id, action.i_start, action.payload_idx, metadata
)
elif action.op_id == OpCode.SWAP_TOKENS:
new_palette = PaletteEditOps.swap_tokens(
palette, action.region_id, action.i_start, action.i_end, metadata
)
elif action.op_id in PaletteEditOps.MACRO_PATTERNS:
new_palette = PaletteEditOps.macro_transform(
palette, action.op_id, action.region_id, metadata
)
else:
# Unknown operation
return palette_img, False
# Post-check balance
if not PaletteEditOps.verify_scope_balance(new_palette):
raise ScopeBalanceError("Operation broke scope balance")
return new_palette, True
except EditError as e:
# Operation failed - return original
return palette_img, False
@staticmethod
def delete_range(
palette: torch.Tensor,
region_id: int,
i_start: int,
i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""
Delete tokens in range [i_start, i_end] within region
Algorithm:
1. Convert region-relative indices to absolute (h,w)
2. Create mask for tokens to delete
3. Check if deleting scope markers
4. Shift remaining tokens left
5. Pad with NOOP at end
Returns:
Modified palette
Raises:
InvalidPointerError: If indices out of bounds
ScopeBalanceError: If would break balance
"""
H, W = palette.shape
region_mask = metadata.masks[region_id] # (H, W)
# Get region positions in flattened order (content tokens only)
positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
# Validate indices
if i_start < 0 or i_start >= len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds (region size={len(positions)})")
if i_end < i_start or i_end >= len(positions):
raise InvalidPointerError(f"i_end={i_end} out of bounds")
# Convert to absolute positions
abs_positions = [positions[i] for i in range(i_start, i_end + 1)]
# Check if deleting scope markers
palette_flat = palette.flatten()
for pos in abs_positions:
if palette_flat[pos] in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]:
# For now, disallow deleting scope markers
# TODO: Implement smart deletion (delete matching pairs)
raise ScopeBalanceError("Cannot delete scope markers")
# Create deletion mask
delete_mask = torch.zeros(H * W, dtype=torch.bool)
for pos in abs_positions:
delete_mask[pos] = True
# Shift tokens left
new_palette_flat = palette_flat[~delete_mask]
# Pad with NOOP
pad_size = H * W - len(new_palette_flat)
if pad_size > 0:
padding = torch.full((pad_size,), PaletteEditOps.NOOP, dtype=palette.dtype)
new_palette_flat = torch.cat([new_palette_flat, padding])
return new_palette_flat.view(H, W)
@staticmethod
def insert_token(
palette: torch.Tensor,
region_id: int,
i_start: int,
payload_idx: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""
Insert token at position i_start within region
Algorithm:
1. Convert to absolute position
2. Check if inserting scope marker (requires pair)
3. Shift tokens right from position
4. Insert payload
5. Handle overflow
Returns:
Modified palette
Raises:
InvalidPointerError: If index out of bounds
ScopeBalanceError: If inserting unpaired scope marker
"""
H, W = palette.shape
region_mask = metadata.masks[region_id]
positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
# Validate index
if i_start < 0 or i_start > len(positions): # Allow insertion at end
raise InvalidPointerError(f"i_start={i_start} out of bounds")
# Get absolute position
if i_start < len(positions):
abs_pos = positions[i_start]
else:
# Insert at end of region
abs_pos = positions[-1] + 1 if positions else 0
# Check for scope marker insertion
if payload_idx == PaletteEditOps.START_OF_SCOPE:
# TODO: Also insert matching END
raise ScopeBalanceError("Cannot insert unpaired START_OF_SCOPE")
if payload_idx == PaletteEditOps.END_OF_SCOPE:
raise ScopeBalanceError("Cannot insert unpaired END_OF_SCOPE")
# Flatten and insert
palette_flat = palette.flatten()
# Create new palette with extra space
new_palette_flat = torch.zeros(H * W, dtype=palette.dtype)
# Copy before insertion point
new_palette_flat[:abs_pos] = palette_flat[:abs_pos]
# Insert payload
new_palette_flat[abs_pos] = payload_idx
# Copy after insertion point (shifted right)
if abs_pos < H * W - 1:
new_palette_flat[abs_pos + 1:] = palette_flat[abs_pos:H*W-1]
return new_palette_flat.view(H, W)
@staticmethod
def replace_token(
palette: torch.Tensor,
region_id: int,
i_start: int,
payload_idx: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""
Replace token at i_start with payload_idx
Simplest operation: direct assignment
Restriction: Cannot replace scope markers with non-markers
Returns:
Modified palette
Raises:
InvalidPointerError: If index out of bounds
ScopeBalanceError: If replacing scope marker with non-marker
"""
H, W = palette.shape
region_mask = metadata.masks[region_id]
positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
# Validate index
if i_start < 0 or i_start >= len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds")
abs_pos = positions[i_start]
h, w = abs_pos // W, abs_pos % W
# Check scope marker replacement
old_value = palette[h, w].item()
is_old_scope = old_value in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]
is_new_scope = payload_idx in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]
if is_old_scope and not is_new_scope:
raise ScopeBalanceError("Cannot replace scope marker with non-marker")
if is_old_scope and is_new_scope and old_value != payload_idx:
raise ScopeBalanceError("Cannot replace START with END or vice versa")
# Direct replacement
new_palette = palette.clone()
new_palette[h, w] = payload_idx
return new_palette
@staticmethod
def swap_tokens(
palette: torch.Tensor,
region_id: int,
i_start: int,
i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""
Swap tokens at i_start and i_end
Restriction: Cannot swap START ↔ END (would break structure)
Returns:
Modified palette
Raises:
InvalidPointerError: If indices out of bounds
ScopeBalanceError: If swapping START ↔ END
"""
H, W = palette.shape
region_mask = metadata.masks[region_id]
positions = PaletteEditOps._get_region_positions(region_mask, W, palette)
# Validate indices
if i_start < 0 or i_start >= len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds")
if i_end < 0 or i_end >= len(positions):
raise InvalidPointerError(f"i_end={i_end} out of bounds")
abs_pos1 = positions[i_start]
abs_pos2 = positions[i_end]
h1, w1 = abs_pos1 // W, abs_pos1 % W
h2, w2 = abs_pos2 // W, abs_pos2 % W
# Check scope marker swap
val1, val2 = palette[h1, w1].item(), palette[h2, w2].item()
if (val1 == PaletteEditOps.START_OF_SCOPE and val2 == PaletteEditOps.END_OF_SCOPE) or \
(val1 == PaletteEditOps.END_OF_SCOPE and val2 == PaletteEditOps.START_OF_SCOPE):
raise ScopeBalanceError("Cannot swap START ↔ END")
# Swap
new_palette = palette.clone()
new_palette[h1, w1], new_palette[h2, w2] = palette[h2, w2], palette[h1, w1]
return new_palette
@staticmethod
def macro_transform(
palette: torch.Tensor,
op_id: int,
region_id: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""
Apply macro pattern transformation
Searches for pattern in region and replaces with target.
Returns:
Modified palette
Raises:
PatternNotFoundError: If pattern not found in region
"""
if op_id not in PaletteEditOps.MACRO_PATTERNS:
raise EditError(f"Unknown macro operation: {op_id}")
macro = PaletteEditOps.MACRO_PATTERNS[op_id]
pattern = macro['pattern']
target = macro['target']
# Extract region tokens
region_mask = metadata.masks[region_id]
positions = PaletteEditOps._get_region_positions(region_mask, palette.shape[1], palette)
palette_flat = palette.flatten()
region_tokens = [palette_flat[pos].item() for pos in positions]
# Pattern matching (sliding window)
pattern_len = len(pattern)
found = False
for i in range(len(region_tokens) - pattern_len + 1):
window = region_tokens[i:i + pattern_len]
if window == pattern:
# Match found - apply transformation
for j, target_token in enumerate(target):
abs_pos = positions[i + j]
h, w = abs_pos // palette.shape[1], abs_pos % palette.shape[1]
palette[h, w] = target_token
found = True
break # Apply once per region
if not found:
raise PatternNotFoundError(f"Pattern {pattern} not found in region {region_id}")
return palette
@staticmethod
def verify_scope_balance(palette: torch.Tensor) -> bool:
"""
Check if START_OF_SCOPE == END_OF_SCOPE count
Returns:
True if balanced, False otherwise
"""
num_starts = (palette == PaletteEditOps.START_OF_SCOPE).sum().item()
num_ends = (palette == PaletteEditOps.END_OF_SCOPE).sum().item()
return num_starts == num_ends
@staticmethod
def _get_region_positions(mask: torch.Tensor, W: int, palette: torch.Tensor = None) -> List[int]:
"""
Get flattened positions of region in scan order (content tokens only, excluding scope markers)
Args:
mask: (H, W) boolean mask
W: width
palette: (H, W) palette tensor (optional, used to filter scope markers)
Returns:
List of flattened indices where mask is True and token is not a scope marker, sorted
"""
positions = mask.nonzero(as_tuple=False) # (N, 2) of (h, w)
flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist()
# If palette provided, filter out scope markers
if palette is not None:
filtered_positions = []
for pos in flat_positions:
h, w = pos // W, pos % W
token = palette[h, w].item()
if token not in [PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE]:
filtered_positions.append(pos)
return sorted(filtered_positions)
return sorted(flat_positions)