GazeMoE: Gaze Estimation with Mixture-of-Experts

GazeMoE is a lightweight gaze estimation model (14MB decoder) built on top of a frozen DINOv2 Vit-L/14 backbone. It uses a Mixture-of-Experts (MoE) transformer decoder to predict whether a person's gaze target is inside or outside the camera frame and generates a heatmap for the gaze location.


papers: - https://arxiv.org/abs/2603.06256

Quick Start

1. Requirements

pip install torch torchvision timm huggingface_hub numpy Pillow

2. Hello World Example

This script downloads the weights, initializes the model (including the auto-download of DINOv2), and runs inference on a single image and bounding box.

import torch
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from gazemoe_builder import get_gazemoe_model

# --- 1. Load Model & Weights ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = get_gazemoe_model()

# Download custom 14MB weights from Hugging Face
weights_path = hf_hub_download(repo_id="zdai257/GazeMoE", filename="GazeMoE.pt")
state_dict = torch.load(weights_path, map_location=device)
model.load_gazemoe_state_dict(state_dict)
model.to(device).eval()

# --- 2. Prepare Input ---
# GazeMoE expects: 
# - images: [B, 3, 448, 448] tensor
# - bboxes: A list of lists containing [xmin, ymin, xmax, ymax] normalized (0-1)
raw_image = Image.open("example.jpg").convert("RGB")
w, h = raw_image.size

# Example: One person with a head bounding box (normalized) OR Multi-person heads in a list
# Format: [xmin, ymin, xmax, ymax]
example_bbox = [0.4, 0.2, 0.55, 0.4] 

inputs = {
    "images": transform(raw_image).unsqueeze(dim=0).to(device),
    "bboxes": [[example_bbox]] 
}

# --- 3. Inference ---
with torch.no_grad():
    preds = model(inputs)

# --- 4. Process Outputs ---
# 'inout' predicts if the gaze is Inside (IFT) or Outside (OFT) the frame
inout_prob = preds['inout'][0][0].item() 

if inout_prob < 0.5:
    print(f"Gaze is OUT-OF-FRAME (Prob: {inout_prob:.2f})")
else:
    print(f"Gaze is IN-FRAME (Prob: {inout_prob:.2f})")
    
    # Heatmap is 64x64. Get the (x, y) via argmax
    heatmap = preds['heatmap'][0][0].cpu().numpy()
    
    argmax = heatmap.flatten().argmax()
    pred_y, pred_x = np.unravel_index(argmax, (64, 64))
    
    # Normalize coordinates to 0-1
    x_norm, y_norm = pred_x / 64.0, pred_y / 64.0
    
    print(f"Estimated Gaze Target (Normalized): x={x_norm:.2f}, y={y_norm:.2f}")
    print(f"Pixel Coordinates: X={x_norm * w:.1f}, Y={y_norm * h:.1f}")

Model Pipeline Details

Input Format

The model consumes a dictionary:

  • images: A torch.Tensor of shape (Batch, 3, 448, 448). Use the transform provided by the factory function to ensure correct normalization and resizing.
  • bboxes: A list of lists. Each sub-list corresponds to an image in the batch and contains the head bounding box proposals in normalized coordinates .

Output Decoding

The model returns a dictionary with two keys:

  1. inout: A sigmoid output. Values indicate the person is looking at something outside the image boundaries.
  2. heatmap: A spatial map. The gaze target is typically identified by taking the argmax of this map to find the peak intensity coordinate.

Citation

If you use this model in your research, please link to the original GitHub repository: zdai257/DisengageNet

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for zdai257/GazeMoE

Finetuned
(28)
this model

Dataset used to train zdai257/GazeMoE

Paper for zdai257/GazeMoE