File size: 12,301 Bytes
a69fe43
 
 
1777497
a69fe43
 
 
 
 
 
 
 
 
 
 
 
 
 
1777497
a69fe43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1777497
a69fe43
 
01d8faa
a69fe43
1777497
a69fe43
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
#!/usr/bin/env python3
"""
NSA Pupil Segmentation Gradio Demo - Native Sparse Attention Web Application

This Gradio application performs real-time pupil segmentation on webcam input
using the NSAPupilSeg model (Native Sparse Attention). It demonstrates eye tracking
and pupil detection capabilities for the VisionAssist medical assistive technology project.

NSA Key Features:
- Token Compression: Global coarse-grained context
- Token Selection: Fine-grained focus on important regions (pupil)
- Sliding Window: Local context for precise boundaries
- Gated Aggregation: Learned combination of attention paths
"""

import cv2
import numpy as np
import torch
import gradio as gr
import mediapipe as mp

from nsa import create_nsa_pupil_seg

# =============================================================================
# Model Loading (at module startup)
# =============================================================================

print("Loading NSA Pupil Segmentation model...")

model = create_nsa_pupil_seg(size="pico", in_channels=1, num_classes=2)
checkpoint = torch.load("best_model.pth", map_location="cpu", weights_only=False)
if "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"Loaded checkpoint with IoU: {checkpoint.get('valid_iou', 'N/A')}")
else:
    model.load_state_dict(checkpoint)
model.eval()

print("Model loaded successfully!")

# =============================================================================
# MediaPipe Face Mesh Setup
# =============================================================================

mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
    max_num_faces=1,
    refine_landmarks=True,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5,
)

# =============================================================================
# Constants (from demo.py - MUST match training exactly)
# =============================================================================

# MediaPipe left eye landmark indices (12 points around the eye)
LEFT_EYE_INDICES = [362, 385, 387, 263, 373, 380, 374, 381, 382, 384, 398, 466]

# Target aspect ratio for eye region (width:height = 640:400 = 1.6:1)
TARGET_ASPECT_RATIO = 640 / 400  # 1.6:1

# Model input/output dimensions
MODEL_WIDTH = 640
MODEL_HEIGHT = 400

# Preprocessing parameters (MUST match training exactly)
NORMALIZE_MEAN = 0.5
NORMALIZE_STD = 0.5

# Eye extraction settings
BBOX_PADDING = 0.2  # 20% padding on each side
MIN_EYE_REGION_SIZE = 50  # Minimum bounding box size

# Visualization settings
OVERLAY_ALPHA = 0.5


# =============================================================================
# Eye Region Extraction Function
# =============================================================================

