Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,824 Bytes
4845d25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from typing import List
import torch.nn as nn
from depth_anything_3.model.dinov2.vision_transformer import (
vit_base,
vit_giant2,
vit_large,
vit_small,
)
class DinoV2(nn.Module):
def __init__(
self,
name: str,
out_layers: List[int],
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = True,
**kwargs,
):
super().__init__()
assert name in {"vits", "vitb", "vitl", "vitg"}
self.name = name
self.out_layers = out_layers
self.alt_start = alt_start
self.qknorm_start = qknorm_start
self.rope_start = rope_start
self.cat_token = cat_token
encoder_map = {
"vits": vit_small,
"vitb": vit_base,
"vitl": vit_large,
"vitg": vit_giant2,
}
encoder_fn = encoder_map[self.name]
ffn_layer = "swiglufused" if self.name == "vitg" else "mlp"
self.pretrained = encoder_fn(
img_size=518,
patch_size=14,
ffn_layer=ffn_layer,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
)
def forward(self, x, **kwargs):
return self.pretrained.get_intermediate_layers(
x,
self.out_layers,
**kwargs,
)
|