connerohnesorge commited on
Commit
a69fe43
·
1 Parent(s): 1777497
__pycache__/app.cpython-313.pyc ADDED
Binary file (11 kB). View file
 
app.py CHANGED
@@ -1,8 +1,393 @@
 
 
 
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
 
 
6
 
7
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
8
- demo.launch()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ NSA Pupil Segmentation Gradio Demo - Native Sparse Attention Web Application
4
 
5
+ This Gradio application performs real-time pupil segmentation on webcam input
6
+ using the NSAPupilSeg model (Native Sparse Attention). It demonstrates eye tracking
7
+ and pupil detection capabilities for the VisionAssist medical assistive technology project.
8
+
9
+ NSA Key Features:
10
+ - Token Compression: Global coarse-grained context
11
+ - Token Selection: Fine-grained focus on important regions (pupil)
12
+ - Sliding Window: Local context for precise boundaries
13
+ - Gated Aggregation: Learned combination of attention paths
14
+ """
15
+
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
  import gradio as gr
20
+ import mediapipe as mp
21
+
22
+ from nsa import create_nsa_pupil_seg
23
+
24
+ # =============================================================================
25
+ # Model Loading (at module startup)
26
+ # =============================================================================
27
+
28
+ print("Loading NSA Pupil Segmentation model...")
29
+
30
+ model = create_nsa_pupil_seg(size="pico", in_channels=1, num_classes=2)
31
+ checkpoint = torch.load("best_model.pth", map_location="cpu", weights_only=False)
32
+ if "model_state_dict" in checkpoint:
33
+ model.load_state_dict(checkpoint["model_state_dict"])
34
+ print(f"Loaded checkpoint with IoU: {checkpoint.get('valid_iou', 'N/A')}")
35
+ else:
36
+ model.load_state_dict(checkpoint)
37
+ model.eval()
38
+
39
+ print("Model loaded successfully!")
40
+
41
+ # =============================================================================
42
+ # MediaPipe Face Mesh Setup
43
+ # =============================================================================
44
+
45
+ mp_face_mesh = mp.solutions.face_mesh
46
+ face_mesh = mp_face_mesh.FaceMesh(
47
+ max_num_faces=1,
48
+ refine_landmarks=True,
49
+ min_detection_confidence=0.5,
50
+ min_tracking_confidence=0.5,
51
+ )
52
+
53
+ # =============================================================================
54
+ # Constants (from demo.py - MUST match training exactly)
55
+ # =============================================================================
56
+
57
+ # MediaPipe left eye landmark indices (12 points around the eye)
58
+ LEFT_EYE_INDICES = [362, 385, 387, 263, 373, 380, 374, 381, 382, 384, 398, 466]
59
+
60
+ # Target aspect ratio for eye region (width:height = 640:400 = 1.6:1)
61
+ TARGET_ASPECT_RATIO = 640 / 400 # 1.6:1
62
+
63
+ # Model input/output dimensions
64
+ MODEL_WIDTH = 640
65
+ MODEL_HEIGHT = 400
66
+
67
+ # Preprocessing parameters (MUST match training exactly)
68
+ NORMALIZE_MEAN = 0.5
69
+ NORMALIZE_STD = 0.5
70
+
71
+ # Eye extraction settings
72
+ BBOX_PADDING = 0.2 # 20% padding on each side
73
+ MIN_EYE_REGION_SIZE = 50 # Minimum bounding box size
74
+
75
+ # Visualization settings
76
+ OVERLAY_ALPHA = 0.5
77
+
78
+
79
+ # =============================================================================
80
+ # Eye Region Extraction Function
81
+ # =============================================================================
82
+
83
+ def extract_eye_region(frame, landmarks):
84
+ """
85
+ Extract left eye region from frame using MediaPipe landmarks.
86
+
87
+ Args:
88
+ frame: Input BGR frame
89
+ landmarks: MediaPipe face landmarks
90
+
91
+ Returns:
92
+ tuple: (eye_crop, bbox) where bbox is (x, y, w, h), or (None, None)
93
+ """
94
+ h, w = frame.shape[:2]
95
+
96
+ # Extract left eye landmark coordinates
97
+ eye_points = np.array([
98
+ [int(landmarks.landmark[idx].x * w), int(landmarks.landmark[idx].y * h)]
99
+ for idx in LEFT_EYE_INDICES
100
+ ], dtype=np.int32)
101
+
102
+ # Compute bounding box
103
+ x_min, y_min = eye_points.min(axis=0)
104
+ x_max, y_max = eye_points.max(axis=0)
105
+
106
+ bbox_w = x_max - x_min
107
+ bbox_h = y_max - y_min
108
+
109
+ # Check if eye region is large enough
110
+ if bbox_w < MIN_EYE_REGION_SIZE or bbox_h < MIN_EYE_REGION_SIZE:
111
+ return None, None
112
+
113
+ # Add padding (20% on each side)
114
+ pad_w = int(bbox_w * BBOX_PADDING)
115
+ pad_h = int(bbox_h * BBOX_PADDING)
116
+
117
+ x_min = max(0, x_min - pad_w)
118
+ y_min = max(0, y_min - pad_h)
119
+ x_max = min(w, x_max + pad_w)
120
+ y_max = min(h, y_max + pad_h)
121
+
122
+ bbox_w = x_max - x_min
123
+ bbox_h = y_max - y_min
124
+
125
+ # Expand to 1.6:1 aspect ratio (640:400)
126
+ current_ratio = bbox_w / bbox_h
127
+ if current_ratio < TARGET_ASPECT_RATIO:
128
+ # Too narrow, expand width
129
+ target_w = int(bbox_h * TARGET_ASPECT_RATIO)
130
+ diff = target_w - bbox_w
131
+ x_min = max(0, x_min - diff // 2)
132
+ x_max = min(w, x_max + diff // 2)
133
+ bbox_w = x_max - x_min
134
+ else:
135
+ # Too short, expand height
136
+ target_h = int(bbox_w / TARGET_ASPECT_RATIO)
137
+ diff = target_h - bbox_h
138
+ y_min = max(0, y_min - diff // 2)
139
+ y_max = min(h, y_max + diff // 2)
140
+ bbox_h = y_max - y_min
141
+
142
+ # Extract region
143
+ eye_crop = frame[y_min:y_max, x_min:x_max]
144
+
145
+ # Validate the crop is not empty
146
+ if eye_crop.size == 0:
147
+ return None, None
148
+
149
+ return eye_crop, (x_min, y_min, bbox_w, bbox_h)
150
+
151
+
152
+ # =============================================================================
153
+ # Preprocessing Function (CRITICAL - must match training exactly)
154
+ # =============================================================================
155
+
156
+ def preprocess(eye_crop):
157
+ """
158
+ Preprocess eye region for model inference.
159
+ CRITICAL: Must match training preprocessing exactly.
160
+
161
+ Args:
162
+ eye_crop: BGR image of eye region
163
+
164
+ Returns:
165
+ torch.Tensor: Preprocessed tensor of shape (1, 1, 640, 400)
166
+ """
167
+ # Step 1: Resize to model input size (640, 400)
168
+ resized = cv2.resize(
169
+ eye_crop,
170
+ (MODEL_WIDTH, MODEL_HEIGHT),
171
+ interpolation=cv2.INTER_LINEAR
172
+ )
173
+
174
+ # Step 2: Convert to grayscale
175
+ gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
176
+
177
+ # Step 3: Normalize to [-1, 1] range (mean=0.5, std=0.5)
178
+ normalized = (gray.astype(np.float32) / 255.0 - NORMALIZE_MEAN) / NORMALIZE_STD
179
+
180
+ # Step 4: Transpose to (1, 1, W, H) - model expects (B, C, W, H), NOT (B, C, H, W)
181
+ # normalized is (H, W) = (400, 640), we need (W, H) = (640, 400)
182
+ input_tensor = normalized.T[np.newaxis, np.newaxis, :, :]
183
+
184
+ return torch.from_numpy(input_tensor)
185
+
186
+
187
+ # =============================================================================
188
+ # Inference Function
189
+ # =============================================================================
190
+
191
+ def run_inference(input_tensor):
192
+ """
193
+ Run model inference on preprocessed input.
194
+
195
+ Args:
196
+ input_tensor: Preprocessed tensor of shape (1, 1, 640, 400)
197
+
198
+ Returns:
199
+ np.ndarray: Binary segmentation mask of shape (400, 640)
200
+ """
201
+ with torch.no_grad():
202
+ output = model(input_tensor)
203
+
204
+ # Convert output to numpy for post-processing
205
+ output_np = output.cpu().numpy()
206
+
207
+ # Post-processing: argmax to get binary mask
208
+ # Model outputs (B, C, W, H) = (1, 2, 640, 400), argmax over classes gives (640, 400)
209
+ # Transpose back to (H, W) = (400, 640) for visualization
210
+ mask = np.argmax(output_np[0], axis=0).T.astype(np.uint8)
211
+
212
+ return mask
213
+
214
+
215
+ # =============================================================================
216
+ # Visualization Function
217
+ # =============================================================================
218
+
219
+ def visualize(frame, eye_crop, mask, bbox, face_detected):
220
+ """
221
+ Visualize segmentation results on frame.
222
+
223
+ Args:
224
+ frame: Original BGR frame
225
+ eye_crop: Eye region crop
226
+ mask: Binary segmentation mask (400, 640)
227
+ bbox: Bounding box (x, y, w, h)
228
+ face_detected: Whether face was detected
229
+
230
+ Returns:
231
+ np.ndarray: Annotated frame
232
+ """
233
+ annotated = frame.copy()
234
+
235
+ # Draw status banner at top center
236
+ banner_height = 50
237
+ banner_w = annotated.shape[1]
238
+
239
+ # Semi-transparent black background for banner
240
+ banner_region = annotated[0:banner_height, 0:banner_w].astype(np.float32)
241
+ banner_region *= 0.5
242
+ annotated[0:banner_height, 0:banner_w] = banner_region.astype(np.uint8)
243
+
244
+ # Status text
245
+ if not face_detected:
246
+ status_text = "No Face Detected"
247
+ status_color = (0, 255, 255) # Yellow (BGR)
248
+ elif mask is None:
249
+ status_text = "Move Closer"
250
+ status_color = (0, 255, 255) # Yellow
251
+ else:
252
+ status_text = "Face Detected"
253
+ status_color = (0, 255, 0) # Green
254
+
255
+ text_size = cv2.getTextSize(status_text, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2)[0]
256
+ text_x = (banner_w - text_size[0]) // 2
257
+ text_y = (banner_height + text_size[1]) // 2
258
+ cv2.putText(
259
+ annotated,
260
+ status_text,
261
+ (text_x, text_y),
262
+ cv2.FONT_HERSHEY_SIMPLEX,
263
+ 1.0,
264
+ status_color,
265
+ 2,
266
+ )
267
+
268
+ # If we have a valid mask, overlay it on the eye region
269
+ if mask is not None and bbox is not None:
270
+ x, y, w, h = bbox
271
+
272
+ # Resize mask to match eye crop size
273
+ mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
274
+
275
+ # Create green overlay where mask==1 (pupil detected)
276
+ green_overlay = np.zeros((h, w, 3), dtype=np.uint8)
277
+ green_overlay[mask_resized == 1] = (0, 255, 0) # Green in BGR
278
+
279
+ # Blend with original eye region
280
+ eye_region = annotated[y:y + h, x:x + w]
281
+ blended = cv2.addWeighted(
282
+ eye_region,
283
+ 1 - OVERLAY_ALPHA,
284
+ green_overlay,
285
+ OVERLAY_ALPHA,
286
+ 0
287
+ )
288
+ annotated[y:y + h, x:x + w] = blended
289
+
290
+ # Draw bounding box
291
+ cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 3)
292
+
293
+ # Draw model info (bottom-left)
294
+ cv2.putText(
295
+ annotated,
296
+ "NSA-pico",
297
+ (10, annotated.shape[0] - 20),
298
+ cv2.FONT_HERSHEY_SIMPLEX,
299
+ 0.7,
300
+ (0, 255, 0),
301
+ 2,
302
+ )
303
+
304
+ return annotated
305
+
306
+
307
+ # =============================================================================
308
+ # Main Process Function
309
+ # =============================================================================
310
+
311
+ def process_frame(image):
312
+ """
313
+ Process a single frame from webcam for pupil segmentation.
314
+
315
+ Args:
316
+ image: Input RGB image from Gradio (numpy array)
317
+
318
+ Returns:
319
+ np.ndarray: Annotated RGB image for Gradio output
320
+ """
321
+ if image is None:
322
+ return None
323
+
324
+ # Gradio provides RGB, convert to BGR for OpenCV
325
+ frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
326
+
327
+ # Run MediaPipe face detection on RGB image
328
+ results = face_mesh.process(image) # MediaPipe expects RGB
329
+ face_detected = results.multi_face_landmarks is not None
330
+
331
+ # Initialize variables
332
+ eye_crop = None
333
+ bbox = None
334
+ mask = None
335
+
336
+ # Process if face detected
337
+ if face_detected:
338
+ landmarks = results.multi_face_landmarks[0]
339
+
340
+ # Extract eye region (from BGR frame)
341
+ eye_crop, bbox = extract_eye_region(frame_bgr, landmarks)
342
+
343
+ if eye_crop is not None:
344
+ # Preprocess
345
+ input_tensor = preprocess(eye_crop)
346
+
347
+ # Run inference
348
+ mask = run_inference(input_tensor)
349
+
350
+ # Visualize (on BGR frame)
351
+ annotated_bgr = visualize(frame_bgr, eye_crop, mask, bbox, face_detected)
352
+
353
+ # Convert back to RGB for Gradio output
354
+ annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
355
+
356
+ return annotated_rgb
357
+
358
+
359
+ # =============================================================================
360
+ # Gradio Interface
361
+ # =============================================================================
362
+
363
+ demo = gr.Interface(
364
+ fn=process_frame,
365
+ inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"),
366
+ outputs=gr.Image(label="Pupil Segmentation"),
367
+ live=True,
368
+ title="NSA Pupil Segmentation Demo",
369
+ description="""
370
+ Real-time pupil segmentation using Native Sparse Attention (NSA).
371
+
372
+ This demo uses the NSAPupilSeg model from the VisionAssist project to detect
373
+ and segment the pupil region in real-time from your webcam feed.
374
+
375
+ **How it works:**
376
+ 1. MediaPipe Face Mesh detects your face and eye landmarks
377
+ 2. The left eye region is extracted and preprocessed
378
+ 3. The NSA model performs semantic segmentation to identify the pupil
379
+ 4. Results are overlaid on the video feed with a green highlight
380
+
381
+ **Tips for best results:**
382
+ - Ensure good lighting on your face
383
+ - Look directly at the camera
384
+ - Keep your face within the frame
385
+ - Move closer if the eye region is too small
386
 
