#!/usr/bin/env python3 """ NSA Pupil Segmentation Gradio Demo - Native Sparse Attention Web Application This Gradio application performs real-time pupil segmentation on webcam input using the NSAPupilSeg model (Native Sparse Attention). It demonstrates eye tracking and pupil detection capabilities for the VisionAssist medical assistive technology project. NSA Key Features: - Token Compression: Global coarse-grained context - Token Selection: Fine-grained focus on important regions (pupil) - Sliding Window: Local context for precise boundaries - Gated Aggregation: Learned combination of attention paths """ import cv2 import numpy as np import torch import gradio as gr import mediapipe as mp from nsa import create_nsa_pupil_seg # ============================================================================= # Model Loading (at module startup) # ============================================================================= print("Loading NSA Pupil Segmentation model...") model = create_nsa_pupil_seg(size="pico", in_channels=1, num_classes=2) checkpoint = torch.load("best_model.pth", map_location="cpu", weights_only=False) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) print(f"Loaded checkpoint with IoU: {checkpoint.get('valid_iou', 'N/A')}") else: model.load_state_dict(checkpoint) model.eval() print("Model loaded successfully!") # ============================================================================= # MediaPipe Face Mesh Setup # ============================================================================= mp_face_mesh = mp.solutions.face_mesh face_mesh = mp_face_mesh.FaceMesh( max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5, ) # ============================================================================= # Constants (from demo.py - MUST match training exactly) # ============================================================================= # MediaPipe left eye landmark indices (12 points around the eye) LEFT_EYE_INDICES = [362, 385, 387, 263, 373, 380, 374, 381, 382, 384, 398, 466] # Target aspect ratio for eye region (width:height = 640:400 = 1.6:1) TARGET_ASPECT_RATIO = 640 / 400 # 1.6:1 # Model input/output dimensions MODEL_WIDTH = 640 MODEL_HEIGHT = 400 # Preprocessing parameters (MUST match training exactly) NORMALIZE_MEAN = 0.5 NORMALIZE_STD = 0.5 # Eye extraction settings BBOX_PADDING = 0.2 # 20% padding on each side MIN_EYE_REGION_SIZE = 50 # Minimum bounding box size # Visualization settings OVERLAY_ALPHA = 0.5 # ============================================================================= # Eye Region Extraction Function # ============================================================================= def extract_eye_region(frame, landmarks): """ Extract left eye region from frame using MediaPipe landmarks. Args: frame: Input BGR frame landmarks: MediaPipe face landmarks Returns: tuple: (eye_crop, bbox) where bbox is (x, y, w, h), or (None, None) """ h, w = frame.shape[:2] # Extract left eye landmark coordinates eye_points = np.array([ [int(landmarks.landmark[idx].x * w), int(landmarks.landmark[idx].y * h)] for idx in LEFT_EYE_INDICES ], dtype=np.int32) # Compute bounding box x_min, y_min = eye_points.min(axis=0) x_max, y_max = eye_points.max(axis=0) bbox_w = x_max - x_min bbox_h = y_max - y_min # Check if eye region is large enough if bbox_w < MIN_EYE_REGION_SIZE or bbox_h < MIN_EYE_REGION_SIZE: return None, None # Add padding (20% on each side) pad_w = int(bbox_w * BBOX_PADDING) pad_h = int(bbox_h * BBOX_PADDING) x_min = max(0, x_min - pad_w) y_min = max(0, y_min - pad_h) x_max = min(w, x_max + pad_w) y_max = min(h, y_max + pad_h) bbox_w = x_max - x_min bbox_h = y_max - y_min # Expand to 1.6:1 aspect ratio (640:400) current_ratio = bbox_w / bbox_h if current_ratio < TARGET_ASPECT_RATIO: # Too narrow, expand width target_w = int(bbox_h * TARGET_ASPECT_RATIO) diff = target_w - bbox_w x_min = max(0, x_min - diff // 2) x_max = min(w, x_max + diff // 2) bbox_w = x_max - x_min else: # Too short, expand height target_h = int(bbox_w / TARGET_ASPECT_RATIO) diff = target_h - bbox_h y_min = max(0, y_min - diff // 2) y_max = min(h, y_max + diff // 2) bbox_h = y_max - y_min # Extract region eye_crop = frame[y_min:y_max, x_min:x_max] # Validate the crop is not empty if eye_crop.size == 0: return None, None return eye_crop, (x_min, y_min, bbox_w, bbox_h) # ============================================================================= # Preprocessing Function (CRITICAL - must match training exactly) # ============================================================================= def preprocess(eye_crop): """ Preprocess eye region for model inference. CRITICAL: Must match training preprocessing exactly. Args: eye_crop: BGR image of eye region Returns: torch.Tensor: Preprocessed tensor of shape (1, 1, 640, 400) """ # Step 1: Resize to model input size (640, 400) resized = cv2.resize( eye_crop, (MODEL_WIDTH, MODEL_HEIGHT), interpolation=cv2.INTER_LINEAR ) # Step 2: Convert to grayscale gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) # Step 3: Normalize to [-1, 1] range (mean=0.5, std=0.5) normalized = (gray.astype(np.float32) / 255.0 - NORMALIZE_MEAN) / NORMALIZE_STD # Step 4: Transpose to (1, 1, W, H) - model expects (B, C, W, H), NOT (B, C, H, W) # normalized is (H, W) = (400, 640), we need (W, H) = (640, 400) input_tensor = normalized.T[np.newaxis, np.newaxis, :, :] return torch.from_numpy(input_tensor) # ============================================================================= # Inference Function # ============================================================================= def run_inference(input_tensor): """ Run model inference on preprocessed input. Args: input_tensor: Preprocessed tensor of shape (1, 1, 640, 400) Returns: np.ndarray: Binary segmentation mask of shape (400, 640) """ with torch.no_grad(): output = model(input_tensor) # Convert output to numpy for post-processing output_np = output.cpu().numpy() # Post-processing: argmax to get binary mask # Model outputs (B, C, W, H) = (1, 2, 640, 400), argmax over classes gives (640, 400) # Transpose back to (H, W) = (400, 640) for visualization mask = np.argmax(output_np[0], axis=0).T.astype(np.uint8) return mask # ============================================================================= # Visualization Function # ============================================================================= def visualize(frame, eye_crop, mask, bbox, face_detected): """ Visualize segmentation results on frame. Args: frame: Original BGR frame eye_crop: Eye region crop mask: Binary segmentation mask (400, 640) bbox: Bounding box (x, y, w, h) face_detected: Whether face was detected Returns: np.ndarray: Annotated frame """ annotated = frame.copy() # Draw status banner at top center banner_height = 50 banner_w = annotated.shape[1] # Semi-transparent black background for banner banner_region = annotated[0:banner_height, 0:banner_w].astype(np.float32) banner_region *= 0.5 annotated[0:banner_height, 0:banner_w] = banner_region.astype(np.uint8) # Status text if not face_detected: status_text = "No Face Detected" status_color = (0, 255, 255) # Yellow (BGR) elif mask is None: status_text = "Move Closer" status_color = (0, 255, 255) # Yellow else: status_text = "Face Detected" status_color = (0, 255, 0) # Green text_size = cv2.getTextSize(status_text, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2)[0] text_x = (banner_w - text_size[0]) // 2 text_y = (banner_height + text_size[1]) // 2 cv2.putText( annotated, status_text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1.0, status_color, 2, ) # If we have a valid mask, overlay it on the eye region if mask is not None and bbox is not None: x, y, w, h = bbox # Resize mask to match eye crop size mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) # Create green overlay where mask==1 (pupil detected) green_overlay = np.zeros((h, w, 3), dtype=np.uint8) green_overlay[mask_resized == 1] = (0, 255, 0) # Green in BGR # Blend with original eye region eye_region = annotated[y:y + h, x:x + w] blended = cv2.addWeighted( eye_region, 1 - OVERLAY_ALPHA, green_overlay, OVERLAY_ALPHA, 0 ) annotated[y:y + h, x:x + w] = blended # Draw bounding box cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 3) # Draw model info (bottom-left) cv2.putText( annotated, "NSA-pico", (10, annotated.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, ) return annotated # ============================================================================= # Main Process Function # ============================================================================= def process_frame(image): """ Process a single frame from webcam for pupil segmentation. Args: image: Input RGB image from Gradio (numpy array) Returns: np.ndarray: Annotated RGB image for Gradio output """ if image is None: return None # Gradio provides RGB, convert to BGR for OpenCV frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Run MediaPipe face detection on RGB image results = face_mesh.process(image) # MediaPipe expects RGB face_detected = results.multi_face_landmarks is not None # Initialize variables eye_crop = None bbox = None mask = None # Process if face detected if face_detected: landmarks = results.multi_face_landmarks[0] # Extract eye region (from BGR frame) eye_crop, bbox = extract_eye_region(frame_bgr, landmarks) if eye_crop is not None: # Preprocess input_tensor = preprocess(eye_crop) # Run inference mask = run_inference(input_tensor) # Visualize (on BGR frame) annotated_bgr = visualize(frame_bgr, eye_crop, mask, bbox, face_detected) # Convert back to RGB for Gradio output annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) return annotated_rgb # ============================================================================= # Gradio Interface # ============================================================================= demo = gr.Interface( fn=process_frame, inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"), outputs=gr.Image(label="Pupil Segmentation"), live=True, title="NSA Pupil Segmentation Demo", description=""" Real-time pupil segmentation using Native Sparse Attention (NSA). This demo uses the NSAPupilSeg model from the VisionAssist project to detect and segment the pupil region in real-time from your webcam feed. **How it works:** 1. MediaPipe Face Mesh detects your face and eye landmarks 2. The left eye region is extracted and preprocessed 3. The NSA model performs semantic segmentation to identify the pupil 4. Results are overlaid on the video feed with a green highlight **Tips for best results:** - Ensure good lighting on your face - Look directly at the camera - Keep your face within the frame - Move closer if the eye region is too small **Model:** NSA-pico (Native Sparse Attention) """, flagging_mode="never", ) if __name__ == "__main__": demo.launch()