File size: 11,140 Bytes
3ff7322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""

Scope-Aware Pooler



Extracts semantic regions from palette using scope markers (0=START, 1=END).

Implements exact scope matching via stack-based algorithm.

"""

import logging
import torch
import torch.nn as nn
from typing import List, Tuple, NamedTuple
from dataclasses import dataclass


class RegionMetadata(NamedTuple):
    """

    Metadata about detected semantic regions



    Fields:

    - masks: BoolTensor[R, H, W] - spatial masks for each region

    - starts: List[int] - flattened start indices

    - ends: List[int] - flattened end indices

    - depths: List[int] - nesting depth of each region

    - types: List[str] - region type hints

    """
    masks: torch.Tensor
    starts: List[int]
    ends: List[int]
    depths: List[int]
    types: List[str]


class ScopeImbalanceError(Exception):
    """Raised when scope markers are critically unbalanced"""
    pass


class ScopePooler(nn.Module):
    """

    Extract semantic regions from palette using scope markers



    This module identifies code scopes (functions, loops, classes, etc.)

    by matching START_OF_SCOPE (0) and END_OF_SCOPE (1) tokens.



    Algorithm:

    1. Flatten palette to 1D sequence

    2. Stack-based matching of scope markers

    3. Extract features for each matched region

    4. Pool features via mean+max aggregation



    Edge Cases Handled:

    - Unbalanced scopes (warning + best-effort matching)

    - Nested scopes (via stack depth tracking)

    - No scopes found (fallback to uniform grid)

    - Empty regions (skip + warning)

    """

    def __init__(

        self,

        hidden_dim: int = 768,

        min_region_size: int = 2,

        fallback_grid_size: int = 4

    ):
        """

        Args:

            hidden_dim: Feature dimension

            min_region_size: Minimum tokens per region

            fallback_grid_size: Grid size when no scopes found

        """
        super().__init__()

        self.hidden_dim = hidden_dim
        self.min_region_size = min_region_size
        self.fallback_grid_size = fallback_grid_size

        # Learned pooling projection
        # Concat [mean, max] then project back to hidden_dim
        self.pool_proj = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(

        self,

        features: torch.Tensor,  # (B, H, W, D)

        palette: torch.Tensor    # (B, H, W)

    ) -> Tuple[torch.Tensor, List[RegionMetadata]]:
        """

        Extract semantic regions and pool features



        Args:

            features: (B, H, W, D) - ViT output features

            palette: (B, H, W) - palette indices



        Returns:

            regions: (B, R, D) - per-region pooled features

            metadata: List[RegionMetadata] - one per batch item



        Guarantees:

            - R >= 1 always (at least one region)

            - All regions non-empty

            - Features normalized (unit norm)

        """
        B, H, W, D = features.shape
        assert palette.shape == (B, H, W), f"Shape mismatch: features{features.shape} vs palette{palette.shape}"
        assert D == self.hidden_dim, f"Hidden dim mismatch: {D} != {self.hidden_dim}"

        all_regions = []
        all_metadata = []

        for b in range(B):
            feat_b = features[b]  # (H, W, D)
            pal_b = palette[b]    # (H, W)

            # Extract regions for this sample
            regions_b, meta_b = self._extract_regions_single(feat_b, pal_b, H, W)

            all_regions.append(regions_b)  # (R_b, D)
            all_metadata.append(meta_b)

        # Pad to max number of regions in batch
        max_regions = max(r.shape[0] for r in all_regions)
        padded_regions = []

        for regions_b in all_regions:
            R_b = regions_b.shape[0]
            if R_b < max_regions:
                # Pad with zeros
                padding = torch.zeros(
                    max_regions - R_b, D,
                    device=regions_b.device,
                    dtype=regions_b.dtype
                )
                regions_b = torch.cat([regions_b, padding], dim=0)
            padded_regions.append(regions_b)

        batched_regions = torch.stack(padded_regions, dim=0)  # (B, R_max, D)

        return batched_regions, all_metadata

    def _extract_regions_single(

        self,

        features: torch.Tensor,  # (H, W, D)

        palette: torch.Tensor,   # (H, W)

        H: int,

        W: int

    ) -> Tuple[torch.Tensor, RegionMetadata]:
        """

        Extract regions from a single sample



        Returns:

            regions: (R, D) - pooled features

            metadata: RegionMetadata

        """
        # 1. Flatten to sequence
        seq = palette.flatten()  # (H*W,)
        features_flat = features.view(-1, self.hidden_dim)  # (H*W, D)

        # 2. Match scopes
        try:
            scope_pairs, depths = self._match_scopes(seq)
        except ScopeImbalanceError as e:
            # Critical error - scopes too broken to recover
            logging.warning(f"{e}. Using fallback uniform grid.")
            scope_pairs, depths = self._fallback_uniform_grid(H, W)

        # 3. Filter invalid regions
        valid_pairs = []
        valid_depths = []
        for (start, end), depth in zip(scope_pairs, depths):
            if (end - start + 1) >= self.min_region_size:
                valid_pairs.append((start, end))
                valid_depths.append(depth)

        if not valid_pairs:
            # No valid regions - use full sequence
            valid_pairs = [(0, H*W - 1)]
            valid_depths = [0]

        # 4. Extract features for each region
        region_features = []
        region_masks = []
        starts = []
        ends = []

        for (start, end) in valid_pairs:
            # Extract features in range
            region_feat = features_flat[start:end+1]  # (L, D)

            # Pool: mean + max
            mean_pool = region_feat.mean(dim=0)  # (D,)
            max_pool = region_feat.max(dim=0)[0]  # (D,)

            # Concatenate and project
            combined = torch.cat([mean_pool, max_pool], dim=0)  # (2D,)
            pooled = self.pool_proj(combined)  # (D,)

            # Normalize
            pooled = torch.nn.functional.normalize(pooled, dim=0)

            region_features.append(pooled)

            # Create mask
            mask = torch.zeros(H * W, dtype=torch.bool, device=palette.device)
            mask[start:end+1] = True
            mask_2d = mask.view(H, W)
            region_masks.append(mask_2d)

            starts.append(start)
            ends.append(end)

        # Stack regions
        regions = torch.stack(region_features, dim=0)  # (R, D)
        masks = torch.stack(region_masks, dim=0)  # (R, H, W)

        # Create metadata
        types = ['scope'] * len(valid_pairs)  # Generic type for now
        metadata = RegionMetadata(
            masks=masks,
            starts=starts,
            ends=ends,
            depths=valid_depths,
            types=types
        )

        return regions, metadata

    def _match_scopes(

        self,

        seq: torch.Tensor  # (N,)

    ) -> Tuple[List[Tuple[int, int]], List[int]]:
        """

        Stack-based scope matching



        Returns:

            pairs: List of (start_idx, end_idx) tuples

            depths: List of nesting depths



        Algorithm:

        - Maintain stack of open scope indices

        - When seeing START (0), push index

        - When seeing END (1), pop and create pair

        - Track depth = current stack size



        Edge Cases:

        - Unmatched START: close at sequence end

        - Unmatched END: skip with warning

        - No scopes: return empty list (caller handles)

        """
        START_OF_SCOPE = 0
        END_OF_SCOPE = 1

        stack = []  # Stack of (index, depth)
        pairs = []
        depths = []

        seq_np = seq.cpu().numpy()  # Faster iteration

        for i, token in enumerate(seq_np):
            if token == START_OF_SCOPE:
                # Open new scope
                depth = len(stack)
                stack.append((i, depth))

            elif token == END_OF_SCOPE:
                # Close scope
                if stack:
                    start_idx, depth = stack.pop()
                    pairs.append((start_idx, i))
                    depths.append(depth)
                else:
                    # Unmatched END - skip
                    logging.warning(f"Unmatched END_OF_SCOPE at position {i}")

        # Handle unmatched STARTs
        if stack:
            logging.warning(f"{len(stack)} unmatched START_OF_SCOPE tokens")
            # Close them at sequence end
            seq_len = len(seq_np)
            for start_idx, depth in stack:
                pairs.append((start_idx, seq_len - 1))
                depths.append(depth)

        # Validate: check for severe imbalance
        num_starts = (seq == START_OF_SCOPE).sum().item()
        num_ends = (seq == END_OF_SCOPE).sum().item()

        if abs(num_starts - num_ends) > max(num_starts, num_ends) * 0.5:
            # More than 50% imbalance - critical error
            raise ScopeImbalanceError(
                f"Severe scope imbalance: {num_starts} starts vs {num_ends} ends"
            )

        return pairs, depths

    def _fallback_uniform_grid(

        self,

        H: int,

        W: int

    ) -> Tuple[List[Tuple[int, int]], List[int]]:
        """

        Fallback when scope matching fails



        Returns uniform grid of regions



        Args:

            H, W: palette dimensions



        Returns:

            pairs: List of (start, end) for grid cells

            depths: All depth=0 (flat)

        """
        total = H * W
        grid_size = self.fallback_grid_size
        region_size = total // grid_size

        pairs = []
        for i in range(grid_size):
            start = i * region_size
            end = (i + 1) * region_size - 1 if i < grid_size - 1 else total - 1
            pairs.append((start, end))

        depths = [0] * grid_size

        return pairs, depths

    def visualize_regions(

        self,

        palette: torch.Tensor,  # (H, W)

        metadata: RegionMetadata

    ) -> str:
        """

        Generate human-readable visualization of regions



        Returns: String representation

        """
        H, W = palette.shape
        output = []
        output.append(f"Detected {len(metadata.starts)} regions:")

        for i, (start, end, depth) in enumerate(zip(
            metadata.starts,
            metadata.ends,
            metadata.depths
        )):
            region_size = end - start + 1
            indent = "  " * depth
            output.append(
                f"{indent}Region {i}: [{start:4d}, {end:4d}] "
                f"(size={region_size:3d}, depth={depth})"
            )

        return "\n".join(output)