387
+ **Model:** NSA-pico (Native Sparse Attention)
388
+ """,
389
+ allow_flagging="never",
390
+ )
391
 
392
+ if __name__ == "__main__":
393
+ demo.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9de8d19344d1567ba49dc011a0d149f557c734f26ed70beaaa033568c774b8f
3
+ size 253744
nsa/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NSA (Native Sparse Attention) for Pupil Segmentation.
3
+
4
+ This module implements a Native Sparse Attention mechanism adapted from
5
+ DeepSeek's NSA paper for efficient pupil segmentation in eye images.
6
+
7
+ Key components:
8
+ - Token Compression: Coarse-grained global context
9
+ - Token Selection: Fine-grained important region focus
10
+ - Sliding Window: Local context for precise boundaries
11
+ - Gated Aggregation: Learned combination of all attention paths
12
+
13
+ Adapted for 2D vision tasks (segmentation) from the original 1D NLP formulation.
14
+ """
15
+
16
+ from .model import (
17
+ NSAPupilSeg,
18
+ NSABlock,
19
+ SpatialNSA,
20
+ TokenCompression,
21
+ TokenSelection,
22
+ SlidingWindowAttention,
23
+ CombinedLoss,
24
+ create_nsa_pupil_seg,
25
+ )
26
+
27
+ __all__ = [
28
+ "NSAPupilSeg",
29
+ "NSABlock",
30
+ "SpatialNSA",
31
+ "TokenCompression",
32
+ "TokenSelection",
33
+ "SlidingWindowAttention",
34
+ "CombinedLoss",
35
+ "create_nsa_pupil_seg",
36
+ ]
nsa/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (978 Bytes). View file
 
nsa/__pycache__/model.cpython-313.pyc ADDED
Binary file (46.6 kB). View file
 
