sddec25-01 / app.py
connerohnesorge
latest
01d8faa
#!/usr/bin/env python3
"""
NSA Pupil Segmentation Gradio Demo - Native Sparse Attention Web Application
This Gradio application performs real-time pupil segmentation on webcam input
using the NSAPupilSeg model (Native Sparse Attention). It demonstrates eye tracking
and pupil detection capabilities for the VisionAssist medical assistive technology project.
NSA Key Features:
- Token Compression: Global coarse-grained context
- Token Selection: Fine-grained focus on important regions (pupil)
- Sliding Window: Local context for precise boundaries
- Gated Aggregation: Learned combination of attention paths
"""
import cv2
import numpy as np
import torch
import gradio as gr
import mediapipe as mp
from nsa import create_nsa_pupil_seg
# =============================================================================
# Model Loading (at module startup)
# =============================================================================
print("Loading NSA Pupil Segmentation model...")
model = create_nsa_pupil_seg(size="pico", in_channels=1, num_classes=2)
checkpoint = torch.load("best_model.pth", map_location="cpu", weights_only=False)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded checkpoint with IoU: {checkpoint.get('valid_iou', 'N/A')}")
else:
model.load_state_dict(checkpoint)
model.eval()
print("Model loaded successfully!")
# =============================================================================
# MediaPipe Face Mesh Setup
# =============================================================================
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
)
# =============================================================================
# Constants (from demo.py - MUST match training exactly)
# =============================================================================
# MediaPipe left eye landmark indices (12 points around the eye)
LEFT_EYE_INDICES = [362, 385, 387, 263, 373, 380, 374, 381, 382, 384, 398, 466]
# Target aspect ratio for eye region (width:height = 640:400 = 1.6:1)
TARGET_ASPECT_RATIO = 640 / 400 # 1.6:1
# Model input/output dimensions
MODEL_WIDTH = 640
MODEL_HEIGHT = 400
# Preprocessing parameters (MUST match training exactly)
NORMALIZE_MEAN = 0.5
NORMALIZE_STD = 0.5
# Eye extraction settings
BBOX_PADDING = 0.2 # 20% padding on each side
MIN_EYE_REGION_SIZE = 50 # Minimum bounding box size
# Visualization settings
OVERLAY_ALPHA = 0.5
# =============================================================================
# Eye Region Extraction Function
# =============================================================================
def extract_eye_region(frame, landmarks):
"""
Extract left eye region from frame using MediaPipe landmarks.
Args:
frame: Input BGR frame
landmarks: MediaPipe face landmarks
Returns:
tuple: (eye_crop, bbox) where bbox is (x, y, w, h), or (None, None)
"""
h, w = frame.shape[:2]
# Extract left eye landmark coordinates
eye_points = np.array([
[int(landmarks.landmark[idx].x * w), int(landmarks.landmark[idx].y * h)]
for idx in LEFT_EYE_INDICES
], dtype=np.int32)
# Compute bounding box
x_min, y_min = eye_points.min(axis=0)
x_max, y_max = eye_points.max(axis=0)
bbox_w = x_max - x_min
bbox_h = y_max - y_min
# Check if eye region is large enough
if bbox_w < MIN_EYE_REGION_SIZE or bbox_h < MIN_EYE_REGION_SIZE:
return None, None
# Add padding (20% on each side)
pad_w = int(bbox_w * BBOX_PADDING)
pad_h = int(bbox_h * BBOX_PADDING)
x_min = max(0, x_min - pad_w)
y_min = max(0, y_min - pad_h)
x_max = min(w, x_max + pad_w)
y_max = min(h, y_max + pad_h)
bbox_w = x_max - x_min
bbox_h = y_max - y_min
# Expand to 1.6:1 aspect ratio (640:400)
current_ratio = bbox_w / bbox_h
if current_ratio < TARGET_ASPECT_RATIO:
# Too narrow, expand width
target_w = int(bbox_h * TARGET_ASPECT_RATIO)
diff = target_w - bbox_w
x_min = max(0, x_min - diff // 2)
x_max = min(w, x_max + diff // 2)
bbox_w = x_max - x_min
else:
# Too short, expand height
target_h = int(bbox_w / TARGET_ASPECT_RATIO)
diff = target_h - bbox_h
y_min = max(0, y_min - diff // 2)
y_max = min(h, y_max + diff // 2)
bbox_h = y_max - y_min
# Extract region
eye_crop = frame[y_min:y_max, x_min:x_max]
# Validate the crop is not empty
if eye_crop.size == 0:
return None, None
return eye_crop, (x_min, y_min, bbox_w, bbox_h)
# =============================================================================
# Preprocessing Function (CRITICAL - must match training exactly)
# =============================================================================
def preprocess(eye_crop):
"""
Preprocess eye region for model inference.
CRITICAL: Must match training preprocessing exactly.
Args:
eye_crop: BGR image of eye region
Returns:
torch.Tensor: Preprocessed tensor of shape (1, 1, 640, 400)
"""
# Step 1: Resize to model input size (640, 400)
resized = cv2.resize(
eye_crop,
(MODEL_WIDTH, MODEL_HEIGHT),
interpolation=cv2.INTER_LINEAR
)
# Step 2: Convert to grayscale
gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
# Step 3: Normalize to [-1, 1] range (mean=0.5, std=0.5)
normalized = (gray.astype(np.float32) / 255.0 - NORMALIZE_MEAN) / NORMALIZE_STD
# Step 4: Transpose to (1, 1, W, H) - model expects (B, C, W, H), NOT (B, C, H, W)
# normalized is (H, W) = (400, 640), we need (W, H) = (640, 400)
input_tensor = normalized.T[np.newaxis, np.newaxis, :, :]
return torch.from_numpy(input_tensor)
# =============================================================================
# Inference Function
# =============================================================================
def run_inference(input_tensor):
"""
Run model inference on preprocessed input.
Args:
input_tensor: Preprocessed tensor of shape (1, 1, 640, 400)
Returns:
np.ndarray: Binary segmentation mask of shape (400, 640)
"""
with torch.no_grad():
output = model(input_tensor)
# Convert output to numpy for post-processing
output_np = output.cpu().numpy()
# Post-processing: argmax to get binary mask
# Model outputs (B, C, W, H) = (1, 2, 640, 400), argmax over classes gives (640, 400)
# Transpose back to (H, W) = (400, 640) for visualization
mask = np.argmax(output_np[0], axis=0).T.astype(np.uint8)
return mask
# =============================================================================
# Visualization Function
# =============================================================================
def visualize(frame, eye_crop, mask, bbox, face_detected):
"""
Visualize segmentation results on frame.
Args:
frame: Original BGR frame
eye_crop: Eye region crop
mask: Binary segmentation mask (400, 640)
bbox: Bounding box (x, y, w, h)
face_detected: Whether face was detected
Returns:
np.ndarray: Annotated frame
"""
annotated = frame.copy()
# Draw status banner at top center
banner_height = 50
banner_w = annotated.shape[1]
# Semi-transparent black background for banner
banner_region = annotated[0:banner_height, 0:banner_w].astype(np.float32)
banner_region *= 0.5
annotated[0:banner_height, 0:banner_w] = banner_region.astype(np.uint8)
# Status text
if not face_detected:
status_text = "No Face Detected"
status_color = (0, 255, 255) # Yellow (BGR)
elif mask is None:
status_text = "Move Closer"
status_color = (0, 255, 255) # Yellow
else:
status_text = "Face Detected"
status_color = (0, 255, 0) # Green
text_size = cv2.getTextSize(status_text, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2)[0]
text_x = (banner_w - text_size[0]) // 2
text_y = (banner_height + text_size[1]) // 2
cv2.putText(
annotated,
status_text,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
status_color,
2,
)
# If we have a valid mask, overlay it on the eye region
if mask is not None and bbox is not None:
x, y, w, h = bbox
# Resize mask to match eye crop size
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
# Create green overlay where mask==1 (pupil detected)
green_overlay = np.zeros((h, w, 3), dtype=np.uint8)
green_overlay[mask_resized == 1] = (0, 255, 0) # Green in BGR
# Blend with original eye region
eye_region = annotated[y:y + h, x:x + w]
blended = cv2.addWeighted(
eye_region,
1 - OVERLAY_ALPHA,
green_overlay,
OVERLAY_ALPHA,
0
)
annotated[y:y + h, x:x + w] = blended
# Draw bounding box
cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 3)
# Draw model info (bottom-left)
cv2.putText(
annotated,
"NSA-pico",
(10, annotated.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
return annotated
# =============================================================================
# Main Process Function
# =============================================================================
def process_frame(image):
"""
Process a single frame from webcam for pupil segmentation.
Args:
image: Input RGB image from Gradio (numpy array)
Returns:
np.ndarray: Annotated RGB image for Gradio output
"""
if image is None:
return None
# Gradio provides RGB, convert to BGR for OpenCV
frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Run MediaPipe face detection on RGB image
results = face_mesh.process(image) # MediaPipe expects RGB
face_detected = results.multi_face_landmarks is not None
# Initialize variables
eye_crop = None
bbox = None
mask = None
# Process if face detected
if face_detected:
landmarks = results.multi_face_landmarks[0]
# Extract eye region (from BGR frame)
eye_crop, bbox = extract_eye_region(frame_bgr, landmarks)
if eye_crop is not None:
# Preprocess
input_tensor = preprocess(eye_crop)
# Run inference
mask = run_inference(input_tensor)
# Visualize (on BGR frame)
annotated_bgr = visualize(frame_bgr, eye_crop, mask, bbox, face_detected)
# Convert back to RGB for Gradio output
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
return annotated_rgb
# =============================================================================
# Gradio Interface
# =============================================================================
demo = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"),
outputs=gr.Image(label="Pupil Segmentation"),
live=True,
title="NSA Pupil Segmentation Demo",
description="""
Real-time pupil segmentation using Native Sparse Attention (NSA).
This demo uses the NSAPupilSeg model from the VisionAssist project to detect
and segment the pupil region in real-time from your webcam feed.
**How it works:**
1. MediaPipe Face Mesh detects your face and eye landmarks
2. The left eye region is extracted and preprocessed
3. The NSA model performs semantic segmentation to identify the pupil
4. Results are overlaid on the video feed with a green highlight
**Tips for best results:**
- Ensure good lighting on your face
- Look directly at the camera
- Keep your face within the frame
- Move closer if the eye region is too small
**Model:** NSA-pico (Native Sparse Attention)
""",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch()