|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- sam2 |
|
|
- segment-anything |
|
|
- onnx |
|
|
- webgpu |
|
|
- computer-vision |
|
|
- image-segmentation |
|
|
library_name: onnxruntime |
|
|
--- |
|
|
|
|
|
# SAM2-HIERA-BASE-PLUS - ONNX Format for WebGPU |
|
|
|
|
|
**Powered by [Segment Anything 2 (SAM2)](https://github.com/facebookresearch/segment-anything-2) from Meta Research** |
|
|
|
|
|
This repository contains ONNX-converted models from [facebook/sam2-hiera-base-plus](https://huggingface.co/facebook/sam2-hiera-base-plus), optimized for WebGPU deployment in browsers. |
|
|
|
|
|
## Model Information |
|
|
|
|
|
- **Original Model**: [facebook/sam2-hiera-base-plus](https://huggingface.co/facebook/sam2-hiera-base-plus) |
|
|
- **Version**: SAM 2.0 |
|
|
- **Size**: 80.8M parameters |
|
|
- **Description**: Base Plus variant - high quality segmentation (recommended) |
|
|
- **Format**: ONNX (encoder + decoder) |
|
|
- **Optimization**: Encoder optimized to .ort format for WebGPU |
|
|
|
|
|
## Files |
|
|
|
|
|
- `encoder.onnx` - Image encoder (ONNX format) |
|
|
- `encoder.with_runtime_opt.ort` - Image encoder (optimized for WebGPU) |
|
|
- `decoder.onnx` - Mask decoder (ONNX format) |
|
|
- `config.json` - Model configuration |
|
|
|
|
|
## Usage |
|
|
|
|
|
### In Browser with ONNX Runtime Web |
|
|
|
|
|
```javascript |
|
|
import * as ort from 'onnxruntime-web/webgpu'; |
|
|
|
|
|
// Load encoder (use optimized .ort version for WebGPU) |
|
|
const encoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-base-plus-onnx/resolve/main/encoder.with_runtime_opt.ort'; |
|
|
const encoderSession = await ort.InferenceSession.create(encoderURL, { |
|
|
executionProviders: ['webgpu'], |
|
|
graphOptimizationLevel: 'disabled' |
|
|
}); |
|
|
|
|
|
// Load decoder |
|
|
const decoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-base-plus-onnx/resolve/main/decoder.onnx'; |
|
|
const decoderSession = await ort.InferenceSession.create(decoderURL, { |
|
|
executionProviders: ['webgpu'] |
|
|
}); |
|
|
|
|
|
// Run encoder |
|
|
const imageData = preprocessImage(image); // Your preprocessing |
|
|
const encoderOutputs = await encoderSession.run({ image: imageData }); |
|
|
|
|
|
// Run decoder with point |
|
|
const point_coords = new ort.Tensor('float32', [x, y, 0, 0], [1, 2, 2]); |
|
|
const point_labels = new ort.Tensor('float32', [1, -1], [1, 2]); |
|
|
const mask_input = new ort.Tensor('float32', new Float32Array(256 * 256).fill(0), [1, 1, 256, 256]); |
|
|
const has_mask_input = new ort.Tensor('float32', [0], [1]); |
|
|
|
|
|
const decoderOutputs = await decoderSession.run({ |
|
|
image_embed: encoderOutputs.image_embed, |
|
|
high_res_feats_0: encoderOutputs.high_res_feats_0, |
|
|
high_res_feats_1: encoderOutputs.high_res_feats_1, |
|
|
point_coords: point_coords, |
|
|
point_labels: point_labels, |
|
|
mask_input: mask_input, |
|
|
has_mask_input: has_mask_input |
|
|
}); |
|
|
|
|
|
// Get masks |
|
|
const masks = decoderOutputs.masks; // Shape: [1, num_masks, 256, 256] |
|
|
``` |
|
|
|
|
|
### In Python with ONNX Runtime |
|
|
|
|
|
```python |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
|
|
|
# Load models |
|
|
encoder_session = ort.InferenceSession("encoder.onnx") |
|
|
decoder_session = ort.InferenceSession("decoder.onnx") |
|
|
|
|
|
# Run encoder |
|
|
encoder_outputs = encoder_session.run(None, {"image": image_tensor}) |
|
|
|
|
|
# Run decoder |
|
|
decoder_outputs = decoder_session.run(None, { |
|
|
"image_embed": encoder_outputs[0], |
|
|
"high_res_feats_0": encoder_outputs[1], |
|
|
"high_res_feats_1": encoder_outputs[2], |
|
|
"point_coords": point_coords, |
|
|
"point_labels": point_labels, |
|
|
"mask_input": mask_input, |
|
|
"has_mask_input": has_mask_input |
|
|
}) |
|
|
|
|
|
masks = decoder_outputs[0] |
|
|
``` |
|
|
|
|
|
## Input/Output Specifications |
|
|
|
|
|
### Encoder |
|
|
|
|
|
**Input:** |
|
|
- `image`: Float32[1, 3, 1024, 1024] - Normalized RGB image |
|
|
|
|
|
**Outputs:** |
|
|
- `image_embed`: Float32[1, 256, 64, 64] - Image embeddings |
|
|
- `high_res_feats_0`: Float32[1, 32, 256, 256] - High-res features (level 0) |
|
|
- `high_res_feats_1`: Float32[1, 64, 128, 128] - High-res features (level 1) |
|
|
|
|
|
### Decoder |
|
|
|
|
|
**Inputs:** |
|
|
- `image_embed`: Float32[1, 256, 64, 64] - From encoder |
|
|
- `high_res_feats_0`: Float32[1, 32, 256, 256] - From encoder |
|
|
- `high_res_feats_1`: Float32[1, 64, 128, 128] - From encoder |
|
|
- `point_coords`: Float32[1, 2, 2] - Point coordinates [[x, y], [0, 0]] |
|
|
- `point_labels`: Float32[1, 2] - Point labels [1, -1] (1=foreground, -1=padding) |
|
|
- `mask_input`: Float32[1, 1, 256, 256] - Previous mask (zeros if none) |
|
|
- `has_mask_input`: Float32[1] - Flag [0] or [1] |
|
|
|
|
|
**Outputs:** |
|
|
- `masks`: Float32[1, 3, 256, 256] - Generated masks (3 candidates) |
|
|
- `iou_predictions`: Float32[1, 3] - IoU scores for each mask |
|
|
- `low_res_masks`: Float32[1, 3, 256, 256] - Low-resolution masks |
|
|
|
|
|
## Browser Requirements |
|
|
|
|
|
- Chrome 113+ with WebGPU enabled (`chrome://flags/#enable-unsafe-webgpu`) |
|
|
- Firefox Nightly with WebGPU enabled |
|
|
- Safari Technology Preview with WebGPU enabled |
|
|
|
|
|
## Performance |
|
|
|
|
|
Typical inference times on Chrome with WebGPU: |
|
|
- **Encoder**: {'2-3s' if 'tiny' in model_name else '3-5s' if 'small' in model_name else '4-6s' if 'base' in model_name else '8-10s'} |
|
|
- **Decoder**: 0.1-0.5s per point |
|
|
|
|
|
## License |
|
|
|
|
|
This model is released under the Apache 2.0 license, following the original SAM2 model. |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@article{ravi2024sam2, |
|
|
title={SAM 2: Segment Anything in Images and Videos}, |
|
|
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, |
|
|
journal={arXiv preprint arXiv:2408.00714}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Related Resources |
|
|
|
|
|
- **Original SAM2**: [facebookresearch/segment-anything-2](https://github.com/facebookresearch/segment-anything-2) |
|
|
- **WebGPU Demo**: [Aegis AI SAM2 WebGPU Demo](https://github.com/yourusername/Aegis-AI/tree/main/tools/sam2-webgpu) |
|
|
- **Conversion Tool**: [SAM2 ONNX Converter](https://github.com/yourusername/Aegis-AI/tree/main/tools/sam2-converter) |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- **Meta Research** for the original SAM2 model |
|
|
- **Microsoft** for ONNX Runtime |
|
|
- **SamExporter** for conversion tools |
|
|
|
|
|
--- |
|
|
|
|
|
*Converted and optimized by [Aegis AI](https://github.com/yourusername/Aegis-AI)* |
|
|
|