metadata
license: apache-2.0
tags:
- vision
- self-supervised-learning
- masked-image-modeling
- knowledge-distillation
- vit
datasets:
- ILSVRC/imagenet-1k
- 1aurent/ADE20K
- detection-datasets/coco
metrics:
- accuracy
- mIoU
- mAP
pipeline_tag: image-classification
MaskDistill ViT-Base/16
The first open-source PyTorch implementation of MaskDistill with pre-trained weights.
This model was trained using the MaskDistill-PyTorch codebase, reproducing the method from "A Unified View of Masked Image Modeling".
Model Description
MaskDistill learns visual representations by distilling knowledge from a frozen CLIP ViT-B/16 teacher into a ViT-Base student through masked image modeling. The student learns to predict the teacher's features for masked patches using Smooth L1 loss.
- Architecture: ViT-Base/16 (86M params)
- Teacher: CLIP ViT-B/16 (frozen)
- Pretraining: 300 epochs on ImageNet-1K
- Masking: Block masking at 40%, dense encoding with shared relative position bias
Results
| Evaluation | Result |
|---|---|
| Finetuning (ImageNet-1K) | 84.8% top-1 |
| k-NN (k=10) | 75.6% top-1 |
| Linear Probe | 76.3% top-1 |
| Sem. Seg. (ADE20K, UPerNet) | 52.6 mIoU |
| Obj. Det. (COCO, Mask R-CNN) | 44.4 bbox mAP |
| Inst. Seg. (COCO, Mask R-CNN) | 40.1 segm mAP |
Available Checkpoints
| File | Description |
|---|---|
pretrain_vit_base_ep290.pth |
Pretrained ViT-Base (300 epochs) |
finetune_vit_base_ep100.pth |
Finetuned on ImageNet-1K (84.8% top-1) |
linprobe_vit_base_ep90.pth.tar |
Linear probe (90 epochs, 76.3% top-1) |
semseg_upernet_ade20k_160k.pth |
UPerNet on ADE20K (52.6 mIoU) |
detection_maskrcnn_coco_12ep.pth |
Mask R-CNN on COCO (44.4 mAP) |
Usage
import torch
from src.models.vision_transformer import VisionTransformerMIM
# Load pretrained model
model = VisionTransformerMIM(
img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,
use_shared_rel_pos_bias=True, use_mask_tokens=True,
)
ckpt = torch.load("pretrain_vit_base_ep290.pth", map_location="cpu")
state = {k.replace("module.student.", ""): v for k, v in ckpt["model"].items() if "student" in k}
model.load_state_dict(state, strict=False)
See the GitHub repo for full training and evaluation code.
Citation
@article{hou2022unified,
title={A Unified View of Masked Image Modeling},
author={Hou, Zhenda and Sun, Fei and Chen, Yun-Hao and Yuan, Jia-Hong and Yu, Jia-Mu},
journal={arXiv preprint arXiv:2210.10615},
year={2022}
}