File size: 5,878 Bytes
afe5d28 03384e3 72672aa 03384e3 afe5d28 03384e3 afe5d28 03384e3 afe5d28 03384e3 afe5d28 72672aa afe5d28 4c217bf 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa afe5d28 72672aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
---
datasets:
- ILSVRC/imagenet-21k
license: other
license_name: nvclv1
license_link: LICENSE
pipeline_tag: image-classification
library_name: transformers
---
# MambaVision: A Hybrid Mamba-Transformer Vision Backbone
[**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083)
## Model Description
We propose a novel hybrid Mamba-Transformer backbone, denoted as MambaVision, which is specifically tailored for vision applications. Our core contribution includes redesigning the Mamba formulation to enhance its capability for efficient modeling of visual features. In addition, we conduct a comprehensive ablation study on the feasibility of integrating Vision Transformers (ViT) with Mamba. Our results demonstrate that equipping the Mamba architecture with several self-attention blocks at the final layers greatly improves the modeling capacity to capture long-range spatial dependencies. Based on our findings, we introduce a family of MambaVision models with a hierarchical architecture to meet various design criteria. For Image classification on ImageNet-1K dataset, MambaVision model variants achieve a new State-of-the-Art (SOTA) performance in terms of Top-1 accuracy and image throughput. In downstream tasks such as object detection, instance segmentation and semantic segmentation on MS COCO and ADE20K datasets, MambaVision outperforms comparably-sized backbones and demonstrates more favorable performance. Code: https://github.com/NVlabs/MambaVision.
## Model Performance
MambaVision-L-21K is pretrained on ImageNet-21K dataset and finetuned on ImageNet-1K.
<table>
<tr>
<th>Name</th>
<th>Acc@1(%)</th>
<th>Acc@5(%)</th>
<th>#Params(M)</th>
<th>FLOPs(G)</th>
<th>Resolution</th>
</tr>
<tr>
<td>MambaVision-L-21K</td>
<td>86.1</td>
<td>97.9</td>
<td>227.9</td>
<td>34.9</td>
<td>224x224</td>
</tr>
</table>
In addition, the MambaVision models demonstrate a strong performance by achieving a new SOTA Pareto-front in
terms of Top-1 accuracy and throughput.
<p align="center">
<img src="https://github.com/NVlabs/MambaVision/assets/26806394/79dcf841-3966-4b77-883d-76cd5e1d4320" width=70% height=70%
class="center">
</p>
## Model Usage
It is highly recommended to install the requirements for MambaVision by running the following:
```Bash
pip install mambavision
```
For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
### Image Classification
In the following example, we demonstrate how MambaVision can be used for image classification.
Given the following image from [COCO dataset](https://cocodataset.org/#home) val set as an input:
<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64414b62603214724ebd2636/4duSnqLf4lrNiAHczSmAN.jpeg" width=70% height=70%
class="center">
</p>
The following snippet can be used for image classification:
```Python
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
# eval mode for inference
model.cuda().eval()
# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# model inference
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
```
The predicted label is ```brown bear, bruin, Ursus arctos.```
### Feature Extraction
MambaVision can also be used as a generic feature extractor.
Specifically, we can extract the outputs of each stage of model (4 stages) as well as the final averaged-pool features that are flattened.
The following snippet can be used for feature extraction:
```Python
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
# eval mode for inference
model.cuda().eval()
# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# model inference
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 640])
print("Number of stages in extracted features:", len(features)) # 4 stages
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])
```
### License:
[NVIDIA Source Code License-NC](https://huggingface.co/nvidia/MambaVision-L-21K/blob/main/LICENSE) |