File size: 2,744 Bytes
1aba626
 
 
 
 
 
 
 
 
4496de0
95957fd
4496de0
1aba626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a1e10c
1aba626
90d8b98
1aba626
 
 
 
 
 
 
 
 
7a1e10c
90d8b98
1aba626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
---
license: apache-2.0
tags:
  - vision
  - self-supervised-learning
  - masked-image-modeling
  - knowledge-distillation
  - vit
datasets:
  - ILSVRC/imagenet-1k
  - 1aurent/ADE20K
  - detection-datasets/coco
metrics:
  - accuracy
  - mIoU
  - mAP
pipeline_tag: image-classification
---

# MaskDistill ViT-Base/16

**The first open-source PyTorch implementation of MaskDistill with pre-trained weights.**

This model was trained using the [MaskDistill-PyTorch](https://github.com/drkostas/MaskDistill-PyTorch) codebase, reproducing the method from ["A Unified View of Masked Image Modeling"](https://arxiv.org/abs/2210.10615).

## Model Description

MaskDistill learns visual representations by distilling knowledge from a frozen CLIP ViT-B/16 teacher into a ViT-Base student through masked image modeling. The student learns to predict the teacher's features for masked patches using Smooth L1 loss.

- **Architecture**: ViT-Base/16 (86M params)
- **Teacher**: CLIP ViT-B/16 (frozen)
- **Pretraining**: 300 epochs on ImageNet-1K
- **Masking**: Block masking at 40%, dense encoding with shared relative position bias

## Results

| Evaluation | Result |
|-----------|--------|
| Finetuning (ImageNet-1K) | **84.8%** top-1 |
| k-NN (k=10) | **75.6%** top-1 |
| Linear Probe | **76.3%** top-1 |
| Sem. Seg. (ADE20K, UPerNet) | **52.6** mIoU |
| Obj. Det. (COCO, Mask R-CNN) | **44.4** bbox mAP |
| Inst. Seg. (COCO, Mask R-CNN) | **40.1** segm mAP |

## Available Checkpoints

| File | Description |
|------|------------|
| `pretrain_vit_base_ep290.pth` | Pretrained ViT-Base (300 epochs) |
| `finetune_vit_base_ep100.pth` | Finetuned on ImageNet-1K (84.8% top-1) |
| `linprobe_vit_base_ep90.pth.tar` | Linear probe (90 epochs, 76.3% top-1) |
| `semseg_upernet_ade20k_160k.pth` | UPerNet on ADE20K (52.6 mIoU) |
| `detection_maskrcnn_coco_12ep.pth` | Mask R-CNN on COCO (44.4 mAP) |

## Usage

```python
import torch
from src.models.vision_transformer import VisionTransformerMIM

# Load pretrained model
model = VisionTransformerMIM(
    img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,
    use_shared_rel_pos_bias=True, use_mask_tokens=True,
)
ckpt = torch.load("pretrain_vit_base_ep290.pth", map_location="cpu")
state = {k.replace("module.student.", ""): v for k, v in ckpt["model"].items() if "student" in k}
model.load_state_dict(state, strict=False)
```

See the [GitHub repo](https://github.com/drkostas/MaskDistill-PyTorch) for full training and evaluation code.

## Citation

```bibtex
@article{hou2022unified,
  title={A Unified View of Masked Image Modeling},
  author={Hou, Zhenda and Sun, Fei and Chen, Yun-Hao and Yuan, Jia-Hong and Yu, Jia-Mu},
  journal={arXiv preprint arXiv:2210.10615},
  year={2022}
}
```