metadata
license: mit
datasets:
- ShijianDeng/gazefollow
language:
- en
base_model:
- facebook/dinov2-large
pipeline_tag: keypoint-detection
GazeMoE: Gaze Estimation with Mixture-of-Experts
GazeMoE is a lightweight gaze estimation model (14MB decoder) built on top of a frozen DINOv2 Vit-L/14 backbone. It uses a Mixture-of-Experts (MoE) transformer decoder to predict whether a person's gaze target is inside or outside the camera frame and generates a heatmap for the gaze location.
papers: - https://arxiv.org/abs/2603.06256
Quick Start
1. Requirements
pip install torch torchvision timm huggingface_hub numpy Pillow
2. Hello World Example
This script downloads the weights, initializes the model (including the auto-download of DINOv2), and runs inference on a single image and bounding box.
import torch
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from gazemoe_builder import get_gazemoe_model
# --- 1. Load Model & Weights ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = get_gazemoe_model()
# Download custom 14MB weights from Hugging Face
weights_path = hf_hub_download(repo_id="zdai257/GazeMoE", filename="GazeMoE.pt")
state_dict = torch.load(weights_path, map_location=device)
model.load_gazemoe_state_dict(state_dict)
model.to(device).eval()
# --- 2. Prepare Input ---
# GazeMoE expects:
# - images: [B, 3, 448, 448] tensor
# - bboxes: A list of lists containing [xmin, ymin, xmax, ymax] normalized (0-1)
raw_image = Image.open("example.jpg").convert("RGB")
w, h = raw_image.size
# Example: One person with a head bounding box (normalized) OR Multi-person heads in a list
# Format: [xmin, ymin, xmax, ymax]
example_bbox = [0.4, 0.2, 0.55, 0.4]
inputs = {
"images": transform(raw_image).unsqueeze(dim=0).to(device),
"bboxes": [[example_bbox]]
}
# --- 3. Inference ---
with torch.no_grad():
preds = model(inputs)
# --- 4. Process Outputs ---
# 'inout' predicts if the gaze is Inside (IFT) or Outside (OFT) the frame
inout_prob = preds['inout'][0][0].item()
if inout_prob < 0.5:
print(f"Gaze is OUT-OF-FRAME (Prob: {inout_prob:.2f})")
else:
print(f"Gaze is IN-FRAME (Prob: {inout_prob:.2f})")
# Heatmap is 64x64. Get the (x, y) via argmax
heatmap = preds['heatmap'][0][0].cpu().numpy()
argmax = heatmap.flatten().argmax()
pred_y, pred_x = np.unravel_index(argmax, (64, 64))
# Normalize coordinates to 0-1
x_norm, y_norm = pred_x / 64.0, pred_y / 64.0
print(f"Estimated Gaze Target (Normalized): x={x_norm:.2f}, y={y_norm:.2f}")
print(f"Pixel Coordinates: X={x_norm * w:.1f}, Y={y_norm * h:.1f}")
Model Pipeline Details
Input Format
The model consumes a dictionary:
images: Atorch.Tensorof shape(Batch, 3, 448, 448). Use thetransformprovided by the factory function to ensure correct normalization and resizing.bboxes: A list of lists. Each sub-list corresponds to an image in the batch and contains the head bounding box proposals in normalized coordinates .
Output Decoding
The model returns a dictionary with two keys:
inout: A sigmoid output. Values indicate the person is looking at something outside the image boundaries.heatmap: A spatial map. The gaze target is typically identified by taking theargmaxof this map to find the peak intensity coordinate.
Citation
If you use this model in your research, please link to the original GitHub repository: zdai257/DisengageNet