Spaces:
Sleeping
Sleeping
| """ | |
| Native Sparse Attention (NSA) Model for Pupil Segmentation. | |
| Implementation based on DeepSeek's NSA paper: | |
| "Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention" | |
| Adapted for 2D vision/segmentation tasks with domain-specific optimizations for | |
| pupil segmentation where: | |
| - Intense pixel localization is required | |
| - The pupil is only found on the eye (spatial locality) | |
| - OpenEDS provides multi-class data beyond pupil | |
| Architecture: | |
| - Encoder with NSA blocks for hierarchical feature extraction | |
| - Decoder with skip connections for precise segmentation | |
| - NSA combines: Compression (global), Selection (important), Sliding Window (local) | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ============================================================================= | |
| # Core Building Blocks | |
| # ============================================================================= | |
| class ConvBNReLU(nn.Module): | |
| """Convolution + BatchNorm + Activation block.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int = 3, | |
| stride: int = 1, | |
| padding: int = 1, | |
| groups: int = 1, | |
| bias: bool = False, | |
| activation: bool = True, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| self.bn = nn.BatchNorm2d( | |
| out_channels | |
| ) | |
| self.act = ( | |
| nn.GELU() | |
| if activation | |
| else nn.Identity() | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| return self.act( | |
| self.bn(self.conv(x)) | |
| ) | |
| class PatchEmbedding(nn.Module): | |
| """ | |
| Embed image patches into tokens for attention processing. | |
| Uses strided convolutions to reduce spatial resolution. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 1, | |
| embed_dim: int = 32, | |
| patch_size: int = 4, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| mid_dim = embed_dim // 2 | |
| # Two-stage downsampling for smoother feature transition | |
| self.conv1 = ConvBNReLU( | |
| in_channels, | |
| mid_dim, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ) | |
| self.conv2 = ConvBNReLU( | |
| mid_dim, | |
| embed_dim, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input image (B, C, H, W) | |
| Returns: | |
| Embedded patches (B, embed_dim, H//4, W//4) | |
| """ | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| # ============================================================================= | |
| # Token Compression Module | |
| # ============================================================================= | |
| class TokenCompression(nn.Module): | |
| """ | |
| Compress spatial blocks into single tokens for coarse-grained attention. | |
| From NSA paper Eq. 7: | |
| K_cmp = {φ(k_{id+1:id+l}) | 0 ≤ i ≤ ⌊(t-l)/d⌋} | |
| Adapted for 2D: compress spatial blocks into representative tokens. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| block_size: int = 4, | |
| stride: int = 2, | |
| ): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.stride = stride | |
| # Learnable compression MLP with position encoding | |
| self.compress_k = nn.Sequential( | |
| nn.Linear( | |
| dim | |
| * block_size | |
| * block_size, | |
| dim * 2, | |
| ), | |
| nn.GELU(), | |
| nn.Linear(dim * 2, dim), | |
| ) | |
| self.compress_v = nn.Sequential( | |
| nn.Linear( | |
| dim | |
| * block_size | |
| * block_size, | |
| dim * 2, | |
| ), | |
| nn.GELU(), | |
| nn.Linear(dim * 2, dim), | |
| ) | |
| # Intra-block position encoding | |
| self.pos_embed = nn.Parameter( | |
| torch.randn( | |
| 1, | |
| block_size * block_size, | |
| dim, | |
| ) | |
| * 0.02 | |
| ) | |
| def forward( | |
| self, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| spatial_size: tuple[int, int], | |
| ) -> tuple[ | |
| torch.Tensor, torch.Tensor | |
| ]: | |
| """ | |
| Compress keys and values into block-level representations. | |
| Args: | |
| k: Keys (B, N, dim) where N = H * W | |
| v: Values (B, N, dim) | |
| spatial_size: (H, W) tuple for non-square inputs | |
| Returns: | |
| k_cmp: Compressed keys (B, N_cmp, dim) | |
| v_cmp: Compressed values (B, N_cmp, dim) | |
| """ | |
| B, N, dim = k.shape | |
| # Use provided spatial dimensions for non-square inputs | |
| H, W = spatial_size | |
| bs = self.block_size | |
| stride = self.stride | |
| # Calculate number of blocks | |
| n_blocks_h = ( | |
| H - bs | |
| ) // stride + 1 | |
| n_blocks_w = ( | |
| W - bs | |
| ) // stride + 1 | |
| # Extract overlapping blocks using unfold | |
| # Use reshape instead of view for non-contiguous tensors | |
| k_2d = ( | |
| k.reshape(B, H, W, dim) | |
| .permute(0, 3, 1, 2) | |
| .contiguous() | |
| ) # (B, dim, H, W) | |
| v_2d = ( | |
| v.reshape(B, H, W, dim) | |
| .permute(0, 3, 1, 2) | |
| .contiguous() | |
| ) | |
| # Unfold to get blocks: (B, dim*bs*bs, n_blocks) | |
| k_blocks = F.unfold( | |
| k_2d, | |
| kernel_size=bs, | |
| stride=stride, | |
| ) | |
| v_blocks = F.unfold( | |
| v_2d, | |
| kernel_size=bs, | |
| stride=stride, | |
| ) | |
| # Reshape for compression: (B, n_blocks, dim*bs*bs) | |
| n_blocks = k_blocks.shape[2] | |
| k_blocks = k_blocks.permute( | |
| 0, 2, 1 | |
| ).contiguous() | |
| v_blocks = v_blocks.permute( | |
| 0, 2, 1 | |
| ).contiguous() | |
| # Add position encoding before compression | |
| # Reshape blocks to add position encoding: (B, n_blocks, bs*bs, dim) | |
| k_blocks_reshaped = ( | |
| k_blocks.reshape( | |
| B, | |
| n_blocks, | |
| bs * bs, | |
| dim, | |
| ) | |
| ) | |
| k_blocks_reshaped = ( | |
| k_blocks_reshaped | |
| + self.pos_embed.unsqueeze( | |
| 0 | |
| ) | |
| ) | |
| k_blocks_pos = ( | |
| k_blocks_reshaped.reshape( | |
| B, | |
| n_blocks, | |
| bs * bs * dim, | |
| ) | |
| ) | |
| # Compress to single tokens | |
| k_cmp = self.compress_k( | |
| k_blocks_pos | |
| ) | |
| v_cmp = self.compress_v( | |
| v_blocks | |
| ) | |
| return k_cmp, v_cmp | |
| # ============================================================================= | |
| # Token Selection Module | |
| # ============================================================================= | |
| class TokenSelection(nn.Module): | |
| """ | |
| Select important token blocks based on attention scores. | |
| From NSA paper Eq. 8-12: | |
| - Compute importance from compressed attention scores | |
| - Select top-n blocks for fine-grained attention | |
| For pupil segmentation: identifies the most relevant spatial regions. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| block_size: int = 4, | |
| num_select: int = 4, | |
| ): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.num_select = num_select | |
| self.dim = dim | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| attn_scores_cmp: torch.Tensor, | |
| spatial_size: tuple[int, int], | |
| ) -> tuple[ | |
| torch.Tensor, | |
| torch.Tensor, | |
| torch.Tensor, | |
| ]: | |
| """ | |
| Select important blocks based on compressed attention scores. | |
| Args: | |
| q: Queries (B, H, N, dim) | |
| k: Keys (B, N, dim) | |
| v: Values (B, N, dim) | |
| attn_scores_cmp: Attention from compression (B, H, N, N_cmp) | |
| spatial_size: (height, width) of feature map | |
| Returns: | |
| k_slc: Selected keys | |
| v_slc: Selected values | |
| indices: Selected block indices | |
| """ | |
| B, num_heads, N, N_cmp = ( | |
| attn_scores_cmp.shape | |
| ) | |
| H, W = spatial_size | |
| bs = self.block_size | |
| # Sum attention across heads for shared selection (GQA-style) | |
| importance = ( | |
| attn_scores_cmp.sum(dim=1) | |
| ) # (B, N, N_cmp) | |
| # Average importance across queries to get block scores | |
| block_importance = ( | |
| importance.mean(dim=1) | |
| ) # (B, N_cmp) | |
| # Select top-n blocks | |
| num_select = min( | |
| self.num_select, N_cmp | |
| ) | |
| _, indices = torch.topk( | |
| block_importance, | |
| num_select, | |
| dim=-1, | |
| ) # (B, num_select) | |
| # Map compressed indices back to original token blocks | |
| # This is simplified - in practice would need proper index mapping | |
| # For now, use the indices to gather from original k, v | |
| # Reshape k, v to blocks | |
| n_blocks_h = (H - bs) // bs + 1 | |
| n_blocks_w = (W - bs) // bs + 1 | |
| # Gather selected blocks | |
| k_2d = ( | |
| k.reshape(B, H, W, -1) | |
| .permute(0, 3, 1, 2) | |
| .contiguous() | |
| ) | |
| v_2d = ( | |
| v.reshape(B, H, W, -1) | |
| .permute(0, 3, 1, 2) | |
| .contiguous() | |
| ) | |
| # Use unfold to extract all blocks | |
| k_blocks = F.unfold( | |
| k_2d, | |
| kernel_size=bs, | |
| stride=bs, | |
| ) # (B, dim*bs*bs, n_blocks) | |
| v_blocks = F.unfold( | |
| v_2d, | |
| kernel_size=bs, | |
| stride=bs, | |
| ) | |
| n_blocks = k_blocks.shape[2] | |
| k_blocks = ( | |
| k_blocks.permute(0, 2, 1) | |
| .contiguous() | |
| .reshape( | |
| B, n_blocks, bs * bs, -1 | |
| ) | |
| ) | |
| v_blocks = ( | |
| v_blocks.permute(0, 2, 1) | |
| .contiguous() | |
| .reshape( | |
| B, n_blocks, bs * bs, -1 | |
| ) | |
| ) | |
| # Clamp indices to valid range | |
| indices = indices.clamp( | |
| 0, n_blocks - 1 | |
| ) | |
| # Gather selected blocks | |
| indices_expanded = ( | |
| indices.unsqueeze(-1) | |
| .unsqueeze(-1) | |
| .expand( | |
| -1, | |
| -1, | |
| bs * bs, | |
| k.shape[-1], | |
| ) | |
| ) | |
| k_slc = torch.gather( | |
| k_blocks, | |
| 1, | |
| indices_expanded, | |
| ) # (B, num_select, bs*bs, dim) | |
| v_slc = torch.gather( | |
| v_blocks, | |
| 1, | |
| indices_expanded, | |
| ) | |
| # Flatten selected blocks | |
| k_slc = k_slc.view( | |
| B, num_select * bs * bs, -1 | |
| ) | |
| v_slc = v_slc.view( | |
| B, num_select * bs * bs, -1 | |
| ) | |
| return k_slc, v_slc, indices | |
| # ============================================================================= | |
| # Sliding Window Attention | |
| # ============================================================================= | |
| class SlidingWindowAttention(nn.Module): | |
| """ | |
| Local sliding window attention for fine-grained local context. | |
| From NSA paper Section 3.3.3: | |
| Maintains recent tokens in a window for local pattern recognition. | |
| For pupil segmentation: critical for precise boundary delineation. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 2, | |
| window_size: int = 7, | |
| qkv_bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.window_size = window_size | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.qkv = nn.Linear( | |
| dim, dim * 3, bias=qkv_bias | |
| ) | |
| self.proj = nn.Linear(dim, dim) | |
| # Relative position bias | |
| self.relative_position_bias_table = nn.Parameter( | |
| torch.zeros( | |
| (2 * window_size - 1) | |
| * (2 * window_size - 1), | |
| num_heads, | |
| ) | |
| ) | |
| nn.init.trunc_normal_( | |
| self.relative_position_bias_table, | |
| std=0.02, | |
| ) | |
| # Create position index | |
| coords_h = torch.arange( | |
| window_size | |
| ) | |
| coords_w = torch.arange( | |
| window_size | |
| ) | |
| coords = torch.stack( | |
| torch.meshgrid( | |
| coords_h, | |
| coords_w, | |
| indexing="ij", | |
| ) | |
| ) | |
| coords_flatten = coords.flatten( | |
| 1 | |
| ) | |
| relative_coords = ( | |
| coords_flatten[:, :, None] | |
| - coords_flatten[:, None, :] | |
| ) | |
| relative_coords = ( | |
| relative_coords.permute( | |
| 1, 2, 0 | |
| ).contiguous() | |
| ) | |
| relative_coords[:, :, 0] += ( | |
| window_size - 1 | |
| ) | |
| relative_coords[:, :, 1] += ( | |
| window_size - 1 | |
| ) | |
| relative_coords[:, :, 0] *= ( | |
| 2 * window_size - 1 | |
| ) | |
| relative_position_index = ( | |
| relative_coords.sum(-1) | |
| ) | |
| self.register_buffer( | |
| "relative_position_index", | |
| relative_position_index, | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Apply sliding window attention. | |
| Args: | |
| x: Input features (B, C, H, W) | |
| Returns: | |
| Output features (B, C, H, W) | |
| """ | |
| B, C, H, W = x.shape | |
| ws = self.window_size | |
| # Pad to multiple of window size | |
| pad_h = (ws - H % ws) % ws | |
| pad_w = (ws - W % ws) % ws | |
| if pad_h > 0 or pad_w > 0: | |
| x = F.pad( | |
| x, (0, pad_w, 0, pad_h) | |
| ) | |
| _, _, Hp, Wp = x.shape | |
| # Reshape to windows: (B*num_windows, ws*ws, C) | |
| x = x.view( | |
| B, | |
| C, | |
| Hp // ws, | |
| ws, | |
| Wp // ws, | |
| ws, | |
| ) | |
| x = x.permute( | |
| 0, 2, 4, 3, 5, 1 | |
| ).contiguous() | |
| x = x.view(-1, ws * ws, C) | |
| # Compute QKV | |
| B_win = x.shape[0] | |
| qkv = self.qkv(x).reshape( | |
| B_win, | |
| ws * ws, | |
| 3, | |
| self.num_heads, | |
| self.head_dim, | |
| ) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| # Attention | |
| attn = ( | |
| q @ k.transpose(-2, -1) | |
| ) * self.scale | |
| # Add relative position bias | |
| relative_position_bias = self.relative_position_bias_table[ | |
| self.relative_position_index.view( | |
| -1 | |
| ) | |
| ].view( | |
| ws * ws, ws * ws, -1 | |
| ) | |
| relative_position_bias = relative_position_bias.permute( | |
| 2, 0, 1 | |
| ).contiguous() | |
| attn = ( | |
| attn | |
| + relative_position_bias.unsqueeze( | |
| 0 | |
| ) | |
| ) | |
| attn = attn.softmax(dim=-1) | |
| x = ( | |
| (attn @ v) | |
| .transpose(1, 2) | |
| .reshape(B_win, ws * ws, C) | |
| ) | |
| x = self.proj(x) | |
| # Reshape back | |
| num_windows_h = Hp // ws | |
| num_windows_w = Wp // ws | |
| x = x.view( | |
| B, | |
| num_windows_h, | |
| num_windows_w, | |
| ws, | |
| ws, | |
| C, | |
| ) | |
| x = x.permute( | |
| 0, 5, 1, 3, 2, 4 | |
| ).contiguous() | |
| x = x.view(B, C, Hp, Wp) | |
| # Remove padding | |
| if pad_h > 0 or pad_w > 0: | |
| x = x[:, :, :H, :W] | |
| return x | |
| # ============================================================================= | |
| # Native Sparse Attention (NSA) - Core Module | |
| # ============================================================================= | |
| class SpatialNSA(nn.Module): | |
| """ | |
| Native Sparse Attention adapted for 2D spatial features. | |
| Combines three attention paths (NSA paper Eq. 5): | |
| o* = Σ g_c · Attn(q, K̃_c, Ṽ_c) for c ∈ {cmp, slc, win} | |
| Components: | |
| 1. Compressed Attention: Global coarse-grained context | |
| 2. Selected Attention: Fine-grained important regions | |
| 3. Sliding Window: Local context for precise boundaries | |
| 4. Gated Aggregation: Learned combination | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 2, | |
| compress_block_size: int = 4, | |
| compress_stride: int = 2, | |
| select_block_size: int = 4, | |
| num_select: int = 4, | |
| window_size: int = 7, | |
| qkv_bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| # Separate QKV for each branch (prevents shortcut learning) | |
| self.qkv_cmp = nn.Linear( | |
| dim, dim * 3, bias=qkv_bias | |
| ) | |
| self.qkv_slc = nn.Linear( | |
| dim, dim * 3, bias=qkv_bias | |
| ) | |
| # Token compression module | |
| self.compression = TokenCompression( | |
| dim=dim, | |
| block_size=compress_block_size, | |
| stride=compress_stride, | |
| ) | |
| # Token selection module | |
| self.selection = TokenSelection( | |
| dim=dim, | |
| block_size=select_block_size, | |
| num_select=num_select, | |
| ) | |
| # Sliding window attention | |
| self.window_attn = ( | |
| SlidingWindowAttention( | |
| dim=dim, | |
| num_heads=num_heads, | |
| window_size=window_size, | |
| qkv_bias=qkv_bias, | |
| ) | |
| ) | |
| # Output projections | |
| self.proj_cmp = nn.Linear( | |
| dim, dim | |
| ) | |
| self.proj_slc = nn.Linear( | |
| dim, dim | |
| ) | |
| # Gating mechanism (NSA paper Eq. 5) | |
| self.gate = nn.Sequential( | |
| nn.Linear(dim, dim // 4), | |
| nn.GELU(), | |
| nn.Linear(dim // 4, 3), | |
| nn.Sigmoid(), | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Apply Native Sparse Attention. | |
| Args: | |
| x: Input features (B, C, H, W) | |
| Returns: | |
| Output features (B, C, H, W) | |
| """ | |
| B, C, H, W = x.shape | |
| N = H * W | |
| # Reshape to sequence | |
| x_seq = x.flatten(2).transpose( | |
| 1, 2 | |
| ) # (B, N, C) | |
| # ================================================================= | |
| # Branch 1: Compressed Attention (Global Coarse-Grained) | |
| # ================================================================= | |
| qkv_cmp = self.qkv_cmp(x_seq) | |
| qkv_cmp = qkv_cmp.reshape( | |
| B, | |
| N, | |
| 3, | |
| self.num_heads, | |
| self.head_dim, | |
| ) | |
| qkv_cmp = qkv_cmp.permute( | |
| 2, 0, 3, 1, 4 | |
| ) | |
| q_cmp, k_cmp_raw, v_cmp_raw = ( | |
| qkv_cmp[0], | |
| qkv_cmp[1], | |
| qkv_cmp[2], | |
| ) | |
| # Reshape k, v for compression | |
| k_for_cmp = k_cmp_raw.transpose( | |
| 1, 2 | |
| ).reshape(B, N, C) | |
| v_for_cmp = v_cmp_raw.transpose( | |
| 1, 2 | |
| ).reshape(B, N, C) | |
| # Compress tokens | |
| k_cmp, v_cmp = self.compression( | |
| k_for_cmp, v_for_cmp, (H, W) | |
| ) | |
| N_cmp = k_cmp.shape[1] | |
| # Reshape for multi-head attention | |
| k_cmp = k_cmp.view( | |
| B, | |
| N_cmp, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| v_cmp = v_cmp.view( | |
| B, | |
| N_cmp, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| # Compute compressed attention | |
| attn_cmp = ( | |
| q_cmp | |
| ) * self.scale | |
| attn_cmp_softmax = ( | |
| attn_cmp.softmax(dim=-1) | |
| ) | |
| o_cmp = attn_cmp_softmax @ v_cmp | |
| o_cmp = o_cmp.transpose( | |
| 1, 2 | |
| ).reshape(B, N, C) | |
| o_cmp = self.proj_cmp(o_cmp) | |
| # ================================================================= | |
| # Branch 2: Selected Attention (Fine-Grained Important) | |
| # ================================================================= | |
| qkv_slc = self.qkv_slc(x_seq) | |
| qkv_slc = qkv_slc.reshape( | |
| B, | |
| N, | |
| 3, | |
| self.num_heads, | |
| self.head_dim, | |
| ) | |
| qkv_slc = qkv_slc.permute( | |
| 2, 0, 3, 1, 4 | |
| ) | |
| q_slc, k_slc_raw, v_slc_raw = ( | |
| qkv_slc[0], | |
| qkv_slc[1], | |
| qkv_slc[2], | |
| ) | |
| k_for_slc = k_slc_raw.transpose( | |
| 1, 2 | |
| ).reshape(B, N, C) | |
| v_for_slc = v_slc_raw.transpose( | |
| 1, 2 | |
| ).reshape(B, N, C) | |
| # Select important blocks based on compressed attention scores | |
| k_slc, v_slc, _ = ( | |
| self.selection( | |
| q_slc, | |
| k_for_slc, | |
| v_for_slc, | |
| attn_cmp_softmax, | |
| (H, W), | |
| ) | |
| ) | |
| N_slc = k_slc.shape[1] | |
| k_slc = k_slc.view( | |
| B, | |
| N_slc, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| v_slc = v_slc.view( | |
| B, | |
| N_slc, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| # Compute selected attention | |
| attn_slc = ( | |
| q_slc | |
| ) * self.scale | |
| attn_slc = attn_slc.softmax( | |
| dim=-1 | |
| ) | |
| o_slc = attn_slc @ v_slc | |
| o_slc = o_slc.transpose( | |
| 1, 2 | |
| ).reshape(B, N, C) | |
| o_slc = self.proj_slc(o_slc) | |
| # ================================================================= | |
| # Branch 3: Sliding Window Attention (Local Context) | |
| # ================================================================= | |
| o_win = self.window_attn(x) | |
| o_win = o_win.flatten( | |
| 2 | |
| ).transpose( | |
| 1, 2 | |
| ) # (B, N, C) | |
| # ================================================================= | |
| # Gated Aggregation | |
| # ================================================================= | |
| # Compute per-token gates | |
| gates = self.gate( | |
| x_seq | |
| ) # (B, N, 3) | |
| g_cmp = gates[:, :, 0:1] | |
| g_slc = gates[:, :, 1:2] | |
| g_win = gates[:, :, 2:3] | |
| # Weighted combination | |
| out = ( | |
| g_cmp * o_cmp | |
| + g_slc * o_slc | |
| + g_win * o_win | |
| ) | |
| # Reshape back to spatial | |
| out = out.transpose(1, 2).view( | |
| B, C, H, W | |
| ) | |
| return out | |
| # ============================================================================= | |
| # NSA Block (Attention + FFN) | |
| # ============================================================================= | |
| class NSABlock(nn.Module): | |
| """ | |
| Complete NSA block with attention, normalization, and FFN. | |
| Structure: | |
| - Depthwise conv for local features (like EfficientViT) | |
| - Native Sparse Attention for global/selective features | |
| - FFN for channel mixing | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 2, | |
| mlp_ratio: float = 2.0, | |
| compress_block_size: int = 4, | |
| compress_stride: int = 2, | |
| select_block_size: int = 4, | |
| num_select: int = 4, | |
| window_size: int = 7, | |
| ): | |
| super().__init__() | |
| # Local feature extraction (depthwise conv) | |
| self.norm1 = nn.BatchNorm2d(dim) | |
| self.dw_conv = nn.Conv2d( | |
| dim, | |
| dim, | |
| kernel_size=3, | |
| padding=1, | |
| groups=dim, | |
| ) | |
| # NSA attention | |
| self.norm2 = nn.BatchNorm2d(dim) | |
| self.nsa = SpatialNSA( | |
| dim=dim, | |
| num_heads=num_heads, | |
| compress_block_size=compress_block_size, | |
| compress_stride=compress_stride, | |
| select_block_size=select_block_size, | |
| num_select=num_select, | |
| window_size=window_size, | |
| ) | |
| # FFN | |
| self.norm3 = nn.LayerNorm(dim) | |
| hidden_dim = int( | |
| dim * mlp_ratio | |
| ) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, dim), | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input features (B, C, H, W) | |
| Returns: | |
| Output features (B, C, H, W) | |
| """ | |
| # Local features | |
| x = x + self.dw_conv( | |
| self.norm1(x) | |
| ) | |
| # NSA attention | |
| x = x + self.nsa(self.norm2(x)) | |
| # FFN | |
| B, C, H, W = x.shape | |
| x_flat = x.flatten(2).transpose( | |
| 1, 2 | |
| ) # (B, N, C) | |
| x_flat = x_flat + self.ffn( | |
| self.norm3(x_flat) | |
| ) | |
| x = x_flat.transpose(1, 2).view( | |
| B, C, H, W | |
| ) | |
| return x | |
| # ============================================================================= | |
| # NSA Stage (Multiple Blocks + Optional Downsampling) | |
| # ============================================================================= | |
| class NSAStage(nn.Module): | |
| """ | |
| Stage containing multiple NSA blocks with optional downsampling. | |
| """ | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| out_dim: int, | |
| depth: int = 1, | |
| num_heads: int = 2, | |
| mlp_ratio: float = 2.0, | |
| compress_block_size: int = 4, | |
| compress_stride: int = 2, | |
| select_block_size: int = 4, | |
| num_select: int = 4, | |
| window_size: int = 7, | |
| downsample: bool = True, | |
| ): | |
| super().__init__() | |
| # Downsampling | |
| self.downsample = None | |
| if downsample: | |
| self.downsample = ( | |
| nn.Sequential( | |
| ConvBNReLU( | |
| in_dim, | |
| out_dim, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ), | |
| ) | |
| ) | |
| elif in_dim != out_dim: | |
| self.downsample = ( | |
| ConvBNReLU( | |
| in_dim, | |
| out_dim, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| ) | |
| ) | |
| # NSA blocks | |
| self.blocks = nn.ModuleList( | |
| [ | |
| NSABlock( | |
| dim=out_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| compress_block_size=compress_block_size, | |
| compress_stride=compress_stride, | |
| select_block_size=select_block_size, | |
| num_select=num_select, | |
| window_size=window_size, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| return x | |
| # ============================================================================= | |
| # NSA Encoder | |
| # ============================================================================= | |
| class NSAEncoder(nn.Module): | |
| """ | |
| NSA-based encoder for hierarchical feature extraction. | |
| Produces multi-scale features for segmentation decoder. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 1, | |
| embed_dims: tuple = ( | |
| 32, | |
| 64, | |
| 96, | |
| ), | |
| depths: tuple = (1, 1, 1), | |
| num_heads: tuple = (2, 2, 4), | |
| mlp_ratios: tuple = (2, 2, 2), | |
| compress_block_sizes: tuple = ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| compress_strides: tuple = ( | |
| 2, | |
| 2, | |
| 2, | |
| ), | |
| select_block_sizes: tuple = ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| num_selects: tuple = (4, 4, 4), | |
| window_sizes: tuple = (7, 7, 7), | |
| ): | |
| super().__init__() | |
| # Patch embedding | |
| self.patch_embed = ( | |
| PatchEmbedding( | |
| in_channels=in_channels, | |
| embed_dim=embed_dims[0], | |
| ) | |
| ) | |
| # Stage 1: No downsampling (already done in patch embed) | |
| self.stage1 = NSAStage( | |
| in_dim=embed_dims[0], | |
| out_dim=embed_dims[0], | |
| depth=depths[0], | |
| num_heads=num_heads[0], | |
| mlp_ratio=mlp_ratios[0], | |
| compress_block_size=compress_block_sizes[ | |
| 0 | |
| ], | |
| compress_stride=compress_strides[ | |
| 0 | |
| ], | |
| select_block_size=select_block_sizes[ | |
| 0 | |
| ], | |
| num_select=num_selects[0], | |
| window_size=window_sizes[0], | |
| downsample=False, | |
| ) | |
| # Stage 2: Downsample 2x | |
| self.stage2 = NSAStage( | |
| in_dim=embed_dims[0], | |
| out_dim=embed_dims[1], | |
| depth=depths[1], | |
| num_heads=num_heads[1], | |
| mlp_ratio=mlp_ratios[1], | |
| compress_block_size=compress_block_sizes[ | |
| 1 | |
| ], | |
| compress_stride=compress_strides[ | |
| 1 | |
| ], | |
| select_block_size=select_block_sizes[ | |
| 1 | |
| ], | |
| num_select=num_selects[1], | |
| window_size=window_sizes[1], | |
| downsample=True, | |
| ) | |
| # Stage 3: Downsample 2x | |
| self.stage3 = NSAStage( | |
| in_dim=embed_dims[1], | |
| out_dim=embed_dims[2], | |
| depth=depths[2], | |
| num_heads=num_heads[2], | |
| mlp_ratio=mlp_ratios[2], | |
| compress_block_size=compress_block_sizes[ | |
| 2 | |
| ], | |
| compress_stride=compress_strides[ | |
| 2 | |
| ], | |
| select_block_size=select_block_sizes[ | |
| 2 | |
| ], | |
| num_select=num_selects[2], | |
| window_size=window_sizes[2], | |
| downsample=True, | |
| ) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> tuple: | |
| """ | |
| Args: | |
| x: Input image (B, C, H, W) | |
| Returns: | |
| Multi-scale features (f1, f2, f3) | |
| """ | |
| x = self.patch_embed(x) | |
| f1 = self.stage1( | |
| x | |
| ) # 1/4 resolution | |
| f2 = self.stage2( | |
| f1 | |
| ) # 1/8 resolution | |
| f3 = self.stage3( | |
| f2 | |
| ) # 1/16 resolution | |
| return f1, f2, f3 | |
| # ============================================================================= | |
| # Segmentation Decoder | |
| # ============================================================================= | |
| class SegmentationDecoder(nn.Module): | |
| """ | |
| FPN-style decoder with skip connections for precise segmentation. | |
| Progressively upsamples features to input resolution. | |
| """ | |
| def __init__( | |
| self, | |
| encoder_dims: tuple = ( | |
| 32, | |
| 64, | |
| 96, | |
| ), | |
| decoder_dim: int = 32, | |
| num_classes: int = 2, | |
| ): | |
| super().__init__() | |
| # Lateral connections | |
| self.lateral3 = nn.Conv2d( | |
| encoder_dims[2], | |
| decoder_dim, | |
| kernel_size=1, | |
| ) | |
| self.lateral2 = nn.Conv2d( | |
| encoder_dims[1], | |
| decoder_dim, | |
| kernel_size=1, | |
| ) | |
| self.lateral1 = nn.Conv2d( | |
| encoder_dims[0], | |
| decoder_dim, | |
| kernel_size=1, | |
| ) | |
| # Smoothing convolutions | |
| self.smooth3 = nn.Sequential( | |
| nn.Conv2d( | |
| decoder_dim, | |
| decoder_dim, | |
| kernel_size=3, | |
| padding=1, | |
| groups=decoder_dim, | |
| ), | |
| nn.BatchNorm2d(decoder_dim), | |
| nn.GELU(), | |
| ) | |
| self.smooth2 = nn.Sequential( | |
| nn.Conv2d( | |
| decoder_dim, | |
| decoder_dim, | |
| kernel_size=3, | |
| padding=1, | |
| groups=decoder_dim, | |
| ), | |
| nn.BatchNorm2d(decoder_dim), | |
| nn.GELU(), | |
| ) | |
| self.smooth1 = nn.Sequential( | |
| nn.Conv2d( | |
| decoder_dim, | |
| decoder_dim, | |
| kernel_size=3, | |
| padding=1, | |
| groups=decoder_dim, | |
| ), | |
| nn.BatchNorm2d(decoder_dim), | |
| nn.GELU(), | |
| ) | |
| # Segmentation head | |
| self.head = nn.Conv2d( | |
| decoder_dim, | |
| num_classes, | |
| kernel_size=1, | |
| ) | |
| def forward( | |
| self, | |
| f1: torch.Tensor, | |
| f2: torch.Tensor, | |
| f3: torch.Tensor, | |
| target_size: tuple, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| f1, f2, f3: Multi-scale encoder features | |
| target_size: (H, W) of output | |
| Returns: | |
| Segmentation logits (B, num_classes, H, W) | |
| """ | |
| # Top-down path with lateral connections | |
| p3 = self.lateral3(f3) | |
| p3 = self.smooth3(p3) | |
| p2 = self.lateral2( | |
| f2 | |
| ) + F.interpolate( | |
| p3, | |
| size=f2.shape[2:], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| p2 = self.smooth2(p2) | |
| p1 = self.lateral1( | |
| f1 | |
| ) + F.interpolate( | |
| p2, | |
| size=f1.shape[2:], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| p1 = self.smooth1(p1) | |
| # Segmentation output | |
| out = self.head(p1) | |
| out = F.interpolate( | |
| out, | |
| size=target_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return out | |
| # ============================================================================= | |
| # Complete NSA Pupil Segmentation Model | |
| # ============================================================================= | |
| class NSAPupilSeg(nn.Module): | |
| """ | |
| Native Sparse Attention model for Pupil Segmentation. | |
| Architecture: | |
| - NSA Encoder: Hierarchical feature extraction with sparse attention | |
| - FPN Decoder: Multi-scale feature fusion for precise segmentation | |
| Key NSA components for pupil segmentation: | |
| - Compression: Captures global eye context (is this an eye? rough pupil location) | |
| - Selection: Focuses on pupil region with fine-grained attention | |
| - Sliding Window: Precise local boundaries for pixel-accurate segmentation | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 1, | |
| num_classes: int = 2, | |
| embed_dims: tuple = ( | |
| 32, | |
| 64, | |
| 96, | |
| ), | |
| depths: tuple = (1, 1, 1), | |
| num_heads: tuple = (2, 2, 4), | |
| mlp_ratios: tuple = (2, 2, 2), | |
| compress_block_sizes: tuple = ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| compress_strides: tuple = ( | |
| 2, | |
| 2, | |
| 2, | |
| ), | |
| select_block_sizes: tuple = ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| num_selects: tuple = (4, 4, 4), | |
| window_sizes: tuple = (7, 7, 7), | |
| decoder_dim: int = 32, | |
| ): | |
| super().__init__() | |
| self.encoder = NSAEncoder( | |
| in_channels=in_channels, | |
| embed_dims=embed_dims, | |
| depths=depths, | |
| num_heads=num_heads, | |
| mlp_ratios=mlp_ratios, | |
| compress_block_sizes=compress_block_sizes, | |
| compress_strides=compress_strides, | |
| select_block_sizes=select_block_sizes, | |
| num_selects=num_selects, | |
| window_sizes=window_sizes, | |
| ) | |
| self.decoder = ( | |
| SegmentationDecoder( | |
| encoder_dims=embed_dims, | |
| decoder_dim=decoder_dim, | |
| num_classes=num_classes, | |
| ) | |
| ) | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| """Initialize model weights.""" | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_( | |
| m.weight, | |
| mode="fan_out", | |
| nonlinearity="relu", | |
| ) | |
| if m.bias is not None: | |
| nn.init.zeros_( | |
| m.bias | |
| ) | |
| elif isinstance( | |
| m, nn.BatchNorm2d | |
| ): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| elif isinstance( | |
| m, nn.Linear | |
| ): | |
| nn.init.trunc_normal_( | |
| m.weight, std=0.02 | |
| ) | |
| if m.bias is not None: | |
| nn.init.zeros_( | |
| m.bias | |
| ) | |
| elif isinstance( | |
| m, nn.LayerNorm | |
| ): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def forward( | |
| self, x: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Input image (B, C, H, W) | |
| Returns: | |
| Segmentation logits (B, num_classes, H, W) | |
| """ | |
| target_size = ( | |
| x.shape[2], | |
| x.shape[3], | |
| ) | |
| f1, f2, f3 = self.encoder(x) | |
| out = self.decoder( | |
| f1, f2, f3, target_size | |
| ) | |
| return out | |
| # ============================================================================= | |
| # Loss Function (same as src/ for compatibility) | |
| # ============================================================================= | |
| def focal_surface_loss( | |
| probs: torch.Tensor, | |
| dist_map: torch.Tensor, | |
| gamma: float = 2.0, | |
| ) -> torch.Tensor: | |
| """Surface loss with focal weighting for hard boundary pixels. | |
| Args: | |
| probs: Predicted probabilities (B, C, H, W) | |
| dist_map: Distance transform (B, 2, H, W) | |
| gamma: Focal weighting exponent | |
| Returns: | |
| Focal-weighted surface loss scalar | |
| """ | |
| focal_weight = (1 - probs) ** gamma | |
| return ( | |
| (focal_weight * probs * dist_map) | |
| .flatten(start_dim=2) | |
| .mean(dim=2) | |
| .mean(dim=1) | |
| .mean() | |
| ) | |
| def boundary_dice_loss( | |
| probs: torch.Tensor, | |
| target: torch.Tensor, | |
| kernel_size: int = 3, | |
| epsilon: float = 1e-5, | |
| ) -> torch.Tensor: | |
| """Dice loss computed only on boundary pixels. | |
| Args: | |
| probs: Predicted probabilities (B, C, H, W) | |
| target: Ground truth labels (B, H, W) | |
| kernel_size: Size of kernel for boundary extraction | |
| epsilon: Small constant for numerical stability | |
| Returns: | |
| Boundary dice loss scalar | |
| """ | |
| # Extract boundary via morphological gradient | |
| target_float = target.float().unsqueeze(1) | |
| padding = kernel_size // 2 | |
| dilated = F.max_pool2d( | |
| target_float, | |
| kernel_size, | |
| stride=1, | |
| padding=padding, | |
| ) | |
| eroded = -F.max_pool2d( | |
| -target_float, | |
| kernel_size, | |
| stride=1, | |
| padding=padding, | |
| ) | |
| boundary = (dilated - eroded).squeeze(1) # (B, H, W) | |
| # Compute Dice only on boundary pixels | |
| probs_pupil = probs[:, 1] # pupil class probabilities (B, H, W) | |
| probs_boundary = probs_pupil * boundary | |
| target_boundary = target.float() * boundary | |
| intersection = ( | |
| probs_boundary * target_boundary | |
| ).sum(dim=(1, 2)) | |
| union = probs_boundary.sum( | |
| dim=(1, 2) | |
| ) + target_boundary.sum(dim=(1, 2)) | |
| dice = ( | |
| 2.0 * intersection + epsilon | |
| ) / (union + epsilon) | |
| return (1.0 - dice).mean() | |
| class CombinedLoss(nn.Module): | |
| """ | |
| Combined loss for pupil segmentation: | |
| - Weighted Cross Entropy: Handles class imbalance | |
| - Dice Loss: Better for small regions like pupils | |
| - Focal Surface Loss: Boundary-aware optimization with focal weighting | |
| - Boundary Dice Loss: Explicit optimization for edge pixels | |
| """ | |
| def __init__( | |
| self, | |
| epsilon: float = 1e-5, | |
| focal_gamma: float = 2.0, | |
| boundary_weight: float = 0.3, | |
| boundary_kernel_size: int = 3, | |
| ): | |
| super().__init__() | |
| self.epsilon = epsilon | |
| self.focal_gamma = focal_gamma | |
| self.boundary_weight = boundary_weight | |
| self.boundary_kernel_size = boundary_kernel_size | |
| self.nll = nn.NLLLoss( | |
| reduction="none" | |
| ) | |
| def forward( | |
| self, | |
| logits: torch.Tensor, | |
| target: torch.Tensor, | |
| spatial_weights: torch.Tensor, | |
| dist_map: torch.Tensor, | |
| alpha: float, | |
| eye_weight: torch.Tensor = None, | |
| ) -> tuple: | |
| """ | |
| Args: | |
| logits: Model output (B, C, H, W) | |
| target: Ground truth (B, H, W) | |
| spatial_weights: Spatial weighting map (B, H, W) | |
| dist_map: Distance map for surface loss (B, 2, H, W) | |
| alpha: Balance between dice and surface loss | |
| eye_weight: Soft distance weighting from eye region (B, H, W) | |
| Returns: | |
| (total_loss, ce_loss, dice_loss, surface_loss, boundary_loss) | |
| """ | |
| probs = F.softmax(logits, dim=1) | |
| log_probs = F.log_softmax( | |
| logits, dim=1 | |
| ) | |
| # Weighted Cross Entropy | |
| ce_loss = self.nll( | |
| log_probs, target | |
| ) | |
| # Apply spatial weights and optional eye weight | |
| weight_factor = 1.0 + spatial_weights | |
| if eye_weight is not None: | |
| weight_factor = weight_factor * eye_weight | |
| weighted_ce = ( | |
| ce_loss * weight_factor | |
| ).mean() | |
| # Dice Loss | |
| target_onehot = ( | |
| F.one_hot( | |
| target, num_classes=2 | |
| ) | |
| .permute(0, 3, 1, 2) | |
| .float() | |
| ) | |
| probs_flat = probs.flatten( | |
| start_dim=2 | |
| ) | |
| target_flat = ( | |
| target_onehot.flatten( | |
| start_dim=2 | |
| ) | |
| ) | |
| intersection = ( | |
| probs_flat * target_flat | |
| ).sum(dim=2) | |
| cardinality = ( | |
| probs_flat + target_flat | |
| ).sum(dim=2) | |
| class_weights = 1.0 / ( | |
| target_flat.sum(dim=2) ** 2 | |
| ).clamp(min=self.epsilon) | |
| dice = ( | |
| 2.0 | |
| * ( | |
| class_weights | |
| * intersection | |
| ).sum(dim=1) | |
| / ( | |
| class_weights | |
| * cardinality | |
| ).sum(dim=1) | |
| ) | |
| dice_loss = ( | |
| 1.0 | |
| - dice.clamp( | |
| min=self.epsilon | |
| ) | |
| ).mean() | |
| # Focal Surface Loss (replaces standard surface loss) | |
| surface_loss = focal_surface_loss( | |
| probs, | |
| dist_map, | |
| gamma=self.focal_gamma, | |
| ) | |
| # Boundary Dice Loss | |
| bdice_loss = boundary_dice_loss( | |
| probs, | |
| target, | |
| kernel_size=self.boundary_kernel_size, | |
| epsilon=self.epsilon, | |
| ) | |
| # Total loss with updated weighting | |
| # Use max(1 - alpha, 0.2) for surface loss weight | |
| surface_weight = max(1.0 - alpha, 0.2) | |
| total_loss = ( | |
| weighted_ce | |
| + alpha * dice_loss | |
| + surface_weight * surface_loss | |
| + self.boundary_weight * bdice_loss | |
| ) | |
| return ( | |
| total_loss, | |
| weighted_ce, | |
| dice_loss, | |
| surface_loss, | |
| bdice_loss, | |
| ) | |
| # ============================================================================= | |
| # Factory function for easy model creation | |
| # ============================================================================= | |
| def create_nsa_pupil_seg( | |
| size: str = "small", | |
| in_channels: int = 1, | |
| num_classes: int = 2, | |
| ) -> NSAPupilSeg: | |
| """ | |
| Create NSA Pupil Segmentation model with predefined configurations. | |
| Args: | |
| size: Model size ('pico', 'nano', 'tiny', 'small', 'medium') | |
| in_channels: Number of input channels | |
| num_classes: Number of output classes | |
| Returns: | |
| Configured NSAPupilSeg model | |
| """ | |
| configs = { | |
| "pico": { | |
| "embed_dims": (4, 4, 4), | |
| "depths": (1, 1, 1), | |
| "num_heads": (1, 1, 1), | |
| "mlp_ratios": ( | |
| 1.0, | |
| 1.0, | |
| 1.0, | |
| ), | |
| "compress_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "compress_strides": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "select_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "num_selects": (1, 1, 1), | |
| "window_sizes": (3, 3, 3), | |
| "decoder_dim": 4, | |
| }, | |
| "nano": { | |
| "embed_dims": (4, 8, 12), | |
| "depths": (1, 1, 1), | |
| "num_heads": (1, 1, 1), | |
| "mlp_ratios": ( | |
| 1.0, | |
| 1.0, | |
| 1.0, | |
| ), | |
| "compress_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "compress_strides": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "select_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "num_selects": (1, 1, 1), | |
| "window_sizes": (3, 3, 3), | |
| "decoder_dim": 4, | |
| }, | |
| "tiny": { | |
| "embed_dims": (8, 12, 16), | |
| "depths": (1, 1, 1), | |
| "num_heads": (1, 1, 1), | |
| "mlp_ratios": ( | |
| 1.5, | |
| 1.5, | |
| 1.5, | |
| ), | |
| "compress_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "compress_strides": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "select_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "num_selects": (1, 1, 1), | |
| "window_sizes": (3, 3, 3), | |
| "decoder_dim": 8, | |
| }, | |
| "small": { | |
| "embed_dims": (12, 24, 32), | |
| "depths": (1, 1, 1), | |
| "num_heads": (1, 1, 2), | |
| "mlp_ratios": ( | |
| 1.5, | |
| 1.5, | |
| 1.5, | |
| ), | |
| "compress_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "compress_strides": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "select_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "num_selects": (1, 1, 1), | |
| "window_sizes": (3, 3, 3), | |
| "decoder_dim": 12, | |
| }, | |
| "medium": { | |
| "embed_dims": (16, 32, 48), | |
| "depths": (1, 1, 1), | |
| "num_heads": (1, 2, 2), | |
| "mlp_ratios": ( | |
| 1.5, | |
| 1.5, | |
| 1.5, | |
| ), | |
| "compress_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "compress_strides": ( | |
| 3, | |
| 3, | |
| 3, | |
| ), | |
| "select_block_sizes": ( | |
| 4, | |
| 4, | |
| 4, | |
| ), | |
| "num_selects": (2, 2, 2), | |
| "window_sizes": (3, 3, 3), | |
| "decoder_dim": 16, | |
| }, | |
| } | |
| if size not in configs: | |
| raise ValueError( | |
| f"Unknown size: {size}. Choose from {list(configs.keys())}" | |
| ) | |
| return NSAPupilSeg( | |
| in_channels=in_channels, | |
| num_classes=num_classes, | |
| **configs[size], | |
| ) | |
| # ============================================================================= | |
| # Testing / Verification | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| # Test model creation and forward pass | |
| print( | |
| "Testing NSA Pupil Segmentation Model" | |
| ) | |
| print("=" * 60) | |
| # Create models of different sizes | |
| for size in [ | |
| "pico", | |
| "nano", | |
| "tiny", | |
| "small", | |
| "medium", | |
| ]: | |
| model = create_nsa_pupil_seg( | |
| size=size | |
| ) | |
| # Count parameters | |
| n_params = sum( | |
| p.numel() | |
| for p in model.parameters() | |
| ) | |
| # Test forward pass | |
| x = torch.randn( | |
| 2, 1, 400, 640 | |
| ) # OpenEDS image size | |
| model.eval() | |
| with torch.no_grad(): | |
| out = model(x) | |
| print( | |
| f"\n{size.upper()} Model:" | |
| ) | |
| print( | |
| f" Parameters: {n_params:,}" | |
| ) | |
| print( | |
| f" Input shape: {x.shape}" | |
| ) | |
| print( | |
| f" Output shape: {out.shape}" | |
| ) | |
| print("\n" + "=" * 60) | |
| print("All tests passed!") | |