SharpAI's picture
Upload sam2-hiera-base-plus ONNX models
2c11034 verified
---
license: apache-2.0
tags:
- sam2
- segment-anything
- onnx
- webgpu
- computer-vision
- image-segmentation
library_name: onnxruntime
---
# SAM2-HIERA-BASE-PLUS - ONNX Format for WebGPU
**Powered by [Segment Anything 2 (SAM2)](https://github.com/facebookresearch/segment-anything-2) from Meta Research**
This repository contains ONNX-converted models from [facebook/sam2-hiera-base-plus](https://huggingface.co/facebook/sam2-hiera-base-plus), optimized for WebGPU deployment in browsers.
## Model Information
- **Original Model**: [facebook/sam2-hiera-base-plus](https://huggingface.co/facebook/sam2-hiera-base-plus)
- **Version**: SAM 2.0
- **Size**: 80.8M parameters
- **Description**: Base Plus variant - high quality segmentation (recommended)
- **Format**: ONNX (encoder + decoder)
- **Optimization**: Encoder optimized to .ort format for WebGPU
## Files
- `encoder.onnx` - Image encoder (ONNX format)
- `encoder.with_runtime_opt.ort` - Image encoder (optimized for WebGPU)
- `decoder.onnx` - Mask decoder (ONNX format)
- `config.json` - Model configuration
## Usage
### In Browser with ONNX Runtime Web
```javascript
import * as ort from 'onnxruntime-web/webgpu';
// Load encoder (use optimized .ort version for WebGPU)
const encoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-base-plus-onnx/resolve/main/encoder.with_runtime_opt.ort';
const encoderSession = await ort.InferenceSession.create(encoderURL, {
executionProviders: ['webgpu'],
graphOptimizationLevel: 'disabled'
});
// Load decoder
const decoderURL = 'https://huggingface.co/SharpAI/sam2-hiera-base-plus-onnx/resolve/main/decoder.onnx';
const decoderSession = await ort.InferenceSession.create(decoderURL, {
executionProviders: ['webgpu']
});
// Run encoder
const imageData = preprocessImage(image); // Your preprocessing
const encoderOutputs = await encoderSession.run({ image: imageData });
// Run decoder with point
const point_coords = new ort.Tensor('float32', [x, y, 0, 0], [1, 2, 2]);
const point_labels = new ort.Tensor('float32', [1, -1], [1, 2]);
const mask_input = new ort.Tensor('float32', new Float32Array(256 * 256).fill(0), [1, 1, 256, 256]);
const has_mask_input = new ort.Tensor('float32', [0], [1]);
const decoderOutputs = await decoderSession.run({
image_embed: encoderOutputs.image_embed,
high_res_feats_0: encoderOutputs.high_res_feats_0,
high_res_feats_1: encoderOutputs.high_res_feats_1,
point_coords: point_coords,
point_labels: point_labels,
mask_input: mask_input,
has_mask_input: has_mask_input
});
// Get masks
const masks = decoderOutputs.masks; // Shape: [1, num_masks, 256, 256]
```
### In Python with ONNX Runtime
```python
import onnxruntime as ort
import numpy as np
# Load models
encoder_session = ort.InferenceSession("encoder.onnx")
decoder_session = ort.InferenceSession("decoder.onnx")
# Run encoder
encoder_outputs = encoder_session.run(None, {"image": image_tensor})
# Run decoder
decoder_outputs = decoder_session.run(None, {
"image_embed": encoder_outputs[0],
"high_res_feats_0": encoder_outputs[1],
"high_res_feats_1": encoder_outputs[2],
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": mask_input,
"has_mask_input": has_mask_input
})
masks = decoder_outputs[0]
```
## Input/Output Specifications
### Encoder
**Input:**
- `image`: Float32[1, 3, 1024, 1024] - Normalized RGB image
**Outputs:**
- `image_embed`: Float32[1, 256, 64, 64] - Image embeddings
- `high_res_feats_0`: Float32[1, 32, 256, 256] - High-res features (level 0)
- `high_res_feats_1`: Float32[1, 64, 128, 128] - High-res features (level 1)
### Decoder
**Inputs:**
- `image_embed`: Float32[1, 256, 64, 64] - From encoder
- `high_res_feats_0`: Float32[1, 32, 256, 256] - From encoder
- `high_res_feats_1`: Float32[1, 64, 128, 128] - From encoder
- `point_coords`: Float32[1, 2, 2] - Point coordinates [[x, y], [0, 0]]
- `point_labels`: Float32[1, 2] - Point labels [1, -1] (1=foreground, -1=padding)
- `mask_input`: Float32[1, 1, 256, 256] - Previous mask (zeros if none)
- `has_mask_input`: Float32[1] - Flag [0] or [1]
**Outputs:**
- `masks`: Float32[1, 3, 256, 256] - Generated masks (3 candidates)
- `iou_predictions`: Float32[1, 3] - IoU scores for each mask
- `low_res_masks`: Float32[1, 3, 256, 256] - Low-resolution masks
## Browser Requirements
- Chrome 113+ with WebGPU enabled (`chrome://flags/#enable-unsafe-webgpu`)
- Firefox Nightly with WebGPU enabled
- Safari Technology Preview with WebGPU enabled
## Performance
Typical inference times on Chrome with WebGPU:
- **Encoder**: {'2-3s' if 'tiny' in model_name else '3-5s' if 'small' in model_name else '4-6s' if 'base' in model_name else '8-10s'}
- **Decoder**: 0.1-0.5s per point
## License
This model is released under the Apache 2.0 license, following the original SAM2 model.
## Citation
```bibtex
@article{ravi2024sam2,
title={SAM 2: Segment Anything in Images and Videos},
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
journal={arXiv preprint arXiv:2408.00714},
year={2024}
}
```
## Related Resources
- **Original SAM2**: [facebookresearch/segment-anything-2](https://github.com/facebookresearch/segment-anything-2)
- **WebGPU Demo**: [Aegis AI SAM2 WebGPU Demo](https://github.com/yourusername/Aegis-AI/tree/main/tools/sam2-webgpu)
- **Conversion Tool**: [SAM2 ONNX Converter](https://github.com/yourusername/Aegis-AI/tree/main/tools/sam2-converter)
## Acknowledgments
- **Meta Research** for the original SAM2 model
- **Microsoft** for ONNX Runtime
- **SamExporter** for conversion tools
---
*Converted and optimized by [Aegis AI](https://github.com/yourusername/Aegis-AI)*