def extract_eye_region(frame, landmarks):
    """
    Extract left eye region from frame using MediaPipe landmarks.

    Args:
        frame: Input BGR frame
        landmarks: MediaPipe face landmarks

    Returns:
        tuple: (eye_crop, bbox) where bbox is (x, y, w, h), or (None, None)
    """
    h, w = frame.shape[:2]

    # Extract left eye landmark coordinates
    eye_points = np.array([
        [int(landmarks.landmark[idx].x * w), int(landmarks.landmark[idx].y * h)]
        for idx in LEFT_EYE_INDICES
    ], dtype=np.int32)

    # Compute bounding box
    x_min, y_min = eye_points.min(axis=0)
    x_max, y_max = eye_points.max(axis=0)

    bbox_w = x_max - x_min
    bbox_h = y_max - y_min

    # Check if eye region is large enough
    if bbox_w < MIN_EYE_REGION_SIZE or bbox_h < MIN_EYE_REGION_SIZE:
        return None, None

    # Add padding (20% on each side)
    pad_w = int(bbox_w * BBOX_PADDING)
    pad_h = int(bbox_h * BBOX_PADDING)

    x_min = max(0, x_min - pad_w)
    y_min = max(0, y_min - pad_h)
    x_max = min(w, x_max + pad_w)
    y_max = min(h, y_max + pad_h)

    bbox_w = x_max - x_min
    bbox_h = y_max - y_min

    # Expand to 1.6:1 aspect ratio (640:400)
    current_ratio = bbox_w / bbox_h
    if current_ratio < TARGET_ASPECT_RATIO:
        # Too narrow, expand width
        target_w = int(bbox_h * TARGET_ASPECT_RATIO)
        diff = target_w - bbox_w
        x_min = max(0, x_min - diff // 2)
        x_max = min(w, x_max + diff // 2)
        bbox_w = x_max - x_min
    else:
        # Too short, expand height
        target_h = int(bbox_w / TARGET_ASPECT_RATIO)
        diff = target_h - bbox_h
        y_min = max(0, y_min - diff // 2)
        y_max = min(h, y_max + diff // 2)
        bbox_h = y_max - y_min

    # Extract region
    eye_crop = frame[y_min:y_max, x_min:x_max]

    # Validate the crop is not empty
    if eye_crop.size == 0:
        return None, None

    return eye_crop, (x_min, y_min, bbox_w, bbox_h)


# =============================================================================
# Preprocessing Function (CRITICAL - must match training exactly)
# =============================================================================

def preprocess(eye_crop):
    """
    Preprocess eye region for model inference.
    CRITICAL: Must match training preprocessing exactly.

    Args:
        eye_crop: BGR image of eye region

    Returns:
        torch.Tensor: Preprocessed tensor of shape (1, 1, 640, 400)
    """
    # Step 1: Resize to model input size (640, 400)
    resized = cv2.resize(
        eye_crop,
        (MODEL_WIDTH, MODEL_HEIGHT),
        interpolation=cv2.INTER_LINEAR
    )

    # Step 2: Convert to grayscale
    gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)

    # Step 3: Normalize to [-1, 1] range (mean=0.5, std=0.5)
    normalized = (gray.astype(np.float32) / 255.0 - NORMALIZE_MEAN) / NORMALIZE_STD

    # Step 4: Transpose to (1, 1, W, H) - model expects (B, C, W, H), NOT (B, C, H, W)
    # normalized is (H, W) = (400, 640), we need (W, H) = (640, 400)
    input_tensor = normalized.T[np.newaxis, np.newaxis, :, :]

    return torch.from_numpy(input_tensor)


# =============================================================================
# Inference Function
# =============================================================================

def run_inference(input_tensor):
    """
    Run model inference on preprocessed input.

    Args:
        input_tensor: Preprocessed tensor of shape (1, 1, 640, 400)

    Returns:
        np.ndarray: Binary segmentation mask of shape (400, 640)
    """
    with torch.no_grad():
        output = model(input_tensor)

    # Convert output to numpy for post-processing
    output_np = output.cpu().numpy()

    # Post-processing: argmax to get binary mask
    # Model outputs (B, C, W, H) = (1, 2, 640, 400), argmax over classes gives (640, 400)
    # Transpose back to (H, W) = (400, 640) for visualization
    mask = np.argmax(output_np[0], axis=0).T.astype(np.uint8)

    return mask


# =============================================================================
# Visualization Function
# =============================================================================

def visualize(frame, eye_crop, mask, bbox, face_detected):
    """
    Visualize segmentation results on frame.

    Args:
        frame: Original BGR frame
        eye_crop: Eye region crop
        mask: Binary segmentation mask (400, 640)
        bbox: Bounding box (x, y, w, h)
        face_detected: Whether face was detected

    Returns:
        np.ndarray: Annotated frame
    """
    annotated = frame.copy()

    # Draw status banner at top center
    banner_height = 50
    banner_w = annotated.shape[1]

    # Semi-transparent black background for banner
    banner_region = annotated[0:banner_height, 0:banner_w].astype(np.float32)
    banner_region *= 0.5
    annotated[0:banner_height, 0:banner_w] = banner_region.astype(np.uint8)

    # Status text
    if not face_detected:
        status_text = "No Face Detected"
        status_color = (0, 255, 255)  # Yellow (BGR)
    elif mask is None:
        status_text = "Move Closer"
        status_color = (0, 255, 255)  # Yellow
    else:
        status_text = "Face Detected"
        status_color = (0, 255, 0)  # Green

    text_size = cv2.getTextSize(status_text, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2)[0]
    text_x = (banner_w - text_size[0]) // 2
    text_y = (banner_height + text_size[1]) // 2
    cv2.putText(
        annotated,
        status_text,
        (text_x, text_y),
        cv2.FONT_HERSHEY_SIMPLEX,
        1.0,
        status_color,
        2,
    )

    # If we have a valid mask, overlay it on the eye region
    if mask is not None and bbox is not None:
        x, y, w, h = bbox

        # Resize mask to match eye crop size
        mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)

        # Create green overlay where mask==1 (pupil detected)
        green_overlay = np.zeros((h, w, 3), dtype=np.uint8)
        green_overlay[mask_resized == 1] = (0, 255, 0)  # Green in BGR

        # Blend with original eye region
        eye_region = annotated[y:y + h, x:x + w]
        blended = cv2.addWeighted(
            eye_region,
            1 - OVERLAY_ALPHA,
            green_overlay,
            OVERLAY_ALPHA,
            0
        )
        annotated[y:y + h, x:x + w] = blended

        # Draw bounding box
        cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 3)

    # Draw model info (bottom-left)
    cv2.putText(
        annotated,
        "NSA-pico",
        (10, annotated.shape[0] - 20),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.7,
        (0, 255, 0),
        2,
    )

    return annotated


# =============================================================================
# Main Process Function
# =============================================================================

def process_frame(image):
    """
    Process a single frame from webcam for pupil segmentation.

    Args:
        image: Input RGB image from Gradio (numpy array)

    Returns:
        np.ndarray: Annotated RGB image for Gradio output
    """
    if image is None:
        return None

    # Gradio provides RGB, convert to BGR for OpenCV
    frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # Run MediaPipe face detection on RGB image
    results = face_mesh.process(image)  # MediaPipe expects RGB
    face_detected = results.multi_face_landmarks is not None

    # Initialize variables
    eye_crop = None
    bbox = None
    mask = None

    # Process if face detected
    if face_detected:
        landmarks = results.multi_face_landmarks[0]

        # Extract eye region (from BGR frame)
        eye_crop, bbox = extract_eye_region(frame_bgr, landmarks)

        if eye_crop is not None:
            # Preprocess
            input_tensor = preprocess(eye_crop)

            # Run inference
            mask = run_inference(input_tensor)

    # Visualize (on BGR frame)
    annotated_bgr = visualize(frame_bgr, eye_crop, mask, bbox, face_detected)

    # Convert back to RGB for Gradio output
    annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)

    return annotated_rgb


# =============================================================================
# Gradio Interface
# =============================================================================

demo = gr.Interface(
    fn=process_frame,
    inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"),
    outputs=gr.Image(label="Pupil Segmentation"),
    live=True,
    title="NSA Pupil Segmentation Demo",
    description="""
    Real-time pupil segmentation using Native Sparse Attention (NSA).

    This demo uses the NSAPupilSeg model from the VisionAssist project to detect
    and segment the pupil region in real-time from your webcam feed.

    **How it works:**
    1. MediaPipe Face Mesh detects your face and eye landmarks
    2. The left eye region is extracted and preprocessed
    3. The NSA model performs semantic segmentation to identify the pupil
    4. Results are overlaid on the video feed with a green highlight

    **Tips for best results:**
    - Ensure good lighting on your face
    - Look directly at the camera
    - Keep your face within the frame
    - Move closer if the eye region is too small

    **Model:** NSA-pico (Native Sparse Attention)
    """,
    flagging_mode="never",
)

if __name__ == "__main__":
    demo.launch()