--- license: apache-2.0 tags: - vision - self-supervised-learning - masked-image-modeling - knowledge-distillation - vit datasets: - ILSVRC/imagenet-1k - 1aurent/ADE20K - detection-datasets/coco metrics: - accuracy - mIoU - mAP pipeline_tag: image-classification --- # MaskDistill ViT-Base/16 **The first open-source PyTorch implementation of MaskDistill with pre-trained weights.** This model was trained using the [MaskDistill-PyTorch](https://github.com/drkostas/MaskDistill-PyTorch) codebase, reproducing the method from ["A Unified View of Masked Image Modeling"](https://arxiv.org/abs/2210.10615). ## Model Description MaskDistill learns visual representations by distilling knowledge from a frozen CLIP ViT-B/16 teacher into a ViT-Base student through masked image modeling. The student learns to predict the teacher's features for masked patches using Smooth L1 loss. - **Architecture**: ViT-Base/16 (86M params) - **Teacher**: CLIP ViT-B/16 (frozen) - **Pretraining**: 300 epochs on ImageNet-1K - **Masking**: Block masking at 40%, dense encoding with shared relative position bias ## Results | Evaluation | Result | |-----------|--------| | Finetuning (ImageNet-1K) | **84.8%** top-1 | | k-NN (k=10) | **75.6%** top-1 | | Linear Probe | **76.3%** top-1 | | Sem. Seg. (ADE20K, UPerNet) | **52.6** mIoU | | Obj. Det. (COCO, Mask R-CNN) | **44.4** bbox mAP | | Inst. Seg. (COCO, Mask R-CNN) | **40.1** segm mAP | ## Available Checkpoints | File | Description | |------|------------| | `pretrain_vit_base_ep290.pth` | Pretrained ViT-Base (300 epochs) | | `finetune_vit_base_ep100.pth` | Finetuned on ImageNet-1K (84.8% top-1) | | `linprobe_vit_base_ep90.pth.tar` | Linear probe (90 epochs, 76.3% top-1) | | `semseg_upernet_ade20k_160k.pth` | UPerNet on ADE20K (52.6 mIoU) | | `detection_maskrcnn_coco_12ep.pth` | Mask R-CNN on COCO (44.4 mAP) | ## Usage ```python import torch from src.models.vision_transformer import VisionTransformerMIM # Load pretrained model model = VisionTransformerMIM( img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, use_shared_rel_pos_bias=True, use_mask_tokens=True, ) ckpt = torch.load("pretrain_vit_base_ep290.pth", map_location="cpu") state = {k.replace("module.student.", ""): v for k, v in ckpt["model"].items() if "student" in k} model.load_state_dict(state, strict=False) ``` See the [GitHub repo](https://github.com/drkostas/MaskDistill-PyTorch) for full training and evaluation code. ## Citation ```bibtex @article{hou2022unified, title={A Unified View of Masked Image Modeling}, author={Hou, Zhenda and Sun, Fei and Chen, Yun-Hao and Yuan, Jia-Hong and Yu, Jia-Mu}, journal={arXiv preprint arXiv:2210.10615}, year={2022} } ```