File size: 2,744 Bytes
1aba626 4496de0 95957fd 4496de0 1aba626 7a1e10c 1aba626 90d8b98 1aba626 7a1e10c 90d8b98 1aba626 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | ---
license: apache-2.0
tags:
- vision
- self-supervised-learning
- masked-image-modeling
- knowledge-distillation
- vit
datasets:
- ILSVRC/imagenet-1k
- 1aurent/ADE20K
- detection-datasets/coco
metrics:
- accuracy
- mIoU
- mAP
pipeline_tag: image-classification
---
# MaskDistill ViT-Base/16
**The first open-source PyTorch implementation of MaskDistill with pre-trained weights.**
This model was trained using the [MaskDistill-PyTorch](https://github.com/drkostas/MaskDistill-PyTorch) codebase, reproducing the method from ["A Unified View of Masked Image Modeling"](https://arxiv.org/abs/2210.10615).
## Model Description
MaskDistill learns visual representations by distilling knowledge from a frozen CLIP ViT-B/16 teacher into a ViT-Base student through masked image modeling. The student learns to predict the teacher's features for masked patches using Smooth L1 loss.
- **Architecture**: ViT-Base/16 (86M params)
- **Teacher**: CLIP ViT-B/16 (frozen)
- **Pretraining**: 300 epochs on ImageNet-1K
- **Masking**: Block masking at 40%, dense encoding with shared relative position bias
## Results
| Evaluation | Result |
|-----------|--------|
| Finetuning (ImageNet-1K) | **84.8%** top-1 |
| k-NN (k=10) | **75.6%** top-1 |
| Linear Probe | **76.3%** top-1 |
| Sem. Seg. (ADE20K, UPerNet) | **52.6** mIoU |
| Obj. Det. (COCO, Mask R-CNN) | **44.4** bbox mAP |
| Inst. Seg. (COCO, Mask R-CNN) | **40.1** segm mAP |
## Available Checkpoints
| File | Description |
|------|------------|
| `pretrain_vit_base_ep290.pth` | Pretrained ViT-Base (300 epochs) |
| `finetune_vit_base_ep100.pth` | Finetuned on ImageNet-1K (84.8% top-1) |
| `linprobe_vit_base_ep90.pth.tar` | Linear probe (90 epochs, 76.3% top-1) |
| `semseg_upernet_ade20k_160k.pth` | UPerNet on ADE20K (52.6 mIoU) |
| `detection_maskrcnn_coco_12ep.pth` | Mask R-CNN on COCO (44.4 mAP) |
## Usage
```python
import torch
from src.models.vision_transformer import VisionTransformerMIM
# Load pretrained model
model = VisionTransformerMIM(
img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,
use_shared_rel_pos_bias=True, use_mask_tokens=True,
)
ckpt = torch.load("pretrain_vit_base_ep290.pth", map_location="cpu")
state = {k.replace("module.student.", ""): v for k, v in ckpt["model"].items() if "student" in k}
model.load_state_dict(state, strict=False)
```
See the [GitHub repo](https://github.com/drkostas/MaskDistill-PyTorch) for full training and evaluation code.
## Citation
```bibtex
@article{hou2022unified,
title={A Unified View of Masked Image Modeling},
author={Hou, Zhenda and Sun, Fei and Chen, Yun-Hao and Yuan, Jia-Hong and Yu, Jia-Mu},
journal={arXiv preprint arXiv:2210.10615},
year={2022}
}
```
|