Spaces:
Sleeping
Sleeping
File size: 12,301 Bytes
a69fe43 1777497 a69fe43 1777497 a69fe43 1777497 a69fe43 01d8faa a69fe43 1777497 a69fe43 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 | #!/usr/bin/env python3
"""
NSA Pupil Segmentation Gradio Demo - Native Sparse Attention Web Application
This Gradio application performs real-time pupil segmentation on webcam input
using the NSAPupilSeg model (Native Sparse Attention). It demonstrates eye tracking
and pupil detection capabilities for the VisionAssist medical assistive technology project.
NSA Key Features:
- Token Compression: Global coarse-grained context
- Token Selection: Fine-grained focus on important regions (pupil)
- Sliding Window: Local context for precise boundaries
- Gated Aggregation: Learned combination of attention paths
"""
import cv2
import numpy as np
import torch
import gradio as gr
import mediapipe as mp
from nsa import create_nsa_pupil_seg
# =============================================================================
# Model Loading (at module startup)
# =============================================================================
print("Loading NSA Pupil Segmentation model...")
model = create_nsa_pupil_seg(size="pico", in_channels=1, num_classes=2)
checkpoint = torch.load("best_model.pth", map_location="cpu", weights_only=False)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded checkpoint with IoU: {checkpoint.get('valid_iou', 'N/A')}")
else:
model.load_state_dict(checkpoint)
model.eval()
print("Model loaded successfully!")
# =============================================================================
# MediaPipe Face Mesh Setup
# =============================================================================
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
)
# =============================================================================
# Constants (from demo.py - MUST match training exactly)
# =============================================================================
# MediaPipe left eye landmark indices (12 points around the eye)
LEFT_EYE_INDICES = [362, 385, 387, 263, 373, 380, 374, 381, 382, 384, 398, 466]
# Target aspect ratio for eye region (width:height = 640:400 = 1.6:1)
TARGET_ASPECT_RATIO = 640 / 400 # 1.6:1
# Model input/output dimensions
MODEL_WIDTH = 640
MODEL_HEIGHT = 400
# Preprocessing parameters (MUST match training exactly)
NORMALIZE_MEAN = 0.5
NORMALIZE_STD = 0.5
# Eye extraction settings
BBOX_PADDING = 0.2 # 20% padding on each side
MIN_EYE_REGION_SIZE = 50 # Minimum bounding box size
# Visualization settings
OVERLAY_ALPHA = 0.5
# =============================================================================
# Eye Region Extraction Function
# =============================================================================
def extract_eye_region(frame, landmarks):
"""
Extract left eye region from frame using MediaPipe landmarks.
Args:
frame: Input BGR frame
landmarks: MediaPipe face landmarks
Returns:
tuple: (eye_crop, bbox) where bbox is (x, y, w, h), or (None, None)
"""
h, w = frame.shape[:2]
# Extract left eye landmark coordinates
eye_points = np.array([
[int(landmarks.landmark[idx].x * w), int(landmarks.landmark[idx].y * h)]
for idx in LEFT_EYE_INDICES
], dtype=np.int32)
# Compute bounding box
x_min, y_min = eye_points.min(axis=0)
x_max, y_max = eye_points.max(axis=0)
bbox_w = x_max - x_min
bbox_h = y_max - y_min
# Check if eye region is large enough
if bbox_w < MIN_EYE_REGION_SIZE or bbox_h < MIN_EYE_REGION_SIZE:
return None, None
# Add padding (20% on each side)
pad_w = int(bbox_w * BBOX_PADDING)
pad_h = int(bbox_h * BBOX_PADDING)
x_min = max(0, x_min - pad_w)
y_min = max(0, y_min - pad_h)
x_max = min(w, x_max + pad_w)
y_max = min(h, y_max + pad_h)
bbox_w = x_max - x_min
bbox_h = y_max - y_min
# Expand to 1.6:1 aspect ratio (640:400)
current_ratio = bbox_w / bbox_h
if current_ratio < TARGET_ASPECT_RATIO:
# Too narrow, expand width
target_w = int(bbox_h * TARGET_ASPECT_RATIO)
diff = target_w - bbox_w
x_min = max(0, x_min - diff // 2)
x_max = min(w, x_max + diff // 2)
bbox_w = x_max - x_min
else:
# Too short, expand height
target_h = int(bbox_w / TARGET_ASPECT_RATIO)
diff = target_h - bbox_h
y_min = max(0, y_min - diff // 2)
y_max = min(h, y_max + diff // 2)
bbox_h = y_max - y_min
# Extract region
eye_crop = frame[y_min:y_max, x_min:x_max]
# Validate the crop is not empty
if eye_crop.size == 0:
return None, None
return eye_crop, (x_min, y_min, bbox_w, bbox_h)
# =============================================================================
# Preprocessing Function (CRITICAL - must match training exactly)
# =============================================================================
def preprocess(eye_crop):
"""
Preprocess eye region for model inference.
CRITICAL: Must match training preprocessing exactly.
Args:
eye_crop: BGR image of eye region
Returns:
torch.Tensor: Preprocessed tensor of shape (1, 1, 640, 400)
"""
# Step 1: Resize to model input size (640, 400)
resized = cv2.resize(
eye_crop,
(MODEL_WIDTH, MODEL_HEIGHT),
interpolation=cv2.INTER_LINEAR
)
# Step 2: Convert to grayscale
gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
# Step 3: Normalize to [-1, 1] range (mean=0.5, std=0.5)
normalized = (gray.astype(np.float32) / 255.0 - NORMALIZE_MEAN) / NORMALIZE_STD
# Step 4: Transpose to (1, 1, W, H) - model expects (B, C, W, H), NOT (B, C, H, W)
# normalized is (H, W) = (400, 640), we need (W, H) = (640, 400)
input_tensor = normalized.T[np.newaxis, np.newaxis, :, :]
return torch.from_numpy(input_tensor)
# =============================================================================
# Inference Function
# =============================================================================
def run_inference(input_tensor):
"""
Run model inference on preprocessed input.
Args:
input_tensor: Preprocessed tensor of shape (1, 1, 640, 400)
Returns:
np.ndarray: Binary segmentation mask of shape (400, 640)
"""
with torch.no_grad():
output = model(input_tensor)
# Convert output to numpy for post-processing
output_np = output.cpu().numpy()
# Post-processing: argmax to get binary mask
# Model outputs (B, C, W, H) = (1, 2, 640, 400), argmax over classes gives (640, 400)
# Transpose back to (H, W) = (400, 640) for visualization
mask = np.argmax(output_np[0], axis=0).T.astype(np.uint8)
return mask
# =============================================================================
# Visualization Function
# =============================================================================
def visualize(frame, eye_crop, mask, bbox, face_detected):
"""
Visualize segmentation results on frame.
Args:
frame: Original BGR frame
eye_crop: Eye region crop
mask: Binary segmentation mask (400, 640)
bbox: Bounding box (x, y, w, h)
face_detected: Whether face was detected
Returns:
np.ndarray: Annotated frame
"""
annotated = frame.copy()
# Draw status banner at top center
banner_height = 50
banner_w = annotated.shape[1]
# Semi-transparent black background for banner
banner_region = annotated[0:banner_height, 0:banner_w].astype(np.float32)
banner_region *= 0.5
annotated[0:banner_height, 0:banner_w] = banner_region.astype(np.uint8)
# Status text
if not face_detected:
status_text = "No Face Detected"
status_color = (0, 255, 255) # Yellow (BGR)
elif mask is None:
status_text = "Move Closer"
status_color = (0, 255, 255) # Yellow
else:
status_text = "Face Detected"
status_color = (0, 255, 0) # Green
text_size = cv2.getTextSize(status_text, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2)[0]
text_x = (banner_w - text_size[0]) // 2
text_y = (banner_height + text_size[1]) // 2
cv2.putText(
annotated,
status_text,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
status_color,
2,
)
# If we have a valid mask, overlay it on the eye region
if mask is not None and bbox is not None:
x, y, w, h = bbox
# Resize mask to match eye crop size
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
# Create green overlay where mask==1 (pupil detected)
green_overlay = np.zeros((h, w, 3), dtype=np.uint8)
green_overlay[mask_resized == 1] = (0, 255, 0) # Green in BGR
# Blend with original eye region
eye_region = annotated[y:y + h, x:x + w]
blended = cv2.addWeighted(
eye_region,
1 - OVERLAY_ALPHA,
green_overlay,
OVERLAY_ALPHA,
0
)
annotated[y:y + h, x:x + w] = blended
# Draw bounding box
cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 3)
# Draw model info (bottom-left)
cv2.putText(
annotated,
"NSA-pico",
(10, annotated.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
return annotated
# =============================================================================
# Main Process Function
# =============================================================================
def process_frame(image):
"""
Process a single frame from webcam for pupil segmentation.
Args:
image: Input RGB image from Gradio (numpy array)
Returns:
np.ndarray: Annotated RGB image for Gradio output
"""
if image is None:
return None
# Gradio provides RGB, convert to BGR for OpenCV
frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Run MediaPipe face detection on RGB image
results = face_mesh.process(image) # MediaPipe expects RGB
face_detected = results.multi_face_landmarks is not None
# Initialize variables
eye_crop = None
bbox = None
mask = None
# Process if face detected
if face_detected:
landmarks = results.multi_face_landmarks[0]
# Extract eye region (from BGR frame)
eye_crop, bbox = extract_eye_region(frame_bgr, landmarks)
if eye_crop is not None:
# Preprocess
input_tensor = preprocess(eye_crop)
# Run inference
mask = run_inference(input_tensor)
# Visualize (on BGR frame)
annotated_bgr = visualize(frame_bgr, eye_crop, mask, bbox, face_detected)
# Convert back to RGB for Gradio output
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
return annotated_rgb
# =============================================================================
# Gradio Interface
# =============================================================================
demo = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"),
outputs=gr.Image(label="Pupil Segmentation"),
live=True,
title="NSA Pupil Segmentation Demo",
description="""
Real-time pupil segmentation using Native Sparse Attention (NSA).
This demo uses the NSAPupilSeg model from the VisionAssist project to detect
and segment the pupil region in real-time from your webcam feed.
**How it works:**
1. MediaPipe Face Mesh detects your face and eye landmarks
2. The left eye region is extracted and preprocessed
3. The NSA model performs semantic segmentation to identify the pupil
4. Results are overlaid on the video feed with a green highlight
**Tips for best results:**
- Ensure good lighting on your face
- Look directly at the camera
- Keep your face within the frame
- Move closer if the eye region is too small
**Model:** NSA-pico (Native Sparse Attention)
""",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch()
|