nsa/model.py ADDED
@@ -0,0 +1,1921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Native Sparse Attention (NSA) Model for Pupil Segmentation.
3
+
4
+ Implementation based on DeepSeek's NSA paper:
5
+ "Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention"
6
+
7
+ Adapted for 2D vision/segmentation tasks with domain-specific optimizations for
8
+ pupil segmentation where:
9
+ - Intense pixel localization is required
10
+ - The pupil is only found on the eye (spatial locality)
11
+ - OpenEDS provides multi-class data beyond pupil
12
+
13
+ Architecture:
14
+ - Encoder with NSA blocks for hierarchical feature extraction
15
+ - Decoder with skip connections for precise segmentation
16
+ - NSA combines: Compression (global), Selection (important), Sliding Window (local)
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ # =============================================================================
26
+ # Core Building Blocks
27
+ # =============================================================================
28
+
29
+
30
+ class ConvBNReLU(nn.Module):
31
+ """Convolution + BatchNorm + Activation block."""
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels: int,
36
+ out_channels: int,
37
+ kernel_size: int = 3,
38
+ stride: int = 1,
39
+ padding: int = 1,
40
+ groups: int = 1,
41
+ bias: bool = False,
42
+ activation: bool = True,
43
+ ):
44
+ super().__init__()
45
+ self.conv = nn.Conv2d(
46
+ in_channels,
47
+ out_channels,
48
+ kernel_size=kernel_size,
49
+ stride=stride,
50
+ padding=padding,
51
+ groups=groups,
52
+ bias=bias,
53
+ )
54
+ self.bn = nn.BatchNorm2d(
55
+ out_channels
56
+ )
57
+ self.act = (
58
+ nn.GELU()
59
+ if activation
60
+ else nn.Identity()
61
+ )
62
+
63
+ def forward(
64
+ self, x: torch.Tensor
65
+ ) -> torch.Tensor:
66
+ return self.act(
67
+ self.bn(self.conv(x))
68
+ )
69
+
70
+
71
+ class PatchEmbedding(nn.Module):
72
+ """
73
+ Embed image patches into tokens for attention processing.
74
+ Uses strided convolutions to reduce spatial resolution.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ in_channels: int = 1,
80
+ embed_dim: int = 32,
81
+ patch_size: int = 4,
82
+ ):
83
+ super().__init__()
84
+ self.patch_size = patch_size
85
+ mid_dim = embed_dim // 2
86
+
87
+ # Two-stage downsampling for smoother feature transition
88
+ self.conv1 = ConvBNReLU(
89
+ in_channels,
90
+ mid_dim,
91
+ kernel_size=3,
92
+ stride=2,
93
+ padding=1,
94
+ )
95
+ self.conv2 = ConvBNReLU(
96
+ mid_dim,
97
+ embed_dim,
98
+ kernel_size=3,
99
+ stride=2,
100
+ padding=1,
101
+ )
102
+
103
+ def forward(
104
+ self, x: torch.Tensor
105
+ ) -> torch.Tensor:
106
+ """
107
+ Args:
108
+ x: Input image (B, C, H, W)
109
+ Returns:
110
+ Embedded patches (B, embed_dim, H//4, W//4)
111
+ """
112
+ x = self.conv1(x)
113
+ x = self.conv2(x)
114
+ return x
115
+
116
+
117
+ # =============================================================================
118
+ # Token Compression Module
119
+ # =============================================================================
120
+
121
+
122
+ class TokenCompression(nn.Module):
123
+ """
124
+ Compress spatial blocks into single tokens for coarse-grained attention.
125
+
126
+ From NSA paper Eq. 7:
127
+ K_cmp = {φ(k_{id+1:id+l}) | 0 ≤ i ≤ ⌊(t-l)/d⌋}
128
+
129
+ Adapted for 2D: compress spatial blocks into representative tokens.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ block_size: int = 4,
136
+ stride: int = 2,
137
+ ):
138
+ super().__init__()
139
+ self.block_size = block_size
140
+ self.stride = stride
141
+
142
+ # Learnable compression MLP with position encoding
143
+ self.compress_k = nn.Sequential(
144
+ nn.Linear(
145
+ dim
146
+ * block_size
147
+ * block_size,
148
+ dim * 2,
149
+ ),
150
+ nn.GELU(),
151
+ nn.Linear(dim * 2, dim),
152
+ )
153
+ self.compress_v = nn.Sequential(
154
+ nn.Linear(
155
+ dim
156
+ * block_size
157
+ * block_size,
158
+ dim * 2,
159
+ ),
160
+ nn.GELU(),
161
+ nn.Linear(dim * 2, dim),
162
+ )
163
+
164
+ # Intra-block position encoding
165
+ self.pos_embed = nn.Parameter(
166
+ torch.randn(
167
+ 1,
168
+ block_size * block_size,
169
+ dim,
170
+ )
171
+ * 0.02
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ k: torch.Tensor,
177
+ v: torch.Tensor,
178
+ spatial_size: tuple[int, int],
179
+ ) -> tuple[
180
+ torch.Tensor, torch.Tensor
181
+ ]:
182
+ """
183
+ Compress keys and values into block-level representations.
184
+
185
+ Args:
186
+ k: Keys (B, N, dim) where N = H * W
187
+ v: Values (B, N, dim)
188
+ spatial_size: (H, W) tuple for non-square inputs
189
+ Returns:
190
+ k_cmp: Compressed keys (B, N_cmp, dim)
191
+ v_cmp: Compressed values (B, N_cmp, dim)
192
+ """
193
+ B, N, dim = k.shape
194
+
195
+ # Use provided spatial dimensions for non-square inputs
196
+ H, W = spatial_size
197
+ bs = self.block_size
198
+ stride = self.stride
199
+
200
+ # Calculate number of blocks
201
+ n_blocks_h = (
202
+ H - bs
203
+ ) // stride + 1
204
+ n_blocks_w = (
205
+ W - bs
206
+ ) // stride + 1
207
+
208
+ # Extract overlapping blocks using unfold
209
+ # Use reshape instead of view for non-contiguous tensors
210
+ k_2d = (
211
+ k.reshape(B, H, W, dim)
212
+ .permute(0, 3, 1, 2)
213
+ .contiguous()
214
+ ) # (B, dim, H, W)
215
+ v_2d = (
216
+ v.reshape(B, H, W, dim)
217
+ .permute(0, 3, 1, 2)
218
+ .contiguous()
219
+ )
220
+
221
+ # Unfold to get blocks: (B, dim*bs*bs, n_blocks)
222
+ k_blocks = F.unfold(
223
+ k_2d,
224
+ kernel_size=bs,
225
+ stride=stride,
226
+ )
227
+ v_blocks = F.unfold(
228
+ v_2d,
229
+ kernel_size=bs,
230
+ stride=stride,
231
+ )
232
+
233
+ # Reshape for compression: (B, n_blocks, dim*bs*bs)
234
+ n_blocks = k_blocks.shape[2]
235
+ k_blocks = k_blocks.permute(
236
+ 0, 2, 1
237
+ ).contiguous()
238
+ v_blocks = v_blocks.permute(
239
+ 0, 2, 1
240
+ ).contiguous()
241
+
242
+ # Add position encoding before compression
243
+ # Reshape blocks to add position encoding: (B, n_blocks, bs*bs, dim)
244
+ k_blocks_reshaped = (
245
+ k_blocks.reshape(
246
+ B,
247
+ n_blocks,
248
+ bs * bs,
249
+ dim,
250
+ )
251
+ )
252
+ k_blocks_reshaped = (
253
+ k_blocks_reshaped
254
+ + self.pos_embed.unsqueeze(
255
+ 0
256
+ )
257
+ )
258
+ k_blocks_pos = (
259
+ k_blocks_reshaped.reshape(
260
+ B,
261
+ n_blocks,
262
+ bs * bs * dim,
263
+ )
264
+ )
265
+
266
+ # Compress to single tokens
267
+ k_cmp = self.compress_k(
268
+ k_blocks_pos
269
+ )
270
+ v_cmp = self.compress_v(
271
+ v_blocks
272
+ )
273
+
274
+ return k_cmp, v_cmp
275
+
276
+
277
+ # =============================================================================
278
+ # Token Selection Module
279
+ # =============================================================================
280
+
281
+
282
+ class TokenSelection(nn.Module):
283
+ """
284
+ Select important token blocks based on attention scores.
285
+
286
+ From NSA paper Eq. 8-12:
287
+ - Compute importance from compressed attention scores
288
+ - Select top-n blocks for fine-grained attention
289
+
290
+ For pupil segmentation: identifies the most relevant spatial regions.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ dim: int,
296
+ block_size: int = 4,
297
+ num_select: int = 4,
298
+ ):
299
+ super().__init__()
300
+ self.block_size = block_size
301
+ self.num_select = num_select
302
+ self.dim = dim
303
+
304
+ def forward(
305
+ self,
306
+ q: torch.Tensor,
307
+ k: torch.Tensor,
308
+ v: torch.Tensor,
309
+ attn_scores_cmp: torch.Tensor,
310
+ spatial_size: tuple[int, int],
311
+ ) -> tuple[
312
+ torch.Tensor,
313
+ torch.Tensor,
314
+ torch.Tensor,
315
+ ]:
316
+ """
317
+ Select important blocks based on compressed attention scores.
318
+
319
+ Args:
320
+ q: Queries (B, H, N, dim)
321
+ k: Keys (B, N, dim)
322
+ v: Values (B, N, dim)
323
+ attn_scores_cmp: Attention from compression (B, H, N, N_cmp)
324
+ spatial_size: (height, width) of feature map
325
+ Returns:
326
+ k_slc: Selected keys
327
+ v_slc: Selected values
328
+ indices: Selected block indices
329
+ """
330
+ B, num_heads, N, N_cmp = (
331
+ attn_scores_cmp.shape
332
+ )
333
+ H, W = spatial_size
334
+ bs = self.block_size
335
+
336
+ # Sum attention across heads for shared selection (GQA-style)
337
+ importance = (
338
+ attn_scores_cmp.sum(dim=1)
339
+ ) # (B, N, N_cmp)
340
+
341
+ # Average importance across queries to get block scores
342
+ block_importance = (
343
+ importance.mean(dim=1)
344
+ ) # (B, N_cmp)
345
+
346
+ # Select top-n blocks
347
+ num_select = min(
348
+ self.num_select, N_cmp
349
+ )
350
+ _, indices = torch.topk(
351
+ block_importance,
352
+ num_select,
353
+ dim=-1,
354
+ ) # (B, num_select)
355
+
356
+ # Map compressed indices back to original token blocks
357
+ # This is simplified - in practice would need proper index mapping
358
+ # For now, use the indices to gather from original k, v
359
+
360
+ # Reshape k, v to blocks
361
+ n_blocks_h = (H - bs) // bs + 1
362
+ n_blocks_w = (W - bs) // bs + 1
363
+
364
+ # Gather selected blocks
365
+ k_2d = (
366
+ k.reshape(B, H, W, -1)
367
+ .permute(0, 3, 1, 2)
368
+ .contiguous()
369
+ )
370
+ v_2d = (
371
+ v.reshape(B, H, W, -1)
372
+ .permute(0, 3, 1, 2)
373
+ .contiguous()
374
+ )
375
+
376
+ # Use unfold to extract all blocks
377
+ k_blocks = F.unfold(
378
+ k_2d,
379
+ kernel_size=bs,
380
+ stride=bs,
381
+ ) # (B, dim*bs*bs, n_blocks)
382
+ v_blocks = F.unfold(
383
+ v_2d,
384
+ kernel_size=bs,
385
+ stride=bs,
386
+ )
387
+
388
+ n_blocks = k_blocks.shape[2]
389
+ k_blocks = (
390
+ k_blocks.permute(0, 2, 1)
391
+ .contiguous()
392
+ .reshape(
393
+ B, n_blocks, bs * bs, -1
394
+ )
395
+ )
396
+ v_blocks = (
397
+ v_blocks.permute(0, 2, 1)
398
+ .contiguous()
399
+ .reshape(
400
+ B, n_blocks, bs * bs, -1
401
+ )
402
+ )
403
+
404
+ # Clamp indices to valid range
405
+ indices = indices.clamp(
406
+ 0, n_blocks - 1
407
+ )
408
+
409
+ # Gather selected blocks
410
+ indices_expanded = (
411
+ indices.unsqueeze(-1)
412
+ .unsqueeze(-1)
413
+ .expand(
414
+ -1,
415
+ -1,
416
+ bs * bs,
417
+ k.shape[-1],
418
+ )
419
+ )
420
+ k_slc = torch.gather(
421
+ k_blocks,
422
+ 1,
423
+ indices_expanded,
424
+ ) # (B, num_select, bs*bs, dim)
425
+ v_slc = torch.gather(
426
+ v_blocks,
427
+ 1,
428
+ indices_expanded,
429
+ )
430
+
431
+ # Flatten selected blocks
432
+ k_slc = k_slc.view(
433
+ B, num_select * bs * bs, -1
434
+ )
435
+ v_slc = v_slc.view(
436
+ B, num_select * bs * bs, -1
437
+ )
438
+
439
+ return k_slc, v_slc, indices
440
+
441
+
442
+ # =============================================================================
443
+ # Sliding Window Attention
444
+ # =============================================================================
445
+
446
+
447
+ class SlidingWindowAttention(nn.Module):
448
+ """
449
+ Local sliding window attention for fine-grained local context.
450
+
451
+ From NSA paper Section 3.3.3:
452
+ Maintains recent tokens in a window for local pattern recognition.
453
+
454
+ For pupil segmentation: critical for precise boundary delineation.
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ dim: int,
460
+ num_heads: int = 2,
461
+ window_size: int = 7,
462
+ qkv_bias: bool = True,
463
+ ):
464
+ super().__init__()
465
+ self.dim = dim
466
+ self.num_heads = num_heads
467
+ self.window_size = window_size
468
+ self.head_dim = dim // num_heads
469
+ self.scale = self.head_dim**-0.5
470
+
471
+ self.qkv = nn.Linear(
472
+ dim, dim * 3, bias=qkv_bias
473
+ )
474
+ self.proj = nn.Linear(dim, dim)
475
+
476
+ # Relative position bias
477
+ self.relative_position_bias_table = nn.Parameter(
478
+ torch.zeros(
479
+ (2 * window_size - 1)
480
+ * (2 * window_size - 1),
481
+ num_heads,
482
+ )
483
+ )
484
+ nn.init.trunc_normal_(
485
+ self.relative_position_bias_table,
486
+ std=0.02,
487
+ )
488
+
489
+ # Create position index
490
+ coords_h = torch.arange(
491
+ window_size
492
+ )
493
+ coords_w = torch.arange(
494
+ window_size
495
+ )
496
+ coords = torch.stack(
497
+ torch.meshgrid(
498
+ coords_h,
499
+ coords_w,
500
+ indexing="ij",
501
+ )
502
+ )
503
+ coords_flatten = coords.flatten(
504
+ 1
505
+ )
506
+ relative_coords = (
507
+ coords_flatten[:, :, None]
508
+ - coords_flatten[:, None, :]
509
+ )
510
+ relative_coords = (
511
+ relative_coords.permute(
512
+ 1, 2, 0
513
+ ).contiguous()
514
+ )
515
+ relative_coords[:, :, 0] += (
516
+ window_size - 1
517
+ )
518
+ relative_coords[:, :, 1] += (
519
+ window_size - 1
520
+ )
521
+ relative_coords[:, :, 0] *= (
522
+ 2 * window_size - 1
523
+ )
524
+ relative_position_index = (
525
+ relative_coords.sum(-1)
526
+ )
527
+ self.register_buffer(
528
+ "relative_position_index",
529
+ relative_position_index,
530
+ )
531
+
532
+ def forward(
533
+ self, x: torch.Tensor
534
+ ) -> torch.Tensor:
535
+ """
536
+ Apply sliding window attention.
537
+
538
+ Args:
539
+ x: Input features (B, C, H, W)
540
+ Returns:
541
+ Output features (B, C, H, W)
542
+ """
543
+ B, C, H, W = x.shape
544
+ ws = self.window_size
545
+
546
+ # Pad to multiple of window size
547
+ pad_h = (ws - H % ws) % ws
548
+ pad_w = (ws - W % ws) % ws
549
+ if pad_h > 0 or pad_w > 0:
550
+ x = F.pad(
551
+ x, (0, pad_w, 0, pad_h)
552
+ )
553
+
554
+ _, _, Hp, Wp = x.shape
555
+
556
+ # Reshape to windows: (B*num_windows, ws*ws, C)
557
+ x = x.view(
558
+ B,
559
+ C,
560
+ Hp // ws,
561
+ ws,
562
+ Wp // ws,
563
+ ws,
564
+ )
565
+ x = x.permute(
566
+ 0, 2, 4, 3, 5, 1
567
+ ).contiguous()
568
+ x = x.view(-1, ws * ws, C)
569
+
570
+ # Compute QKV
571
+ B_win = x.shape[0]
572
+ qkv = self.qkv(x).reshape(
573
+ B_win,
574
+ ws * ws,
575
+ 3,
576
+ self.num_heads,
577
+ self.head_dim,
578
+ )
579
+ qkv = qkv.permute(2, 0, 3, 1, 4)
580
+ q, k, v = qkv[0], qkv[1], qkv[2]
581
+
582
+ # Attention
583
+ attn = (
584
+ q @ k.transpose(-2, -1)
585
+ ) * self.scale
586
+
587
+ # Add relative position bias
588
+ relative_position_bias = self.relative_position_bias_table[
589
+ self.relative_position_index.view(
590
+ -1
591
+ )
592
+ ].view(
593
+ ws * ws, ws * ws, -1
594
+ )
595
+ relative_position_bias = relative_position_bias.permute(
596
+ 2, 0, 1
597
+ ).contiguous()
598
+ attn = (
599
+ attn
600
+ + relative_position_bias.unsqueeze(
601
+ 0
602
+ )
603
+ )
604
+
605
+ attn = attn.softmax(dim=-1)
606
+ x = (
607
+ (attn @ v)
608
+ .transpose(1, 2)
609
+ .reshape(B_win, ws * ws, C)
610
+ )
611
+ x = self.proj(x)
612
+
613
+ # Reshape back
614
+ num_windows_h = Hp // ws
615
+ num_windows_w = Wp // ws
616
+ x = x.view(
617
+ B,
618
+ num_windows_h,
619
+ num_windows_w,
620
+ ws,
621
+ ws,
622
+ C,
623
+ )
624
+ x = x.permute(
625
+ 0, 5, 1, 3, 2, 4
626
+ ).contiguous()
627
+ x = x.view(B, C, Hp, Wp)
628
+
629
+ # Remove padding
630
+ if pad_h > 0 or pad_w > 0:
631
+ x = x[:, :, :H, :W]
632
+
633
+ return x
634
+
635
+
636
+ # =============================================================================
637
+ # Native Sparse Attention (NSA) - Core Module
638
+ # =============================================================================
639
+
640
+
641
+ class SpatialNSA(nn.Module):
642
+ """
643
+ Native Sparse Attention adapted for 2D spatial features.
644
+
645
+ Combines three attention paths (NSA paper Eq. 5):
646
+ o* = Σ g_c · Attn(q, K̃_c, Ṽ_c) for c ∈ {cmp, slc, win}
647
+
648
+ Components:
649
+ 1. Compressed Attention: Global coarse-grained context
650
+ 2. Selected Attention: Fine-grained important regions
651
+ 3. Sliding Window: Local context for precise boundaries
652
+ 4. Gated Aggregation: Learned combination
653
+ """
654
+
655
+ def __init__(
656
+ self,
657
+ dim: int,
658
+ num_heads: int = 2,
659
+ compress_block_size: int = 4,
660
+ compress_stride: int = 2,
661
+ select_block_size: int = 4,
662
+ num_select: int = 4,
663
+ window_size: int = 7,
664
+ qkv_bias: bool = True,
665
+ ):
666
+ super().__init__()
667
+ self.dim = dim
668
+ self.num_heads = num_heads
669
+ self.head_dim = dim // num_heads
670
+ self.scale = self.head_dim**-0.5
671
+
672
+ # Separate QKV for each branch (prevents shortcut learning)
673
+ self.qkv_cmp = nn.Linear(
674
+ dim, dim * 3, bias=qkv_bias
675
+ )
676
+ self.qkv_slc = nn.Linear(
677
+ dim, dim * 3, bias=qkv_bias
678
+ )
679
+
680
+ # Token compression module
681
+ self.compression = TokenCompression(
682
+ dim=dim,
683
+ block_size=compress_block_size,
684
+ stride=compress_stride,
685
+ )
686
+
687
+ # Token selection module
688
+ self.selection = TokenSelection(
689
+ dim=dim,
690
+ block_size=select_block_size,
691
+ num_select=num_select,
692
+ )
693
+
694
+ # Sliding window attention
695
+ self.window_attn = (
696
+ SlidingWindowAttention(
697
+ dim=dim,
698
+ num_heads=num_heads,
699
+ window_size=window_size,
700
+ qkv_bias=qkv_bias,
701
+ )
702
+ )
703
+
704
+ # Output projections
705
+ self.proj_cmp = nn.Linear(
706
+ dim, dim
707
+ )
708
+ self.proj_slc = nn.Linear(
709
+ dim, dim
710
+ )
711
+
712
+ # Gating mechanism (NSA paper Eq. 5)
713
+ self.gate = nn.Sequential(
714
+ nn.Linear(dim, dim // 4),
715
+ nn.GELU(),
716
+ nn.Linear(dim // 4, 3),
717
+ nn.Sigmoid(),
718
+ )
719
+
720
+ def forward(
721
+ self, x: torch.Tensor
722
+ ) -> torch.Tensor:
723
+ """
724
+ Apply Native Sparse Attention.
725
+
726
+ Args:
727
+ x: Input features (B, C, H, W)
728
+ Returns:
729
+ Output features (B, C, H, W)
730
+ """
731
+ B, C, H, W = x.shape
732
+ N = H * W
733
+
734
+ # Reshape to sequence
735
+ x_seq = x.flatten(2).transpose(
736
+ 1, 2
737
+ ) # (B, N, C)
738
+
739
+ # =================================================================
740
+ # Branch 1: Compressed Attention (Global Coarse-Grained)
741
+ # =================================================================
742
+ qkv_cmp = self.qkv_cmp(x_seq)
743
+ qkv_cmp = qkv_cmp.reshape(
744
+ B,
745
+ N,
746
+ 3,
747
+ self.num_heads,
748
+ self.head_dim,
749
+ )
750
+ qkv_cmp = qkv_cmp.permute(
751
+ 2, 0, 3, 1, 4
752
+ )
753
+ q_cmp, k_cmp_raw, v_cmp_raw = (
754
+ qkv_cmp[0],
755
+ qkv_cmp[1],
756
+ qkv_cmp[2],
757
+ )
758
+
759
+ # Reshape k, v for compression
760
+ k_for_cmp = k_cmp_raw.transpose(
761
+ 1, 2
762
+ ).reshape(B, N, C)
763
+ v_for_cmp = v_cmp_raw.transpose(
764
+ 1, 2
765
+ ).reshape(B, N, C)
766
+
767
+ # Compress tokens
768
+ k_cmp, v_cmp = self.compression(
769
+ k_for_cmp, v_for_cmp, (H, W)
770
+ )
771
+ N_cmp = k_cmp.shape[1]
772
+
773
+ # Reshape for multi-head attention
774
+ k_cmp = k_cmp.view(
775
+ B,
776
+ N_cmp,
777
+ self.num_heads,
778
+ self.head_dim,
779
+ ).transpose(1, 2)
780
+ v_cmp = v_cmp.view(
781
+ B,
782
+ N_cmp,
783
+ self.num_heads,
784
+ self.head_dim,
785
+ ).transpose(1, 2)
786
+
787
+ # Compute compressed attention
788
+ attn_cmp = (
789
+ q_cmp
790
+ @ k_cmp.transpose(-2, -1)
791
+ ) * self.scale
792
+ attn_cmp_softmax = (
793
+ attn_cmp.softmax(dim=-1)
794
+ )
795
+ o_cmp = attn_cmp_softmax @ v_cmp
796
+ o_cmp = o_cmp.transpose(
797
+ 1, 2
798
+ ).reshape(B, N, C)
799
+ o_cmp = self.proj_cmp(o_cmp)
800
+
801
+ # =================================================================
802
+ # Branch 2: Selected Attention (Fine-Grained Important)
803
+ # =================================================================
804
+ qkv_slc = self.qkv_slc(x_seq)
805
+ qkv_slc = qkv_slc.reshape(
806
+ B,
807
+ N,
808
+ 3,
809
+ self.num_heads,
810
+ self.head_dim,
811
+ )
812
+ qkv_slc = qkv_slc.permute(
813
+ 2, 0, 3, 1, 4
814
+ )
815
+ q_slc, k_slc_raw, v_slc_raw = (
816
+ qkv_slc[0],
817
+ qkv_slc[1],
818
+ qkv_slc[2],
819
+ )
820
+
821
+ k_for_slc = k_slc_raw.transpose(
822
+ 1, 2
823
+ ).reshape(B, N, C)
824
+ v_for_slc = v_slc_raw.transpose(
825
+ 1, 2
826
+ ).reshape(B, N, C)
827
+
828
+ # Select important blocks based on compressed attention scores
829
+ k_slc, v_slc, _ = (
830
+ self.selection(
831
+ q_slc,
832
+ k_for_slc,
833
+ v_for_slc,
834
+ attn_cmp_softmax,
835
+ (H, W),
836
+ )
837
+ )
838
+
839
+ N_slc = k_slc.shape[1]
840
+ k_slc = k_slc.view(
841
+ B,
842
+ N_slc,
843
+ self.num_heads,
844
+ self.head_dim,
845
+ ).transpose(1, 2)
846
+ v_slc = v_slc.view(
847
+ B,
848
+ N_slc,
849
+ self.num_heads,
850
+ self.head_dim,
851
+ ).transpose(1, 2)
852
+
853
+ # Compute selected attention
854
+ attn_slc = (
855
+ q_slc
856
+ @ k_slc.transpose(-2, -1)
857
+ ) * self.scale
858
+ attn_slc = attn_slc.softmax(
859
+ dim=-1
860
+ )
861
+ o_slc = attn_slc @ v_slc
862
+ o_slc = o_slc.transpose(
863
+ 1, 2
864
+ ).reshape(B, N, C)
865
+ o_slc = self.proj_slc(o_slc)
866
+
867
+ # =================================================================
868
+ # Branch 3: Sliding Window Attention (Local Context)
869
+ # =================================================================
870
+ o_win = self.window_attn(x)
871
+ o_win = o_win.flatten(
872
+ 2
873
+ ).transpose(
874
+ 1, 2
875
+ ) # (B, N, C)
876
+
877
+ # =================================================================
878
+ # Gated Aggregation
879
+ # =================================================================
880
+ # Compute per-token gates
881
+ gates = self.gate(
882
+ x_seq
883
+ ) # (B, N, 3)
884
+ g_cmp = gates[:, :, 0:1]
885
+ g_slc = gates[:, :, 1:2]
886
+ g_win = gates[:, :, 2:3]
887
+
888
+ # Weighted combination
889
+ out = (
890
+ g_cmp * o_cmp
891
+ + g_slc * o_slc
892
+ + g_win * o_win
893
+ )
894
+
895
+ # Reshape back to spatial
896
+ out = out.transpose(1, 2).view(
897
+ B, C, H, W
898
+ )
899
+
900
+ return out
901
+
902
+
903
+ # =============================================================================
904
+ # NSA Block (Attention + FFN)
905
+ # =============================================================================
906
+
907
+
908
+ class NSABlock(nn.Module):
909
+ """
910
+ Complete NSA block with attention, normalization, and FFN.
911
+
912
+ Structure:
913
+ - Depthwise conv for local features (like EfficientViT)
914
+ - Native Sparse Attention for global/selective features
915
+ - FFN for channel mixing
916
+ """
917
+
918
+ def __init__(
919
+ self,
920
+ dim: int,
921
+ num_heads: int = 2,
922
+ mlp_ratio: float = 2.0,
923
+ compress_block_size: int = 4,
924
+ compress_stride: int = 2,
925
+ select_block_size: int = 4,
926
+ num_select: int = 4,
927
+ window_size: int = 7,
928
+ ):
929
+ super().__init__()
930
+
931
+ # Local feature extraction (depthwise conv)
932
+ self.norm1 = nn.BatchNorm2d(dim)
933
+ self.dw_conv = nn.Conv2d(
934
+ dim,
935
+ dim,
936
+ kernel_size=3,
937
+ padding=1,
938
+ groups=dim,
939
+ )
940
+
941
+ # NSA attention
942
+ self.norm2 = nn.BatchNorm2d(dim)
943
+ self.nsa = SpatialNSA(
944
+ dim=dim,
945
+ num_heads=num_heads,
946
+ compress_block_size=compress_block_size,
947
+ compress_stride=compress_stride,
948
+ select_block_size=select_block_size,
949
+ num_select=num_select,
950
+ window_size=window_size,
951
+ )
952
+
953
+ # FFN
954
+ self.norm3 = nn.LayerNorm(dim)
955
+ hidden_dim = int(
956
+ dim * mlp_ratio
957
+ )
958
+ self.ffn = nn.Sequential(
959
+ nn.Linear(dim, hidden_dim),
960
+ nn.GELU(),
961
+ nn.Linear(hidden_dim, dim),
962
+ )
963
+
964
+ def forward(
965
+ self, x: torch.Tensor
966
+ ) -> torch.Tensor:
967
+ """
968
+ Args:
969
+ x: Input features (B, C, H, W)
970
+ Returns:
971
+ Output features (B, C, H, W)
972
+ """
973
+ # Local features
974
+ x = x + self.dw_conv(
975
+ self.norm1(x)
976
+ )
977
+
978
+ # NSA attention
979
+ x = x + self.nsa(self.norm2(x))
980
+
981
+ # FFN
982
+ B, C, H, W = x.shape
983
+ x_flat = x.flatten(2).transpose(
984
+ 1, 2
985
+ ) # (B, N, C)
986
+ x_flat = x_flat + self.ffn(
987
+ self.norm3(x_flat)
988
+ )
989
+ x = x_flat.transpose(1, 2).view(
990
+ B, C, H, W
991
+ )
992
+
993
+ return x
994
+
995
+
996
+ # =============================================================================
997
+ # NSA Stage (Multiple Blocks + Optional Downsampling)
998
+ # =============================================================================
999
+
1000
+
1001
+ class NSAStage(nn.Module):
1002
+ """
1003
+ Stage containing multiple NSA blocks with optional downsampling.
1004
+ """
1005
+
1006
+ def __init__(
1007
+ self,
1008
+ in_dim: int,
1009
+ out_dim: int,
1010
+ depth: int = 1,
1011
+ num_heads: int = 2,
1012
+ mlp_ratio: float = 2.0,
1013
+ compress_block_size: int = 4,
1014
+ compress_stride: int = 2,
1015
+ select_block_size: int = 4,
1016
+ num_select: int = 4,
1017
+ window_size: int = 7,
1018
+ downsample: bool = True,
1019
+ ):
1020
+ super().__init__()
1021
+
1022
+ # Downsampling
1023
+ self.downsample = None
1024
+ if downsample:
1025
+ self.downsample = (
1026
+ nn.Sequential(
1027
+ ConvBNReLU(
1028
+ in_dim,
1029
+ out_dim,
1030
+ kernel_size=3,
1031
+ stride=2,
1032
+ padding=1,
1033
+ ),
1034
+ )
1035
+ )
1036
+ elif in_dim != out_dim:
1037
+ self.downsample = (
1038
+ ConvBNReLU(
1039
+ in_dim,
1040
+ out_dim,
1041
+ kernel_size=1,
1042
+ stride=1,
1043
+ padding=0,
1044
+ )
1045
+ )
1046
+
1047
+ # NSA blocks
1048
+ self.blocks = nn.ModuleList(
1049
+ [
1050
+ NSABlock(
1051
+ dim=out_dim,
1052
+ num_heads=num_heads,
1053
+ mlp_ratio=mlp_ratio,
1054
+ compress_block_size=compress_block_size,
1055
+ compress_stride=compress_stride,
1056
+ select_block_size=select_block_size,
1057
+ num_select=num_select,
1058
+ window_size=window_size,
1059
+ )
1060
+ for _ in range(depth)
1061
+ ]
1062
+ )
1063
+
1064
+ def forward(
1065
+ self, x: torch.Tensor
1066
+ ) -> torch.Tensor:
1067
+ if self.downsample is not None:
1068
+ x = self.downsample(x)
1069
+ for block in self.blocks:
1070
+ x = block(x)
1071
+ return x
1072
+
1073
+
1074
+ # =============================================================================
1075
+ # NSA Encoder
1076
+ # =============================================================================
1077
+
1078
+
1079
+ class NSAEncoder(nn.Module):
1080
+ """
1081
+ NSA-based encoder for hierarchical feature extraction.
1082
+ Produces multi-scale features for segmentation decoder.
1083
+ """
1084
+
1085
+ def __init__(
1086
+ self,
1087
+ in_channels: int = 1,
1088
+ embed_dims: tuple = (
1089
+ 32,
1090
+ 64,
1091
+ 96,
1092
+ ),
1093
+ depths: tuple = (1, 1, 1),
1094
+ num_heads: tuple = (2, 2, 4),
1095
+ mlp_ratios: tuple = (2, 2, 2),
1096
+ compress_block_sizes: tuple = (
1097
+ 4,
1098
+ 4,
1099
+ 4,
1100
+ ),
1101
+ compress_strides: tuple = (
1102
+ 2,
1103
+ 2,
1104
+ 2,
1105
+ ),
1106
+ select_block_sizes: tuple = (
1107
+ 4,
1108
+ 4,
1109
+ 4,
1110
+ ),
1111
+ num_selects: tuple = (4, 4, 4),
1112
+ window_sizes: tuple = (7, 7, 7),
1113
+ ):
1114
+ super().__init__()
1115
+
1116
+ # Patch embedding
1117
+ self.patch_embed = (
1118
+ PatchEmbedding(
1119
+ in_channels=in_channels,
1120
+ embed_dim=embed_dims[0],
1121
+ )
1122
+ )
1123
+
1124
+ # Stage 1: No downsampling (already done in patch embed)
1125
+ self.stage1 = NSAStage(
1126
+ in_dim=embed_dims[0],
1127
+ out_dim=embed_dims[0],
1128
+ depth=depths[0],
1129
+ num_heads=num_heads[0],
1130
+ mlp_ratio=mlp_ratios[0],
1131
+ compress_block_size=compress_block_sizes[
1132
+ 0
1133
+ ],
1134
+ compress_stride=compress_strides[
1135
+ 0
1136
+ ],
1137
+ select_block_size=select_block_sizes[
1138
+ 0
1139
+ ],
1140
+ num_select=num_selects[0],
1141
+ window_size=window_sizes[0],
1142
+ downsample=False,
1143
+ )
1144
+
1145
+ # Stage 2: Downsample 2x
1146
+ self.stage2 = NSAStage(
1147
+ in_dim=embed_dims[0],
1148
+ out_dim=embed_dims[1],
1149
+ depth=depths[1],
1150
+ num_heads=num_heads[1],
1151
+ mlp_ratio=mlp_ratios[1],
1152
+ compress_block_size=compress_block_sizes[
1153
+ 1
1154
+ ],
1155
+ compress_stride=compress_strides[
1156
+ 1
1157
+ ],
1158
+ select_block_size=select_block_sizes[
1159
+ 1
1160
+ ],
1161
+ num_select=num_selects[1],
1162
+ window_size=window_sizes[1],
1163
+ downsample=True,
1164
+ )
1165
+
1166
+ # Stage 3: Downsample 2x
1167
+ self.stage3 = NSAStage(
1168
+ in_dim=embed_dims[1],
1169
+ out_dim=embed_dims[2],
1170
+ depth=depths[2],
1171
+ num_heads=num_heads[2],
1172
+ mlp_ratio=mlp_ratios[2],
1173
+ compress_block_size=compress_block_sizes[
1174
+ 2
1175
+ ],
1176
+ compress_stride=compress_strides[
1177
+ 2
1178
+ ],
1179
+ select_block_size=select_block_sizes[
1180
+ 2
1181
+ ],
1182
+ num_select=num_selects[2],
1183
+ window_size=window_sizes[2],
1184
+ downsample=True,
1185
+ )
1186
+
1187
+ def forward(
1188
+ self, x: torch.Tensor
1189
+ ) -> tuple:
1190
+ """
1191
+ Args:
1192
+ x: Input image (B, C, H, W)
1193
+ Returns:
1194
+ Multi-scale features (f1, f2, f3)
1195
+ """
1196
+ x = self.patch_embed(x)
1197
+ f1 = self.stage1(
1198
+ x
1199
+ ) # 1/4 resolution
1200
+ f2 = self.stage2(
1201
+ f1
1202
+ ) # 1/8 resolution
1203
+ f3 = self.stage3(
1204
+ f2
1205
+ ) # 1/16 resolution
1206
+ return f1, f2, f3
1207
+
1208
+
1209
+ # =============================================================================
1210
+ # Segmentation Decoder
1211
+ # =============================================================================
1212
+
1213
+
1214
+ class SegmentationDecoder(nn.Module):
1215
+ """
1216
+ FPN-style decoder with skip connections for precise segmentation.
1217
+ Progressively upsamples features to input resolution.
1218
+ """
1219
+
1220
+ def __init__(
1221
+ self,
1222
+ encoder_dims: tuple = (
1223
+ 32,
1224
+ 64,
1225
+ 96,
1226
+ ),
1227
+ decoder_dim: int = 32,
1228
+ num_classes: int = 2,
1229
+ ):
1230
+ super().__init__()
1231
+
1232
+ # Lateral connections
1233
+ self.lateral3 = nn.Conv2d(
1234
+ encoder_dims[2],
1235
+ decoder_dim,
1236
+ kernel_size=1,
1237
+ )
1238
+ self.lateral2 = nn.Conv2d(
1239
+ encoder_dims[1],
1240
+ decoder_dim,
1241
+ kernel_size=1,
1242
+ )
1243
+ self.lateral1 = nn.Conv2d(
1244
+ encoder_dims[0],
1245
+ decoder_dim,
1246
+ kernel_size=1,
1247
+ )
1248
+
1249
+ # Smoothing convolutions
1250
+ self.smooth3 = nn.Sequential(
1251
+ nn.Conv2d(
1252
+ decoder_dim,
1253
+ decoder_dim,
1254
+ kernel_size=3,
1255
+ padding=1,
1256
+ groups=decoder_dim,
1257
+ ),
1258
+ nn.BatchNorm2d(decoder_dim),
1259
+ nn.GELU(),
1260
+ )
1261
+ self.smooth2 = nn.Sequential(
1262
+ nn.Conv2d(
1263
+ decoder_dim,
1264
+ decoder_dim,
1265
+ kernel_size=3,
1266
+ padding=1,
1267
+ groups=decoder_dim,
1268
+ ),
1269
+ nn.BatchNorm2d(decoder_dim),
1270
+ nn.GELU(),
1271
+ )
1272
+ self.smooth1 = nn.Sequential(
1273
+ nn.Conv2d(
1274
+ decoder_dim,
1275
+ decoder_dim,
1276
+ kernel_size=3,
1277
+ padding=1,
1278
+ groups=decoder_dim,
1279
+ ),
1280
+ nn.BatchNorm2d(decoder_dim),
1281
+ nn.GELU(),
1282
+ )
1283
+
1284
+ # Segmentation head
1285
+ self.head = nn.Conv2d(
1286
+ decoder_dim,
1287
+ num_classes,
1288
+ kernel_size=1,
1289
+ )
1290
+
1291
+ def forward(
1292
+ self,
1293
+ f1: torch.Tensor,
1294
+ f2: torch.Tensor,
1295
+ f3: torch.Tensor,
1296
+ target_size: tuple,
1297
+ ) -> torch.Tensor:
1298
+ """
1299
+ Args:
1300
+ f1, f2, f3: Multi-scale encoder features
1301
+ target_size: (H, W) of output
1302
+ Returns:
1303
+ Segmentation logits (B, num_classes, H, W)
1304
+ """
1305
+ # Top-down path with lateral connections
1306
+ p3 = self.lateral3(f3)
1307
+ p3 = self.smooth3(p3)
1308
+
1309
+ p2 = self.lateral2(
1310
+ f2
1311
+ ) + F.interpolate(
1312
+ p3,
1313
+ size=f2.shape[2:],
1314
+ mode="bilinear",
1315
+ align_corners=False,
1316
+ )
1317
+ p2 = self.smooth2(p2)
1318
+
1319
+ p1 = self.lateral1(
1320
+ f1
1321
+ ) + F.interpolate(
1322
+ p2,
1323
+ size=f1.shape[2:],
1324
+ mode="bilinear",
1325
+ align_corners=False,
1326
+ )
1327
+ p1 = self.smooth1(p1)
1328
+
1329
+ # Segmentation output
1330
+ out = self.head(p1)
1331
+ out = F.interpolate(
1332
+ out,
1333
+ size=target_size,
1334
+ mode="bilinear",
1335
+ align_corners=False,
1336
+ )
1337
+
1338
+ return out
1339
+
1340
+
1341
+ # =============================================================================
1342
+ # Complete NSA Pupil Segmentation Model
1343
+ # =============================================================================
1344
+
1345
+
1346
+ class NSAPupilSeg(nn.Module):
1347
+ """
1348
+ Native Sparse Attention model for Pupil Segmentation.
1349
+
1350
+ Architecture:
1351
+ - NSA Encoder: Hierarchical feature extraction with sparse attention
1352
+ - FPN Decoder: Multi-scale feature fusion for precise segmentation
1353
+
1354
+ Key NSA components for pupil segmentation:
1355
+ - Compression: Captures global eye context (is this an eye? rough pupil location)
1356
+ - Selection: Focuses on pupil region with fine-grained attention
1357
+ - Sliding Window: Precise local boundaries for pixel-accurate segmentation
1358
+ """
1359
+
1360
+ def __init__(
1361
+ self,
1362
+ in_channels: int = 1,
1363
+ num_classes: int = 2,
1364
+ embed_dims: tuple = (
1365
+ 32,
1366
+ 64,
1367
+ 96,
1368
+ ),
1369
+ depths: tuple = (1, 1, 1),
1370
+ num_heads: tuple = (2, 2, 4),
1371
+ mlp_ratios: tuple = (2, 2, 2),
1372
+ compress_block_sizes: tuple = (
1373
+ 4,
1374
+ 4,
1375
+ 4,
1376
+ ),
1377
+ compress_strides: tuple = (
1378
+ 2,
1379
+ 2,
1380
+ 2,
1381
+ ),
1382
+ select_block_sizes: tuple = (
1383
+ 4,
1384
+ 4,
1385
+ 4,
1386
+ ),
1387
+ num_selects: tuple = (4, 4, 4),
1388
+ window_sizes: tuple = (7, 7, 7),
1389
+ decoder_dim: int = 32,
1390
+ ):
1391
+ super().__init__()
1392
+
1393
+ self.encoder = NSAEncoder(
1394
+ in_channels=in_channels,
1395
+ embed_dims=embed_dims,
1396
+ depths=depths,
1397
+ num_heads=num_heads,
1398
+ mlp_ratios=mlp_ratios,
1399
+ compress_block_sizes=compress_block_sizes,
1400
+ compress_strides=compress_strides,
1401
+ select_block_sizes=select_block_sizes,
1402
+ num_selects=num_selects,
1403
+ window_sizes=window_sizes,
1404
+ )
1405
+
1406
+ self.decoder = (
1407
+ SegmentationDecoder(
1408
+ encoder_dims=embed_dims,
1409
+ decoder_dim=decoder_dim,
1410
+ num_classes=num_classes,
1411
+ )
1412
+ )
1413
+
1414
+ self._initialize_weights()
1415
+
1416
+ def _initialize_weights(self):
1417
+ """Initialize model weights."""
1418
+ for m in self.modules():
1419
+ if isinstance(m, nn.Conv2d):
1420
+ nn.init.kaiming_normal_(
1421
+ m.weight,
1422
+ mode="fan_out",
1423
+ nonlinearity="relu",
1424
+ )
1425
+ if m.bias is not None:
1426
+ nn.init.zeros_(
1427
+ m.bias
1428
+ )
1429
+ elif isinstance(
1430
+ m, nn.BatchNorm2d
1431
+ ):
1432
+ nn.init.ones_(m.weight)
1433
+ nn.init.zeros_(m.bias)
1434
+ elif isinstance(
1435
+ m, nn.Linear
1436
+ ):
1437
+ nn.init.trunc_normal_(
1438
+ m.weight, std=0.02
1439
+ )
1440
+ if m.bias is not None:
1441
+ nn.init.zeros_(
1442
+ m.bias
1443
+ )
1444
+ elif isinstance(
1445
+ m, nn.LayerNorm
1446
+ ):
1447
+ nn.init.ones_(m.weight)
1448
+ nn.init.zeros_(m.bias)
1449
+
1450
+ def forward(
1451
+ self, x: torch.Tensor
1452
+ ) -> torch.Tensor:
1453
+ """
1454
+ Args:
1455
+ x: Input image (B, C, H, W)
1456
+ Returns:
1457
+ Segmentation logits (B, num_classes, H, W)
1458
+ """
1459
+ target_size = (
1460
+ x.shape[2],
1461
+ x.shape[3],
1462
+ )
1463
+ f1, f2, f3 = self.encoder(x)
1464
+ out = self.decoder(
1465
+ f1, f2, f3, target_size
1466
+ )
1467
+ return out
1468
+
1469
+
1470
+ # =============================================================================
1471
+ # Loss Function (same as src/ for compatibility)
1472
+ # =============================================================================
1473
+
1474
+
1475
+ def focal_surface_loss(
1476
+ probs: torch.Tensor,
1477
+ dist_map: torch.Tensor,
1478
+ gamma: float = 2.0,
1479
+ ) -> torch.Tensor:
1480
+ """Surface loss with focal weighting for hard boundary pixels.
1481
+
1482
+ Args:
1483
+ probs: Predicted probabilities (B, C, H, W)
1484
+ dist_map: Distance transform (B, 2, H, W)
1485
+ gamma: Focal weighting exponent
1486
+
1487
+ Returns:
1488
+ Focal-weighted surface loss scalar
1489
+ """
1490
+ focal_weight = (1 - probs) ** gamma
1491
+ return (
1492
+ (focal_weight * probs * dist_map)
1493
+ .flatten(start_dim=2)
1494
+ .mean(dim=2)
1495
+ .mean(dim=1)
1496
+ .mean()
1497
+ )
1498
+
1499
+
1500
+ def boundary_dice_loss(
1501
+ probs: torch.Tensor,
1502
+ target: torch.Tensor,
1503
+ kernel_size: int = 3,
1504
+ epsilon: float = 1e-5,
1505
+ ) -> torch.Tensor:
1506
+ """Dice loss computed only on boundary pixels.
1507
+
1508
+ Args:
1509
+ probs: Predicted probabilities (B, C, H, W)
1510
+ target: Ground truth labels (B, H, W)
1511
+ kernel_size: Size of kernel for boundary extraction
1512
+ epsilon: Small constant for numerical stability
1513
+
1514
+ Returns:
1515
+ Boundary dice loss scalar
1516
+ """
1517
+ # Extract boundary via morphological gradient
1518
+ target_float = target.float().unsqueeze(1)
1519
+ padding = kernel_size // 2
1520
+ dilated = F.max_pool2d(
1521
+ target_float,
1522
+ kernel_size,
1523
+ stride=1,
1524
+ padding=padding,
1525
+ )
1526
+ eroded = -F.max_pool2d(
1527
+ -target_float,
1528
+ kernel_size,
1529
+ stride=1,
1530
+ padding=padding,
1531
+ )
1532
+ boundary = (dilated - eroded).squeeze(1) # (B, H, W)
1533
+
1534
+ # Compute Dice only on boundary pixels
1535
+ probs_pupil = probs[:, 1] # pupil class probabilities (B, H, W)
1536
+ probs_boundary = probs_pupil * boundary
1537
+ target_boundary = target.float() * boundary
1538
+
1539
+ intersection = (
1540
+ probs_boundary * target_boundary
1541
+ ).sum(dim=(1, 2))
1542
+ union = probs_boundary.sum(
1543
+ dim=(1, 2)
1544
+ ) + target_boundary.sum(dim=(1, 2))
1545
+
1546
+ dice = (
1547
+ 2.0 * intersection + epsilon
1548
+ ) / (union + epsilon)
1549
+ return (1.0 - dice).mean()
1550
+
1551
+
1552
+ class CombinedLoss(nn.Module):
1553
+ """
1554
+ Combined loss for pupil segmentation:
1555
+ - Weighted Cross Entropy: Handles class imbalance
1556
+ - Dice Loss: Better for small regions like pupils
1557
+ - Focal Surface Loss: Boundary-aware optimization with focal weighting
1558
+ - Boundary Dice Loss: Explicit optimization for edge pixels
1559
+ """
1560
+
1561
+ def __init__(
1562
+ self,
1563
+ epsilon: float = 1e-5,
1564
+ focal_gamma: float = 2.0,
1565
+ boundary_weight: float = 0.3,
1566
+ boundary_kernel_size: int = 3,
1567
+ ):
1568
+ super().__init__()
1569
+ self.epsilon = epsilon
1570
+ self.focal_gamma = focal_gamma
1571
+ self.boundary_weight = boundary_weight
1572
+ self.boundary_kernel_size = boundary_kernel_size
1573
+ self.nll = nn.NLLLoss(
1574
+ reduction="none"
1575
+ )
1576
+
1577
+ def forward(
1578
+ self,
1579
+ logits: torch.Tensor,
1580
+ target: torch.Tensor,
1581
+ spatial_weights: torch.Tensor,
1582
+ dist_map: torch.Tensor,
1583
+ alpha: float,
1584
+ eye_weight: torch.Tensor = None,
1585
+ ) -> tuple:
1586
+ """
1587
+ Args:
1588
+ logits: Model output (B, C, H, W)
1589
+ target: Ground truth (B, H, W)
1590
+ spatial_weights: Spatial weighting map (B, H, W)
1591
+ dist_map: Distance map for surface loss (B, 2, H, W)
1592
+ alpha: Balance between dice and surface loss
1593
+ eye_weight: Soft distance weighting from eye region (B, H, W)
1594
+ Returns:
1595
+ (total_loss, ce_loss, dice_loss, surface_loss, boundary_loss)
1596
+ """
1597
+ probs = F.softmax(logits, dim=1)
1598
+ log_probs = F.log_softmax(
1599
+ logits, dim=1
1600
+ )
1601
+
1602
+ # Weighted Cross Entropy
1603
+ ce_loss = self.nll(
1604
+ log_probs, target
1605
+ )
1606
+ # Apply spatial weights and optional eye weight
1607
+ weight_factor = 1.0 + spatial_weights
1608
+ if eye_weight is not None:
1609
+ weight_factor = weight_factor * eye_weight
1610
+ weighted_ce = (
1611
+ ce_loss * weight_factor
1612
+ ).mean()
1613
+
1614
+ # Dice Loss
1615
+ target_onehot = (
1616
+ F.one_hot(
1617
+ target, num_classes=2
1618
+ )
1619
+ .permute(0, 3, 1, 2)
1620
+ .float()
1621
+ )
1622
+ probs_flat = probs.flatten(
1623
+ start_dim=2
1624
+ )
1625
+ target_flat = (
1626
+ target_onehot.flatten(
1627
+ start_dim=2
1628
+ )
1629
+ )
1630
+
1631
+ intersection = (
1632
+ probs_flat * target_flat
1633
+ ).sum(dim=2)
1634
+ cardinality = (
1635
+ probs_flat + target_flat
1636
+ ).sum(dim=2)
1637
+ class_weights = 1.0 / (
1638
+ target_flat.sum(dim=2) ** 2
1639
+ ).clamp(min=self.epsilon)
1640
+
1641
+ dice = (
1642
+ 2.0
1643
+ * (
1644
+ class_weights
1645
+ * intersection
1646
+ ).sum(dim=1)
1647
+ / (
1648
+ class_weights
1649
+ * cardinality
1650
+ ).sum(dim=1)
1651
+ )
1652
+ dice_loss = (
1653
+ 1.0
1654
+ - dice.clamp(
1655
+ min=self.epsilon
1656
+ )
1657
+ ).mean()
1658
+
1659
+ # Focal Surface Loss (replaces standard surface loss)
1660
+ surface_loss = focal_surface_loss(
1661
+ probs,
1662
+ dist_map,
1663
+ gamma=self.focal_gamma,
1664
+ )
1665
+
1666
+ # Boundary Dice Loss
1667
+ bdice_loss = boundary_dice_loss(
1668
+ probs,
1669
+ target,
1670
+ kernel_size=self.boundary_kernel_size,
1671
+ epsilon=self.epsilon,
1672
+ )
1673
+
1674
+ # Total loss with updated weighting
1675
+ # Use max(1 - alpha, 0.2) for surface loss weight
1676
+ surface_weight = max(1.0 - alpha, 0.2)
1677
+ total_loss = (
1678
+ weighted_ce
1679
+ + alpha * dice_loss
1680
+ + surface_weight * surface_loss
1681
+ + self.boundary_weight * bdice_loss
1682
+ )
1683
+
1684
+ return (
1685
+ total_loss,
1686
+ weighted_ce,
1687
+ dice_loss,
1688
+ surface_loss,
1689
+ bdice_loss,
1690
+ )
1691
+
1692
+
1693
+ # =============================================================================
1694
+ # Factory function for easy model creation
1695
+ # =============================================================================
1696
+
1697
+
1698
+ def create_nsa_pupil_seg(
1699
+ size: str = "small",
1700
+ in_channels: int = 1,
1701
+ num_classes: int = 2,
1702
+ ) -> NSAPupilSeg:
1703
+ """
1704
+ Create NSA Pupil Segmentation model with predefined configurations.
1705
+
1706
+ Args:
1707
+ size: Model size ('pico', 'nano', 'tiny', 'small', 'medium')
1708
+ in_channels: Number of input channels
1709
+ num_classes: Number of output classes
1710
+ Returns:
1711
+ Configured NSAPupilSeg model
1712
+ """
1713
+ configs = {
1714
+ "pico": {
1715
+ "embed_dims": (4, 4, 4),
1716
+ "depths": (1, 1, 1),
1717
+ "num_heads": (1, 1, 1),
1718
+ "mlp_ratios": (
1719
+ 1.0,
1720
+ 1.0,
1721
+ 1.0,
1722
+ ),
1723
+ "compress_block_sizes": (
1724
+ 4,
1725
+ 4,
1726
+ 4,
1727
+ ),
1728
+ "compress_strides": (
1729
+ 4,
1730
+ 4,
1731
+ 4,
1732
+ ),
1733
+ "select_block_sizes": (
1734
+ 4,
1735
+ 4,
1736
+ 4,
1737
+ ),
1738
+ "num_selects": (1, 1, 1),
1739
+ "window_sizes": (3, 3, 3),
1740
+ "decoder_dim": 4,
1741
+ },
1742
+ "nano": {
1743
+ "embed_dims": (4, 8, 12),
1744
+ "depths": (1, 1, 1),
1745
+ "num_heads": (1, 1, 1),
1746
+ "mlp_ratios": (
1747
+ 1.0,
1748
+ 1.0,
1749
+ 1.0,
1750
+ ),
1751
+ "compress_block_sizes": (
1752
+ 4,
1753
+ 4,
1754
+ 4,
1755
+ ),
1756
+ "compress_strides": (
1757
+ 4,
1758
+ 4,
1759
+ 4,
1760
+ ),
1761
+ "select_block_sizes": (
1762
+ 4,
1763
+ 4,
1764
+ 4,
1765
+ ),
1766
+ "num_selects": (1, 1, 1),
1767
+ "window_sizes": (3, 3, 3),
1768
+ "decoder_dim": 4,
1769
+ },
1770
+ "tiny": {
1771
+ "embed_dims": (8, 12, 16),
1772
+ "depths": (1, 1, 1),
1773
+ "num_heads": (1, 1, 1),
1774
+ "mlp_ratios": (
1775
+ 1.5,
1776
+ 1.5,
1777
+ 1.5,
1778
+ ),
1779
+ "compress_block_sizes": (
1780
+ 4,
1781
+ 4,
1782
+ 4,
1783
+ ),
1784
+ "compress_strides": (
1785
+ 4,
1786
+ 4,
1787
+ 4,
1788
+ ),
1789
+ "select_block_sizes": (
1790
+ 4,
1791
+ 4,
1792
+ 4,
1793
+ ),
1794
+ "num_selects": (1, 1, 1),
1795
+ "window_sizes": (3, 3, 3),
1796
+ "decoder_dim": 8,
1797
+ },
1798
+ "small": {
1799
+ "embed_dims": (12, 24, 32),
1800
+ "depths": (1, 1, 1),
1801
+ "num_heads": (1, 1, 2),
1802
+ "mlp_ratios": (
1803
+ 1.5,
1804
+ 1.5,
1805
+ 1.5,
1806
+ ),
1807
+ "compress_block_sizes": (
1808
+ 4,
1809
+ 4,
1810
+ 4,
1811
+ ),
1812
+ "compress_strides": (
1813
+ 4,
1814
+ 4,
1815
+ 4,
1816
+ ),
1817
+ "select_block_sizes": (
1818
+ 4,
1819
+ 4,
1820
+ 4,
1821
+ ),
1822
+ "num_selects": (1, 1, 1),
1823
+ "window_sizes": (3, 3, 3),
1824
+ "decoder_dim": 12,
1825
+ },
1826
+ "medium": {
1827
+ "embed_dims": (16, 32, 48),
1828
+ "depths": (1, 1, 1),
1829
+ "num_heads": (1, 2, 2),
1830
+ "mlp_ratios": (
1831
+ 1.5,
1832
+ 1.5,
1833
+ 1.5,
1834
+ ),
1835
+ "compress_block_sizes": (
1836
+ 4,
1837
+ 4,
1838
+ 4,
1839
+ ),
1840
+ "compress_strides": (
1841
+ 3,
1842
+ 3,
1843
+ 3,
1844
+ ),
1845
+ "select_block_sizes": (
1846
+ 4,
1847
+ 4,
1848
+ 4,
1849
+ ),
1850
+ "num_selects": (2, 2, 2),
1851
+ "window_sizes": (3, 3, 3),
1852
+ "decoder_dim": 16,
1853
+ },
1854
+ }
1855
+
1856
+ if size not in configs:
1857
+ raise ValueError(
1858
+ f"Unknown size: {size}. Choose from {list(configs.keys())}"
1859
+ )
1860
+
1861
+ return NSAPupilSeg(
1862
+ in_channels=in_channels,
1863
+ num_classes=num_classes,
1864
+ **configs[size],
1865
+ )
1866
+
1867
+
1868
+ # =============================================================================
1869
+ # Testing / Verification
1870
+ # =============================================================================
1871
+
1872
+
1873
+ if __name__ == "__main__":
1874
+ # Test model creation and forward pass
1875
+ print(
1876
+ "Testing NSA Pupil Segmentation Model"
1877
+ )
1878
+ print("=" * 60)
1879
+
1880
+ # Create models of different sizes
1881
+ for size in [
1882
+ "pico",
1883
+ "nano",
1884
+ "tiny",
1885
+ "small",
1886
+ "medium",
1887
+ ]:
1888
+ model = create_nsa_pupil_seg(
1889
+ size=size
1890
+ )
1891
+
1892
+ # Count parameters
1893
+ n_params = sum(
1894
+ p.numel()
1895
+ for p in model.parameters()
1896
+ )
1897
+
1898
+ # Test forward pass
1899
+ x = torch.randn(
1900
+ 2, 1, 400, 640
1901
+ ) # OpenEDS image size
1902
+
1903
+ model.eval()
1904
+ with torch.no_grad():
1905
+ out = model(x)
1906
+
1907
+ print(
1908
+ f"\n{size.upper()} Model:"
1909
+ )
1910
+ print(
1911
+ f" Parameters: {n_params:,}"
1912
+ )
1913
+ print(
1914
+ f" Input shape: {x.shape}"
1915
+ )
1916
+ print(
1917
+ f" Output shape: {out.shape}"
1918
+ )
1919
+
1920
+ print("\n" + "=" * 60)
1921
+ print("All tests passed!")
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.21.0
3
+ opencv-python-headless>=4.5.0
4
+ mediapipe>=0.10.21
5
+ gradio==6.1.0
6
+ Pillow>=8.3.0