Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| NSA Pupil Segmentation Gradio Demo - Native Sparse Attention Web Application | |
| This Gradio application performs real-time pupil segmentation on webcam input | |
| using the NSAPupilSeg model (Native Sparse Attention). It demonstrates eye tracking | |
| and pupil detection capabilities for the VisionAssist medical assistive technology project. | |
| NSA Key Features: | |
| - Token Compression: Global coarse-grained context | |
| - Token Selection: Fine-grained focus on important regions (pupil) | |
| - Sliding Window: Local context for precise boundaries | |
| - Gated Aggregation: Learned combination of attention paths | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import mediapipe as mp | |
| from nsa import create_nsa_pupil_seg | |
| # ============================================================================= | |
| # Model Loading (at module startup) | |
| # ============================================================================= | |
| print("Loading NSA Pupil Segmentation model...") | |
| model = create_nsa_pupil_seg(size="pico", in_channels=1, num_classes=2) | |
| checkpoint = torch.load("best_model.pth", map_location="cpu", weights_only=False) | |
| if "model_state_dict" in checkpoint: | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| print(f"Loaded checkpoint with IoU: {checkpoint.get('valid_iou', 'N/A')}") | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # ============================================================================= | |
| # MediaPipe Face Mesh Setup | |
| # ============================================================================= | |
| mp_face_mesh = mp.solutions.face_mesh | |
| face_mesh = mp_face_mesh.FaceMesh( | |
| max_num_faces=1, | |
| refine_landmarks=True, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5, | |
| ) | |
| # ============================================================================= | |
| # Constants (from demo.py - MUST match training exactly) | |
| # ============================================================================= | |
| # MediaPipe left eye landmark indices (12 points around the eye) | |
| LEFT_EYE_INDICES = [362, 385, 387, 263, 373, 380, 374, 381, 382, 384, 398, 466] | |
| # Target aspect ratio for eye region (width:height = 640:400 = 1.6:1) | |
| TARGET_ASPECT_RATIO = 640 / 400 # 1.6:1 | |
| # Model input/output dimensions | |
| MODEL_WIDTH = 640 | |
| MODEL_HEIGHT = 400 | |
| # Preprocessing parameters (MUST match training exactly) | |
| NORMALIZE_MEAN = 0.5 | |
| NORMALIZE_STD = 0.5 | |
| # Eye extraction settings | |
| BBOX_PADDING = 0.2 # 20% padding on each side | |
| MIN_EYE_REGION_SIZE = 50 # Minimum bounding box size | |
| # Visualization settings | |
| OVERLAY_ALPHA = 0.5 | |
| # ============================================================================= | |
| # Eye Region Extraction Function | |
| # ============================================================================= | |
| def extract_eye_region(frame, landmarks): | |
| """ | |
| Extract left eye region from frame using MediaPipe landmarks. | |
| Args: | |
| frame: Input BGR frame | |
| landmarks: MediaPipe face landmarks | |
| Returns: | |
| tuple: (eye_crop, bbox) where bbox is (x, y, w, h), or (None, None) | |
| """ | |
| h, w = frame.shape[:2] | |
| # Extract left eye landmark coordinates | |
| eye_points = np.array([ | |
| [int(landmarks.landmark[idx].x * w), int(landmarks.landmark[idx].y * h)] | |
| for idx in LEFT_EYE_INDICES | |
| ], dtype=np.int32) | |
| # Compute bounding box | |
| x_min, y_min = eye_points.min(axis=0) | |
| x_max, y_max = eye_points.max(axis=0) | |
| bbox_w = x_max - x_min | |
| bbox_h = y_max - y_min | |
| # Check if eye region is large enough | |
| if bbox_w < MIN_EYE_REGION_SIZE or bbox_h < MIN_EYE_REGION_SIZE: | |
| return None, None | |
| # Add padding (20% on each side) | |
| pad_w = int(bbox_w * BBOX_PADDING) | |
| pad_h = int(bbox_h * BBOX_PADDING) | |
| x_min = max(0, x_min - pad_w) | |
| y_min = max(0, y_min - pad_h) | |
| x_max = min(w, x_max + pad_w) | |
| y_max = min(h, y_max + pad_h) | |
| bbox_w = x_max - x_min | |
| bbox_h = y_max - y_min | |
| # Expand to 1.6:1 aspect ratio (640:400) | |
| current_ratio = bbox_w / bbox_h | |
| if current_ratio < TARGET_ASPECT_RATIO: | |
| # Too narrow, expand width | |
| target_w = int(bbox_h * TARGET_ASPECT_RATIO) | |
| diff = target_w - bbox_w | |
| x_min = max(0, x_min - diff // 2) | |
| x_max = min(w, x_max + diff // 2) | |
| bbox_w = x_max - x_min | |
| else: | |
| # Too short, expand height | |
| target_h = int(bbox_w / TARGET_ASPECT_RATIO) | |
| diff = target_h - bbox_h | |
| y_min = max(0, y_min - diff // 2) | |
| y_max = min(h, y_max + diff // 2) | |
| bbox_h = y_max - y_min | |
| # Extract region | |
| eye_crop = frame[y_min:y_max, x_min:x_max] | |
| # Validate the crop is not empty | |
| if eye_crop.size == 0: | |
| return None, None | |
| return eye_crop, (x_min, y_min, bbox_w, bbox_h) | |
| # ============================================================================= | |
| # Preprocessing Function (CRITICAL - must match training exactly) | |
| # ============================================================================= | |
| def preprocess(eye_crop): | |
| """ | |
| Preprocess eye region for model inference. | |
| CRITICAL: Must match training preprocessing exactly. | |
| Args: | |
| eye_crop: BGR image of eye region | |
| Returns: | |
| torch.Tensor: Preprocessed tensor of shape (1, 1, 640, 400) | |
| """ | |
| # Step 1: Resize to model input size (640, 400) | |
| resized = cv2.resize( | |
| eye_crop, | |
| (MODEL_WIDTH, MODEL_HEIGHT), | |
| interpolation=cv2.INTER_LINEAR | |
| ) | |
| # Step 2: Convert to grayscale | |
| gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY) | |
| # Step 3: Normalize to [-1, 1] range (mean=0.5, std=0.5) | |
| normalized = (gray.astype(np.float32) / 255.0 - NORMALIZE_MEAN) / NORMALIZE_STD | |
| # Step 4: Transpose to (1, 1, W, H) - model expects (B, C, W, H), NOT (B, C, H, W) | |
| # normalized is (H, W) = (400, 640), we need (W, H) = (640, 400) | |
| input_tensor = normalized.T[np.newaxis, np.newaxis, :, :] | |
| return torch.from_numpy(input_tensor) | |
| # ============================================================================= | |
| # Inference Function | |
| # ============================================================================= | |
| def run_inference(input_tensor): | |
| """ | |
| Run model inference on preprocessed input. | |
| Args: | |
| input_tensor: Preprocessed tensor of shape (1, 1, 640, 400) | |
| Returns: | |
| np.ndarray: Binary segmentation mask of shape (400, 640) | |
| """ | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| # Convert output to numpy for post-processing | |
| output_np = output.cpu().numpy() | |
| # Post-processing: argmax to get binary mask | |
| # Model outputs (B, C, W, H) = (1, 2, 640, 400), argmax over classes gives (640, 400) | |
| # Transpose back to (H, W) = (400, 640) for visualization | |
| mask = np.argmax(output_np[0], axis=0).T.astype(np.uint8) | |
| return mask | |
| # ============================================================================= | |
| # Visualization Function | |
| # ============================================================================= | |
| def visualize(frame, eye_crop, mask, bbox, face_detected): | |
| """ | |
| Visualize segmentation results on frame. | |
| Args: | |
| frame: Original BGR frame | |
| eye_crop: Eye region crop | |
| mask: Binary segmentation mask (400, 640) | |
| bbox: Bounding box (x, y, w, h) | |
| face_detected: Whether face was detected | |
| Returns: | |
| np.ndarray: Annotated frame | |
| """ | |
| annotated = frame.copy() | |
| # Draw status banner at top center | |
| banner_height = 50 | |
| banner_w = annotated.shape[1] | |
| # Semi-transparent black background for banner | |
| banner_region = annotated[0:banner_height, 0:banner_w].astype(np.float32) | |
| banner_region *= 0.5 | |
| annotated[0:banner_height, 0:banner_w] = banner_region.astype(np.uint8) | |
| # Status text | |
| if not face_detected: | |
| status_text = "No Face Detected" | |
| status_color = (0, 255, 255) # Yellow (BGR) | |
| elif mask is None: | |
| status_text = "Move Closer" | |
| status_color = (0, 255, 255) # Yellow | |
| else: | |
| status_text = "Face Detected" | |
| status_color = (0, 255, 0) # Green | |
| text_size = cv2.getTextSize(status_text, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2)[0] | |
| text_x = (banner_w - text_size[0]) // 2 | |
| text_y = (banner_height + text_size[1]) // 2 | |
| cv2.putText( | |
| annotated, | |
| status_text, | |
| (text_x, text_y), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1.0, | |
| status_color, | |
| 2, | |
| ) | |
| # If we have a valid mask, overlay it on the eye region | |
| if mask is not None and bbox is not None: | |
| x, y, w, h = bbox | |
| # Resize mask to match eye crop size | |
| mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) | |
| # Create green overlay where mask==1 (pupil detected) | |
| green_overlay = np.zeros((h, w, 3), dtype=np.uint8) | |
| green_overlay[mask_resized == 1] = (0, 255, 0) # Green in BGR | |
| # Blend with original eye region | |
| eye_region = annotated[y:y + h, x:x + w] | |
| blended = cv2.addWeighted( | |
| eye_region, | |
| 1 - OVERLAY_ALPHA, | |
| green_overlay, | |
| OVERLAY_ALPHA, | |
| 0 | |
| ) | |
| annotated[y:y + h, x:x + w] = blended | |
| # Draw bounding box | |
| cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 3) | |
| # Draw model info (bottom-left) | |
| cv2.putText( | |
| annotated, | |
| "NSA-pico", | |
| (10, annotated.shape[0] - 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.7, | |
| (0, 255, 0), | |
| 2, | |
| ) | |
| return annotated | |
| # ============================================================================= | |
| # Main Process Function | |
| # ============================================================================= | |
| def process_frame(image): | |
| """ | |
| Process a single frame from webcam for pupil segmentation. | |
| Args: | |
| image: Input RGB image from Gradio (numpy array) | |
| Returns: | |
| np.ndarray: Annotated RGB image for Gradio output | |
| """ | |
| if image is None: | |
| return None | |
| # Gradio provides RGB, convert to BGR for OpenCV | |
| frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| # Run MediaPipe face detection on RGB image | |
| results = face_mesh.process(image) # MediaPipe expects RGB | |
| face_detected = results.multi_face_landmarks is not None | |
| # Initialize variables | |
| eye_crop = None | |
| bbox = None | |
| mask = None | |
| # Process if face detected | |
| if face_detected: | |
| landmarks = results.multi_face_landmarks[0] | |
| # Extract eye region (from BGR frame) | |
| eye_crop, bbox = extract_eye_region(frame_bgr, landmarks) | |
| if eye_crop is not None: | |
| # Preprocess | |
| input_tensor = preprocess(eye_crop) | |
| # Run inference | |
| mask = run_inference(input_tensor) | |
| # Visualize (on BGR frame) | |
| annotated_bgr = visualize(frame_bgr, eye_crop, mask, bbox, face_detected) | |
| # Convert back to RGB for Gradio output | |
| annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) | |
| return annotated_rgb | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| demo = gr.Interface( | |
| fn=process_frame, | |
| inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"), | |
| outputs=gr.Image(label="Pupil Segmentation"), | |
| live=True, | |
| title="NSA Pupil Segmentation Demo", | |
| description=""" | |
| Real-time pupil segmentation using Native Sparse Attention (NSA). | |
| This demo uses the NSAPupilSeg model from the VisionAssist project to detect | |
| and segment the pupil region in real-time from your webcam feed. | |
| **How it works:** | |
| 1. MediaPipe Face Mesh detects your face and eye landmarks | |
| 2. The left eye region is extracted and preprocessed | |
| 3. The NSA model performs semantic segmentation to identify the pupil | |
| 4. Results are overlaid on the video feed with a green highlight | |
| **Tips for best results:** | |
| - Ensure good lighting on your face | |
| - Look directly at the camera | |
| - Keep your face within the frame | |
| - Move closer if the eye region is too small | |
| **Model:** NSA-pico (Native Sparse Attention) | |
| """, | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |