File size: 11,140 Bytes
3ff7322 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 | """
Scope-Aware Pooler
Extracts semantic regions from palette using scope markers (0=START, 1=END).
Implements exact scope matching via stack-based algorithm.
"""
import logging
import torch
import torch.nn as nn
from typing import List, Tuple, NamedTuple
from dataclasses import dataclass
class RegionMetadata(NamedTuple):
"""
Metadata about detected semantic regions
Fields:
- masks: BoolTensor[R, H, W] - spatial masks for each region
- starts: List[int] - flattened start indices
- ends: List[int] - flattened end indices
- depths: List[int] - nesting depth of each region
- types: List[str] - region type hints
"""
masks: torch.Tensor
starts: List[int]
ends: List[int]
depths: List[int]
types: List[str]
class ScopeImbalanceError(Exception):
"""Raised when scope markers are critically unbalanced"""
pass
class ScopePooler(nn.Module):
"""
Extract semantic regions from palette using scope markers
This module identifies code scopes (functions, loops, classes, etc.)
by matching START_OF_SCOPE (0) and END_OF_SCOPE (1) tokens.
Algorithm:
1. Flatten palette to 1D sequence
2. Stack-based matching of scope markers
3. Extract features for each matched region
4. Pool features via mean+max aggregation
Edge Cases Handled:
- Unbalanced scopes (warning + best-effort matching)
- Nested scopes (via stack depth tracking)
- No scopes found (fallback to uniform grid)
- Empty regions (skip + warning)
"""
def __init__(
self,
hidden_dim: int = 768,
min_region_size: int = 2,
fallback_grid_size: int = 4
):
"""
Args:
hidden_dim: Feature dimension
min_region_size: Minimum tokens per region
fallback_grid_size: Grid size when no scopes found
"""
super().__init__()
self.hidden_dim = hidden_dim
self.min_region_size = min_region_size
self.fallback_grid_size = fallback_grid_size
# Learned pooling projection
# Concat [mean, max] then project back to hidden_dim
self.pool_proj = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(
self,
features: torch.Tensor, # (B, H, W, D)
palette: torch.Tensor # (B, H, W)
) -> Tuple[torch.Tensor, List[RegionMetadata]]:
"""
Extract semantic regions and pool features
Args:
features: (B, H, W, D) - ViT output features
palette: (B, H, W) - palette indices
Returns:
regions: (B, R, D) - per-region pooled features
metadata: List[RegionMetadata] - one per batch item
Guarantees:
- R >= 1 always (at least one region)
- All regions non-empty
- Features normalized (unit norm)
"""
B, H, W, D = features.shape
assert palette.shape == (B, H, W), f"Shape mismatch: features{features.shape} vs palette{palette.shape}"
assert D == self.hidden_dim, f"Hidden dim mismatch: {D} != {self.hidden_dim}"
all_regions = []
all_metadata = []
for b in range(B):
feat_b = features[b] # (H, W, D)
pal_b = palette[b] # (H, W)
# Extract regions for this sample
regions_b, meta_b = self._extract_regions_single(feat_b, pal_b, H, W)
all_regions.append(regions_b) # (R_b, D)
all_metadata.append(meta_b)
# Pad to max number of regions in batch
max_regions = max(r.shape[0] for r in all_regions)
padded_regions = []
for regions_b in all_regions:
R_b = regions_b.shape[0]
if R_b < max_regions:
# Pad with zeros
padding = torch.zeros(
max_regions - R_b, D,
device=regions_b.device,
dtype=regions_b.dtype
)
regions_b = torch.cat([regions_b, padding], dim=0)
padded_regions.append(regions_b)
batched_regions = torch.stack(padded_regions, dim=0) # (B, R_max, D)
return batched_regions, all_metadata
def _extract_regions_single(
self,
features: torch.Tensor, # (H, W, D)
palette: torch.Tensor, # (H, W)
H: int,
W: int
) -> Tuple[torch.Tensor, RegionMetadata]:
"""
Extract regions from a single sample
Returns:
regions: (R, D) - pooled features
metadata: RegionMetadata
"""
# 1. Flatten to sequence
seq = palette.flatten() # (H*W,)
features_flat = features.view(-1, self.hidden_dim) # (H*W, D)
# 2. Match scopes
try:
scope_pairs, depths = self._match_scopes(seq)
except ScopeImbalanceError as e:
# Critical error - scopes too broken to recover
logging.warning(f"{e}. Using fallback uniform grid.")
scope_pairs, depths = self._fallback_uniform_grid(H, W)
# 3. Filter invalid regions
valid_pairs = []
valid_depths = []
for (start, end), depth in zip(scope_pairs, depths):
if (end - start + 1) >= self.min_region_size:
valid_pairs.append((start, end))
valid_depths.append(depth)
if not valid_pairs:
# No valid regions - use full sequence
valid_pairs = [(0, H*W - 1)]
valid_depths = [0]
# 4. Extract features for each region
region_features = []
region_masks = []
starts = []
ends = []
for (start, end) in valid_pairs:
# Extract features in range
region_feat = features_flat[start:end+1] # (L, D)
# Pool: mean + max
mean_pool = region_feat.mean(dim=0) # (D,)
max_pool = region_feat.max(dim=0)[0] # (D,)
# Concatenate and project
combined = torch.cat([mean_pool, max_pool], dim=0) # (2D,)
pooled = self.pool_proj(combined) # (D,)
# Normalize
pooled = torch.nn.functional.normalize(pooled, dim=0)
region_features.append(pooled)
# Create mask
mask = torch.zeros(H * W, dtype=torch.bool, device=palette.device)
mask[start:end+1] = True
mask_2d = mask.view(H, W)
region_masks.append(mask_2d)
starts.append(start)
ends.append(end)
# Stack regions
regions = torch.stack(region_features, dim=0) # (R, D)
masks = torch.stack(region_masks, dim=0) # (R, H, W)
# Create metadata
types = ['scope'] * len(valid_pairs) # Generic type for now
metadata = RegionMetadata(
masks=masks,
starts=starts,
ends=ends,
depths=valid_depths,
types=types
)
return regions, metadata
def _match_scopes(
self,
seq: torch.Tensor # (N,)
) -> Tuple[List[Tuple[int, int]], List[int]]:
"""
Stack-based scope matching
Returns:
pairs: List of (start_idx, end_idx) tuples
depths: List of nesting depths
Algorithm:
- Maintain stack of open scope indices
- When seeing START (0), push index
- When seeing END (1), pop and create pair
- Track depth = current stack size
Edge Cases:
- Unmatched START: close at sequence end
- Unmatched END: skip with warning
- No scopes: return empty list (caller handles)
"""
START_OF_SCOPE = 0
END_OF_SCOPE = 1
stack = [] # Stack of (index, depth)
pairs = []
depths = []
seq_np = seq.cpu().numpy() # Faster iteration
for i, token in enumerate(seq_np):
if token == START_OF_SCOPE:
# Open new scope
depth = len(stack)
stack.append((i, depth))
elif token == END_OF_SCOPE:
# Close scope
if stack:
start_idx, depth = stack.pop()
pairs.append((start_idx, i))
depths.append(depth)
else:
# Unmatched END - skip
logging.warning(f"Unmatched END_OF_SCOPE at position {i}")
# Handle unmatched STARTs
if stack:
logging.warning(f"{len(stack)} unmatched START_OF_SCOPE tokens")
# Close them at sequence end
seq_len = len(seq_np)
for start_idx, depth in stack:
pairs.append((start_idx, seq_len - 1))
depths.append(depth)
# Validate: check for severe imbalance
num_starts = (seq == START_OF_SCOPE).sum().item()
num_ends = (seq == END_OF_SCOPE).sum().item()
if abs(num_starts - num_ends) > max(num_starts, num_ends) * 0.5:
# More than 50% imbalance - critical error
raise ScopeImbalanceError(
f"Severe scope imbalance: {num_starts} starts vs {num_ends} ends"
)
return pairs, depths
def _fallback_uniform_grid(
self,
H: int,
W: int
) -> Tuple[List[Tuple[int, int]], List[int]]:
"""
Fallback when scope matching fails
Returns uniform grid of regions
Args:
H, W: palette dimensions
Returns:
pairs: List of (start, end) for grid cells
depths: All depth=0 (flat)
"""
total = H * W
grid_size = self.fallback_grid_size
region_size = total // grid_size
pairs = []
for i in range(grid_size):
start = i * region_size
end = (i + 1) * region_size - 1 if i < grid_size - 1 else total - 1
pairs.append((start, end))
depths = [0] * grid_size
return pairs, depths
def visualize_regions(
self,
palette: torch.Tensor, # (H, W)
metadata: RegionMetadata
) -> str:
"""
Generate human-readable visualization of regions
Returns: String representation
"""
H, W = palette.shape
output = []
output.append(f"Detected {len(metadata.starts)} regions:")
for i, (start, end, depth) in enumerate(zip(
metadata.starts,
metadata.ends,
metadata.depths
)):
region_size = end - start + 1
indent = " " * depth
output.append(
f"{indent}Region {i}: [{start:4d}, {end:4d}] "
f"(size={region_size:3d}, depth={depth})"
)
return "\n".join(output)
|