Tenbatsu24 commited on
Commit ·
a10ce46
1
Parent(s): 94c37e9
add: missing files
Browse files- configuration_vitv2.py +28 -0
- hf_src/__init__.py +0 -0
- hf_src/layers/__init__.py +31 -0
- hf_src/layers/attention.py +105 -0
- hf_src/layers/block.py +331 -0
- hf_src/layers/cva_head.py +184 -0
- hf_src/layers/dino_head.py +76 -0
- hf_src/layers/drop_path.py +31 -0
- hf_src/layers/fp8_linear.py +144 -0
- hf_src/layers/layer_scale.py +28 -0
- hf_src/layers/mlp.py +49 -0
- hf_src/layers/patch_embed.py +96 -0
- hf_src/layers/rms_norm.py +24 -0
- hf_src/layers/rope_attention.py +182 -0
- hf_src/layers/rope_block.py +299 -0
- hf_src/layers/rope_position_encoding.py +184 -0
- hf_src/layers/sparse_linear.py +94 -0
- hf_src/layers/swiglu_ffn.py +64 -0
- hf_src/model/__init__.py +0 -0
- hf_src/model/image/__init__.py +0 -0
- hf_src/model/image/vitv2/__init__.py +0 -0
- hf_src/model/image/vitv2/transformer.py +475 -0
- hf_src/utils/__init__.py +16 -0
- hf_src/utils/download.py +99 -0
- hf_src/utils/dtype.py +37 -0
- hf_src/utils/masking.py +113 -0
- hf_src/utils/seedlet_masking.py +0 -0
- hf_src/utils/utils.py +136 -0
- modelling_vitv2.py +32 -0
- requirements.txt +2 -0
configuration_vitv2.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ViTv2Config(PretrainedConfig):
|
| 5 |
+
model_type = "vitv2"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
img_size=224,
|
| 10 |
+
patch_size=16,
|
| 11 |
+
embed_dim=384,
|
| 12 |
+
depth=12,
|
| 13 |
+
num_heads=6,
|
| 14 |
+
mlp_ratio=4,
|
| 15 |
+
num_register_tokens=0,
|
| 16 |
+
init_values=None,
|
| 17 |
+
**ignored_kwargs,
|
| 18 |
+
):
|
| 19 |
+
super().__init__(**ignored_kwargs)
|
| 20 |
+
|
| 21 |
+
self.depth = depth
|
| 22 |
+
self.img_size = img_size
|
| 23 |
+
self.embed_dim = embed_dim
|
| 24 |
+
self.num_heads = num_heads
|
| 25 |
+
self.mlp_ratio = mlp_ratio
|
| 26 |
+
self.patch_size = patch_size
|
| 27 |
+
self.init_values = init_values
|
| 28 |
+
self.num_register_tokens = num_register_tokens
|
hf_src/__init__.py
ADDED
|
File without changes
|
hf_src/layers/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp import Mlp
|
| 2 |
+
from .block import Block # noqa: F401
|
| 3 |
+
from .rms_norm import RMSNorm
|
| 4 |
+
from .drop_path import DropPath
|
| 5 |
+
from .dino_head import DINOHead
|
| 6 |
+
from .layer_scale import LayerScale
|
| 7 |
+
from .patch_embed import PatchEmbed
|
| 8 |
+
from .block import NestedTensorBlock
|
| 9 |
+
from .attention import MemEffAttention
|
| 10 |
+
from .rope_block import SelfAttentionBlock
|
| 11 |
+
from .cva_head import CVAHead, IdentityHead
|
| 12 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 13 |
+
from .rope_position_encoding import RopePositionEmbedding
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"CVAHead",
|
| 17 |
+
"RMSNorm",
|
| 18 |
+
"IdentityHead",
|
| 19 |
+
"DINOHead",
|
| 20 |
+
"DropPath",
|
| 21 |
+
"Block",
|
| 22 |
+
"Mlp",
|
| 23 |
+
"PatchEmbed",
|
| 24 |
+
"LayerScale",
|
| 25 |
+
"SwiGLUFFN",
|
| 26 |
+
"SwiGLUFFNFused",
|
| 27 |
+
"NestedTensorBlock",
|
| 28 |
+
"MemEffAttention",
|
| 29 |
+
"SelfAttentionBlock",
|
| 30 |
+
"RopePositionEmbedding",
|
| 31 |
+
]
|
hf_src/layers/attention.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 11 |
+
try:
|
| 12 |
+
if XFORMERS_ENABLED:
|
| 13 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 14 |
+
|
| 15 |
+
XFORMERS_AVAILABLE = True
|
| 16 |
+
else:
|
| 17 |
+
raise ImportError
|
| 18 |
+
except ImportError:
|
| 19 |
+
XFORMERS_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Attention(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
dim: int,
|
| 26 |
+
num_heads: int = 8,
|
| 27 |
+
qkv_bias: bool = False,
|
| 28 |
+
proj_bias: bool = True,
|
| 29 |
+
attn_drop: float = 0.0,
|
| 30 |
+
proj_drop: float = 0.0,
|
| 31 |
+
) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.num_heads = num_heads
|
| 34 |
+
head_dim = dim // num_heads
|
| 35 |
+
self.scale = head_dim**-0.5
|
| 36 |
+
|
| 37 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 38 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 39 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 40 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: Tensor, return_attn=False) -> Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| 45 |
+
"""
|
| 46 |
+
B, N, C = x.shape
|
| 47 |
+
qkv = (
|
| 48 |
+
self.qkv(x)
|
| 49 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 50 |
+
.permute(2, 0, 3, 1, 4)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
|
| 56 |
+
attn = attn.softmax(dim=-1)
|
| 57 |
+
attn = self.attn_drop(attn)
|
| 58 |
+
|
| 59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
|
| 63 |
+
# Adaptation for returing attentions
|
| 64 |
+
if return_attn:
|
| 65 |
+
return attn
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MemEffAttention(Attention):
|
| 70 |
+
"""
|
| 71 |
+
Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def forward(self, x: Tensor, attn_bias=None, return_attn=False) -> Tensor:
|
| 75 |
+
if not XFORMERS_AVAILABLE:
|
| 76 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 77 |
+
# Change this line
|
| 78 |
+
# return super().forward(x)
|
| 79 |
+
# Adaptation for returing attentions
|
| 80 |
+
return super().forward(x, return_attn)
|
| 81 |
+
|
| 82 |
+
B, N, C = x.shape
|
| 83 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 84 |
+
|
| 85 |
+
q, k, v = unbind(qkv, 2)
|
| 86 |
+
|
| 87 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 88 |
+
if return_attn:
|
| 89 |
+
# Support for XFORMERS to return attention
|
| 90 |
+
# Adapted from https://github.com/facebookresearch/dinov2/issues/90#issuecomment-1574001076
|
| 91 |
+
attn = x.permute(0, 2, 1, 3) @ v.permute(0, 2, 3, 1)
|
| 92 |
+
return attn
|
| 93 |
+
x = x.reshape([B, N, C])
|
| 94 |
+
|
| 95 |
+
x = self.proj(x)
|
| 96 |
+
x = self.proj_drop(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
import torch
|
| 102 |
+
|
| 103 |
+
_att = MemEffAttention(dim=32, num_heads=4).to("cuda")
|
| 104 |
+
print(_att(torch.randn(4, 16, 32, device="cuda"), return_attn=True).shape)
|
| 105 |
+
print(_att(torch.randn(4, 16, 32, device="cuda")).shape)
|
hf_src/layers/block.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
|
| 13 |
+
from .attention import Attention, MemEffAttention
|
| 14 |
+
from .drop_path import DropPath
|
| 15 |
+
from .layer_scale import LayerScale
|
| 16 |
+
from .mlp import Mlp
|
| 17 |
+
|
| 18 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 19 |
+
try:
|
| 20 |
+
if XFORMERS_ENABLED:
|
| 21 |
+
from xformers.ops import fmha
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
else:
|
| 25 |
+
raise ImportError
|
| 26 |
+
except ImportError:
|
| 27 |
+
XFORMERS_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Block(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
dim: int,
|
| 34 |
+
num_heads: int,
|
| 35 |
+
mlp_ratio: float = 4.0,
|
| 36 |
+
qkv_bias: bool = False,
|
| 37 |
+
proj_bias: bool = True,
|
| 38 |
+
ffn_bias: bool = True,
|
| 39 |
+
drop: float = 0.0,
|
| 40 |
+
attn_drop: float = 0.0,
|
| 41 |
+
init_values=None,
|
| 42 |
+
drop_path: float = 0.0,
|
| 43 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 44 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 45 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 46 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.norm1 = norm_layer(dim)
|
| 50 |
+
self.attn = attn_class(
|
| 51 |
+
dim,
|
| 52 |
+
num_heads=num_heads,
|
| 53 |
+
qkv_bias=qkv_bias,
|
| 54 |
+
proj_bias=proj_bias,
|
| 55 |
+
attn_drop=attn_drop,
|
| 56 |
+
proj_drop=drop,
|
| 57 |
+
)
|
| 58 |
+
self.ls1 = (
|
| 59 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 60 |
+
)
|
| 61 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 62 |
+
|
| 63 |
+
self.norm2 = norm_layer(dim)
|
| 64 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 65 |
+
self.mlp = ffn_layer(
|
| 66 |
+
in_features=dim,
|
| 67 |
+
hidden_features=mlp_hidden_dim,
|
| 68 |
+
act_layer=act_layer,
|
| 69 |
+
drop=drop,
|
| 70 |
+
bias=ffn_bias,
|
| 71 |
+
)
|
| 72 |
+
self.ls2 = (
|
| 73 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 74 |
+
)
|
| 75 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 76 |
+
|
| 77 |
+
self.sample_drop_ratio = drop_path
|
| 78 |
+
|
| 79 |
+
def forward(self, x: Tensor, return_attention=False) -> Tensor:
|
| 80 |
+
"""
|
| 81 |
+
Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 85 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 86 |
+
|
| 87 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 88 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 89 |
+
|
| 90 |
+
# Adaptation for returning attentions
|
| 91 |
+
if return_attention:
|
| 92 |
+
attn = self.attn(self.norm1(x), return_attn=True)
|
| 93 |
+
|
| 94 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 95 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 96 |
+
x = drop_add_residual_stochastic_depth(
|
| 97 |
+
x,
|
| 98 |
+
residual_func=attn_residual_func,
|
| 99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 100 |
+
)
|
| 101 |
+
x = drop_add_residual_stochastic_depth(
|
| 102 |
+
x,
|
| 103 |
+
residual_func=ffn_residual_func,
|
| 104 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 105 |
+
)
|
| 106 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 107 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 108 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 109 |
+
else:
|
| 110 |
+
x = x + attn_residual_func(x)
|
| 111 |
+
x = x + ffn_residual_func(x)
|
| 112 |
+
|
| 113 |
+
# Adaptation for returing attentions
|
| 114 |
+
if return_attention:
|
| 115 |
+
return x, attn
|
| 116 |
+
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def drop_add_residual_stochastic_depth(
|
| 121 |
+
x: Tensor,
|
| 122 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 123 |
+
sample_drop_ratio: float = 0.0,
|
| 124 |
+
) -> Tensor:
|
| 125 |
+
# 1) extract subset using permutation
|
| 126 |
+
b, n, d = x.shape
|
| 127 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 128 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 129 |
+
x_subset = x[brange]
|
| 130 |
+
|
| 131 |
+
# 2) apply residual_func to get residual
|
| 132 |
+
residual = residual_func(x_subset)
|
| 133 |
+
|
| 134 |
+
x_flat = x.flatten(1)
|
| 135 |
+
residual = residual.flatten(1)
|
| 136 |
+
|
| 137 |
+
residual_scale_factor = b / sample_subset_size
|
| 138 |
+
|
| 139 |
+
# 3) add the residual
|
| 140 |
+
x_plus_residual = torch.index_add(
|
| 141 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 142 |
+
)
|
| 143 |
+
return x_plus_residual.view_as(x)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 147 |
+
b, n, d = x.shape
|
| 148 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 149 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 150 |
+
residual_scale_factor = b / sample_subset_size
|
| 151 |
+
return brange, residual_scale_factor
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def add_residual(x, brange, residual, residual_scale_factor, ls=None):
|
| 155 |
+
if ls is None:
|
| 156 |
+
x_flat = x.flatten(1)
|
| 157 |
+
residual = residual.flatten(1)
|
| 158 |
+
x_plus_residual = x_flat.index_add_(
|
| 159 |
+
dim=0,
|
| 160 |
+
index=brange,
|
| 161 |
+
source=residual.to(dtype=x.dtype),
|
| 162 |
+
alpha=residual_scale_factor,
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
x_plus_residual = x.index_add_(
|
| 166 |
+
dim=0,
|
| 167 |
+
source=ls(residual.to(dtype=x.dtype)),
|
| 168 |
+
index=brange,
|
| 169 |
+
alpha=residual_scale_factor,
|
| 170 |
+
)
|
| 171 |
+
return x_plus_residual
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 178 |
+
"""
|
| 179 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 180 |
+
"""
|
| 181 |
+
batch_sizes = (
|
| 182 |
+
[b.shape[0] for b in branges]
|
| 183 |
+
if branges is not None
|
| 184 |
+
else [x.shape[0] for x in x_list]
|
| 185 |
+
)
|
| 186 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 187 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 188 |
+
seqlens = []
|
| 189 |
+
for b, x in zip(batch_sizes, x_list):
|
| 190 |
+
for _ in range(b):
|
| 191 |
+
seqlens.append(x.shape[1])
|
| 192 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 193 |
+
attn_bias._batch_sizes = batch_sizes
|
| 194 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 195 |
+
|
| 196 |
+
if branges is not None:
|
| 197 |
+
cat_tensors = torch.cat(
|
| 198 |
+
[
|
| 199 |
+
_s.index_select(0, _i).reshape(-1)
|
| 200 |
+
for _s, _i in zip([_x.flatten(1) for _x in x_list], branges)
|
| 201 |
+
],
|
| 202 |
+
dim=0,
|
| 203 |
+
).view(1, -1, x_list[0].shape[-1])
|
| 204 |
+
# cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
| 205 |
+
# 1, -1, x_list[0].shape[-1]
|
| 206 |
+
# )
|
| 207 |
+
else:
|
| 208 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 209 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 210 |
+
|
| 211 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def drop_add_residual_stochastic_depth_list(
|
| 215 |
+
x_list: List[Tensor],
|
| 216 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 217 |
+
sample_drop_ratio: float = 0.0,
|
| 218 |
+
scaling_vector=None,
|
| 219 |
+
) -> Tensor:
|
| 220 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 221 |
+
branges_scales = [
|
| 222 |
+
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| 223 |
+
]
|
| 224 |
+
branges = [s[0] for s in branges_scales]
|
| 225 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 226 |
+
|
| 227 |
+
# 2) get attention bias and index+concat the tensors
|
| 228 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 229 |
+
|
| 230 |
+
# 3) apply residual_func to get residual, and split the result
|
| 231 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 232 |
+
|
| 233 |
+
outputs = []
|
| 234 |
+
for x, brange, residual, residual_scale_factor in zip(
|
| 235 |
+
x_list, branges, residual_list, residual_scale_factors
|
| 236 |
+
):
|
| 237 |
+
outputs.append(
|
| 238 |
+
add_residual(
|
| 239 |
+
x, brange, residual, residual_scale_factor, scaling_vector
|
| 240 |
+
).view_as(x)
|
| 241 |
+
)
|
| 242 |
+
return outputs
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class NestedTensorBlock(Block):
|
| 246 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 247 |
+
"""
|
| 248 |
+
x_list contains a list of tensors to nest together and run
|
| 249 |
+
"""
|
| 250 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 251 |
+
|
| 252 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 253 |
+
|
| 254 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 255 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 256 |
+
|
| 257 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 258 |
+
return self.mlp(self.norm2(x))
|
| 259 |
+
|
| 260 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 261 |
+
x_list,
|
| 262 |
+
residual_func=attn_residual_func,
|
| 263 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 264 |
+
scaling_vector=self.ls1 if isinstance(self.ls1, LayerScale) else None,
|
| 265 |
+
)
|
| 266 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 267 |
+
x_list,
|
| 268 |
+
residual_func=ffn_residual_func,
|
| 269 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 270 |
+
scaling_vector=self.ls2 if isinstance(self.ls1, LayerScale) else None,
|
| 271 |
+
)
|
| 272 |
+
return x_list
|
| 273 |
+
else:
|
| 274 |
+
|
| 275 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 276 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 277 |
+
|
| 278 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 279 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 280 |
+
|
| 281 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 282 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 283 |
+
x = x + ffn_residual_func(x)
|
| 284 |
+
return attn_bias.split(x)
|
| 285 |
+
|
| 286 |
+
def forward(self, x_or_x_list, return_attention=False):
|
| 287 |
+
"""
|
| 288 |
+
Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| 289 |
+
"""
|
| 290 |
+
if isinstance(x_or_x_list, Tensor):
|
| 291 |
+
# Change the following line
|
| 292 |
+
# return super().forward(x_or_x_list)
|
| 293 |
+
return super().forward(x_or_x_list, return_attention)
|
| 294 |
+
elif isinstance(x_or_x_list, list):
|
| 295 |
+
if return_attention:
|
| 296 |
+
raise NotImplementedError(
|
| 297 |
+
"return_attention not supported for nested tensors"
|
| 298 |
+
)
|
| 299 |
+
assert (
|
| 300 |
+
XFORMERS_AVAILABLE
|
| 301 |
+
), "Please install xFormers for nested tensors usage"
|
| 302 |
+
return self.forward_nested(x_or_x_list)
|
| 303 |
+
else:
|
| 304 |
+
raise AssertionError
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
_device = (
|
| 309 |
+
"cuda"
|
| 310 |
+
if torch.cuda.is_available()
|
| 311 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 312 |
+
)
|
| 313 |
+
# Example usage
|
| 314 |
+
block = Block(dim=64, num_heads=8, drop_path=0.3).to(_device)
|
| 315 |
+
x = torch.randn(
|
| 316 |
+
10, 16, 64, device=_device
|
| 317 |
+
) # Batch size 10, sequence length 16, feature dimension 64
|
| 318 |
+
output = block(x)
|
| 319 |
+
print(output.shape) # Should be (10, 16, 64)
|
| 320 |
+
|
| 321 |
+
nested_block = NestedTensorBlock(
|
| 322 |
+
dim=64, num_heads=8, attn_class=MemEffAttention
|
| 323 |
+
).to(_device)
|
| 324 |
+
nested_x = [
|
| 325 |
+
torch.randn(10, 16, 64, device=_device),
|
| 326 |
+
torch.randn(10, 16, 64, device=_device),
|
| 327 |
+
] # List of tensors
|
| 328 |
+
nested_output = nested_block(nested_x)
|
| 329 |
+
print(
|
| 330 |
+
[o.shape for o in nested_output]
|
| 331 |
+
) # Should print shapes of tensors in the list
|
hf_src/layers/cva_head.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch.nn.init import trunc_normal_
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _make_lna_block(input_dim, output_dim, bias, norm_op, activation):
|
| 11 |
+
layers = [nn.Linear(input_dim, output_dim, bias=bias)]
|
| 12 |
+
if norm_op is not None:
|
| 13 |
+
layers.append(norm_op(output_dim))
|
| 14 |
+
if activation is not None:
|
| 15 |
+
layers.append(activation())
|
| 16 |
+
return nn.Sequential(*layers)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _build_projector(n_layers, in_dim, out_dim, hidden_dim, activation=nn.GELU):
|
| 20 |
+
norm_op = partial(nn.BatchNorm1d, track_running_stats=False)
|
| 21 |
+
if n_layers > 1:
|
| 22 |
+
layers = _make_lna_block(in_dim, hidden_dim, True, norm_op, activation)
|
| 23 |
+
for _ in range(n_layers - 2):
|
| 24 |
+
layers += _make_lna_block(hidden_dim, hidden_dim, True, norm_op, activation)
|
| 25 |
+
layers += nn.Sequential(
|
| 26 |
+
*[nn.Linear(hidden_dim, out_dim, bias=False), norm_op(out_dim)]
|
| 27 |
+
)
|
| 28 |
+
return nn.Sequential(*layers)
|
| 29 |
+
else:
|
| 30 |
+
layers = [nn.Linear(in_dim, out_dim, bias=False), norm_op(out_dim)]
|
| 31 |
+
return nn.Sequential(*layers)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _build_predictor(n_layers, in_out_dim, bottleneck_dim, activation=nn.GELU):
|
| 35 |
+
norm_op = partial(nn.BatchNorm1d, track_running_stats=False)
|
| 36 |
+
layers = [_make_lna_block(in_out_dim, bottleneck_dim, True, norm_op, activation)]
|
| 37 |
+
|
| 38 |
+
for _ in range(n_layers - 1):
|
| 39 |
+
layers += _make_lna_block(
|
| 40 |
+
bottleneck_dim, bottleneck_dim, True, norm_op, activation
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
layers += _make_lna_block(bottleneck_dim, in_out_dim, False, None, None)
|
| 44 |
+
return nn.Sequential(*layers)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CVAHead(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
in_dim,
|
| 51 |
+
out_dim=1024,
|
| 52 |
+
projector_layers=3,
|
| 53 |
+
predictor_layers=1,
|
| 54 |
+
hidden_dim=2048,
|
| 55 |
+
bottleneck_dim=256,
|
| 56 |
+
act_op=nn.GELU,
|
| 57 |
+
use_predictor=True,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
projector_layers = max(projector_layers, 1)
|
| 61 |
+
|
| 62 |
+
self.projector = _build_projector(
|
| 63 |
+
projector_layers,
|
| 64 |
+
in_dim,
|
| 65 |
+
out_dim,
|
| 66 |
+
hidden_dim=hidden_dim,
|
| 67 |
+
activation=act_op,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if use_predictor:
|
| 71 |
+
self.predictor = _build_predictor(
|
| 72 |
+
predictor_layers,
|
| 73 |
+
out_dim,
|
| 74 |
+
bottleneck_dim,
|
| 75 |
+
activation=act_op,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.apply(self._init_weights)
|
| 79 |
+
|
| 80 |
+
def _init_weights(self, m):
|
| 81 |
+
if isinstance(m, nn.Linear):
|
| 82 |
+
trunc_normal_(m.weight, std=0.02)
|
| 83 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 84 |
+
nn.init.constant_(m.bias, 0)
|
| 85 |
+
|
| 86 |
+
def project(self, latent):
|
| 87 |
+
if latent.ndim == 2:
|
| 88 |
+
return self.projector(latent)
|
| 89 |
+
|
| 90 |
+
if latent.ndim == 4:
|
| 91 |
+
# spatial_latent: (B, C, H, W)
|
| 92 |
+
b, _, h, w = latent.shape
|
| 93 |
+
flattened_latent = rearrange(latent, "b c h w -> (b h w) c").contiguous()
|
| 94 |
+
|
| 95 |
+
proj = self.projector(flattened_latent)
|
| 96 |
+
|
| 97 |
+
# make it spatial again
|
| 98 |
+
return rearrange(proj, "(b h w) c -> b c h w", b=b, h=h, w=w).contiguous()
|
| 99 |
+
|
| 100 |
+
if latent.ndim == 3:
|
| 101 |
+
# (B, N, C)
|
| 102 |
+
b, n, _ = latent.shape
|
| 103 |
+
|
| 104 |
+
return self.projector(latent.flatten(0, 1)).unflatten(0, (b, n))
|
| 105 |
+
|
| 106 |
+
raise ValueError(f"{latent.ndim=}D latent input is not supported")
|
| 107 |
+
|
| 108 |
+
def predict(self, latent):
|
| 109 |
+
if latent.ndim == 2:
|
| 110 |
+
return self.predictor(self.projector(latent))
|
| 111 |
+
|
| 112 |
+
if latent.ndim == 4:
|
| 113 |
+
# spatial_latent: (B, C, H, W)
|
| 114 |
+
b, _, h, w = latent.shape
|
| 115 |
+
flattened_latent = rearrange(latent, "b c h w -> (b h w) c").contiguous()
|
| 116 |
+
|
| 117 |
+
projection = self.projector(flattened_latent)
|
| 118 |
+
pred = self.predictor(projection)
|
| 119 |
+
|
| 120 |
+
# make it spatial again
|
| 121 |
+
return rearrange(pred, "(b h w) c -> b c h w", b=b, h=h, w=w).contiguous()
|
| 122 |
+
|
| 123 |
+
if latent.ndim == 3:
|
| 124 |
+
# (B, N, C)
|
| 125 |
+
b, n, _ = latent.shape
|
| 126 |
+
return self.predictor(self.projector(latent.flatten(0, 1))).unflatten(
|
| 127 |
+
0, (b, n)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
raise ValueError(f"{latent.ndim=}D latent input is not supported")
|
| 131 |
+
|
| 132 |
+
def project_predict(self, latent):
|
| 133 |
+
projected = self.project(latent)
|
| 134 |
+
predicted = self.predictor(projected)
|
| 135 |
+
return projected, predicted
|
| 136 |
+
|
| 137 |
+
def forward(self, latent, project_only=False):
|
| 138 |
+
if project_only:
|
| 139 |
+
return self.project(latent)
|
| 140 |
+
|
| 141 |
+
return self.predict(latent)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class IdentityHead(torch.nn.Module):
|
| 145 |
+
def __init__(self):
|
| 146 |
+
super().__init__()
|
| 147 |
+
|
| 148 |
+
def project(self, x):
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
def predict(self, x):
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
def project_predict(self, x):
|
| 155 |
+
return x, x
|
| 156 |
+
|
| 157 |
+
def forward(self, x, **kwargs):
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class CVAHeadList(torch.nn.Module):
|
| 162 |
+
def __init__(self, num_scales=2, **params):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.heads = torch.nn.ModuleList([CVAHead(**params) for _ in range(num_scales)])
|
| 165 |
+
|
| 166 |
+
def forward(self, x, scale_idx, project_only=False):
|
| 167 |
+
return self.heads[scale_idx](x, project_only=project_only)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
model = CVAHead(
|
| 172 |
+
768,
|
| 173 |
+
512,
|
| 174 |
+
hidden_dim=2048,
|
| 175 |
+
bottleneck_dim=256,
|
| 176 |
+
act_op=nn.GELU,
|
| 177 |
+
)
|
| 178 |
+
print(model)
|
| 179 |
+
x = torch.randn(2, 36, 768)
|
| 180 |
+
out = model(x, project_only=True)
|
| 181 |
+
|
| 182 |
+
print("Output shape:", out.shape) # Expected: (2, 2048, 6, 6)
|
| 183 |
+
out2 = model(x, project_only=False)
|
| 184 |
+
print("Output shape after prediction:", out2.shape) # Expected: (2, 2048, 6, 6)
|
hf_src/layers/dino_head.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torch.nn.init import trunc_normal_
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DINOHead(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_dim,
|
| 12 |
+
out_dim=2**16,
|
| 13 |
+
use_bn=False,
|
| 14 |
+
nlayers=3,
|
| 15 |
+
hidden_dim=2048,
|
| 16 |
+
bottleneck_dim=256,
|
| 17 |
+
mlp_bias=True,
|
| 18 |
+
use_last_layer=True,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
nlayers = max(nlayers, 1)
|
| 22 |
+
|
| 23 |
+
self.use_last_layer = use_last_layer
|
| 24 |
+
|
| 25 |
+
self.mlp = _build_mlp(
|
| 26 |
+
nlayers,
|
| 27 |
+
in_dim,
|
| 28 |
+
bottleneck_dim,
|
| 29 |
+
hidden_dim=hidden_dim,
|
| 30 |
+
use_bn=use_bn,
|
| 31 |
+
bias=mlp_bias,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if use_last_layer:
|
| 35 |
+
self.last_layer = weight_norm(
|
| 36 |
+
nn.Linear(bottleneck_dim, out_dim, bias=False)
|
| 37 |
+
)
|
| 38 |
+
self.last_layer.parametrizations.weight.original0.data.fill_(1)
|
| 39 |
+
|
| 40 |
+
def init_weights(self) -> None:
|
| 41 |
+
self.apply(self._init_weights)
|
| 42 |
+
|
| 43 |
+
def _init_weights(self, m):
|
| 44 |
+
if isinstance(m, nn.Linear):
|
| 45 |
+
trunc_normal_(m.weight, std=0.02)
|
| 46 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 47 |
+
nn.init.constant_(m.bias, 0)
|
| 48 |
+
|
| 49 |
+
def forward(self, x, **kwargs):
|
| 50 |
+
x = self.mlp(x)
|
| 51 |
+
|
| 52 |
+
if self.use_last_layer:
|
| 53 |
+
eps = torch.finfo(x.dtype).eps
|
| 54 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 55 |
+
return self.last_layer(x)
|
| 56 |
+
else:
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _build_mlp(
|
| 61 |
+
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
| 62 |
+
):
|
| 63 |
+
if nlayers == 1:
|
| 64 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=not use_bn)
|
| 65 |
+
else:
|
| 66 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 67 |
+
if use_bn:
|
| 68 |
+
layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats=False))
|
| 69 |
+
layers.append(nn.GELU())
|
| 70 |
+
for _ in range(nlayers - 2):
|
| 71 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 72 |
+
if use_bn:
|
| 73 |
+
layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats=False))
|
| 74 |
+
layers.append(nn.GELU())
|
| 75 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=not use_bn))
|
| 76 |
+
return nn.Sequential(*layers)
|
hf_src/layers/drop_path.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 10 |
+
if drop_prob == 0.0 or not training:
|
| 11 |
+
return x
|
| 12 |
+
keep_prob = 1 - drop_prob
|
| 13 |
+
shape = (x.shape[0],) + (1,) * (
|
| 14 |
+
x.ndim - 1
|
| 15 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 16 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 17 |
+
if keep_prob > 0.0:
|
| 18 |
+
random_tensor.div_(keep_prob)
|
| 19 |
+
output = x * random_tensor
|
| 20 |
+
return output
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DropPath(nn.Module):
|
| 24 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, drop_prob=None):
|
| 27 |
+
super(DropPath, self).__init__()
|
| 28 |
+
self.drop_prob = drop_prob
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
return drop_path(x, self.drop_prob, self.training)
|
hf_src/layers/fp8_linear.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from hf_src.utils import named_replace
|
| 11 |
+
from hf_src.layers.rope_attention import LinearKMaskedBias
|
| 12 |
+
|
| 13 |
+
# avoid division by zero when calculating scale
|
| 14 |
+
EPS = 1e-12
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def scale(t, amax_t):
|
| 18 |
+
max_v = torch.finfo(torch.float8_e4m3fn).max
|
| 19 |
+
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
|
| 20 |
+
t_fp8 = (t / scale_t).to(torch.float8_e4m3fn)
|
| 21 |
+
return t_fp8, scale_t
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def matmul(first, amax_first, second_t, amax_second_t, bias):
|
| 25 |
+
first_fp8, scale_first = scale(first, amax_first)
|
| 26 |
+
second_t_fp8, scale_second_t = scale(second_t, amax_second_t)
|
| 27 |
+
# PyTorch's row-wise scaled matmul kernel is based on CUTLASS and is quite
|
| 28 |
+
# slow. Hence we fall back to an "unscaled" matmul, which uses cuBLAS, and
|
| 29 |
+
# apply the scale manually afterwards.
|
| 30 |
+
output = torch._scaled_mm(
|
| 31 |
+
first_fp8,
|
| 32 |
+
second_t_fp8.t(),
|
| 33 |
+
scale_a=scale_first.new_ones((1, 1)),
|
| 34 |
+
scale_b=scale_second_t.t().new_ones((1, 1)),
|
| 35 |
+
bias=None,
|
| 36 |
+
out_dtype=torch.bfloat16,
|
| 37 |
+
use_fast_accum=False,
|
| 38 |
+
)
|
| 39 |
+
output = (output * scale_first * scale_second_t.t()).to(torch.bfloat16)
|
| 40 |
+
if bias is not None:
|
| 41 |
+
output = output + bias
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@torch.compiler.allow_in_graph
|
| 46 |
+
class Fp8LinearFn(torch.autograd.Function):
|
| 47 |
+
@staticmethod
|
| 48 |
+
def forward(ctx, a, b_t, bias):
|
| 49 |
+
amax_a = a.abs().amax(dim=-1, keepdim=True)
|
| 50 |
+
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
|
| 51 |
+
out = matmul(a, amax_a, b_t, amax_b_t, bias)
|
| 52 |
+
|
| 53 |
+
ctx.a_requires_grad = a.requires_grad
|
| 54 |
+
ctx.b_requires_grad = b_t.requires_grad
|
| 55 |
+
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
|
| 56 |
+
|
| 57 |
+
ctx.save_for_backward(a, b_t, amax_b_t.max())
|
| 58 |
+
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def backward(ctx, grad_out):
|
| 63 |
+
a, b_t, amax_b = ctx.saved_tensors
|
| 64 |
+
|
| 65 |
+
if ctx.a_requires_grad:
|
| 66 |
+
b = b_t.t().contiguous()
|
| 67 |
+
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
|
| 68 |
+
amax_b = amax_b.repeat(b.shape[0], 1)
|
| 69 |
+
grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None)
|
| 70 |
+
else:
|
| 71 |
+
grad_a = None
|
| 72 |
+
if ctx.b_requires_grad:
|
| 73 |
+
grad_b = grad_out.t() @ a
|
| 74 |
+
else:
|
| 75 |
+
grad_b = None
|
| 76 |
+
if ctx.bias_requires_grad:
|
| 77 |
+
grad_bias = grad_out.sum(dim=0)
|
| 78 |
+
else:
|
| 79 |
+
grad_bias = None
|
| 80 |
+
|
| 81 |
+
return grad_a, grad_b, grad_bias
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Fp8Linear(torch.nn.Linear):
|
| 85 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
|
| 87 |
+
out = out.unflatten(0, input.shape[:-1])
|
| 88 |
+
return out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Fp8LinearKMaskedBias(LinearKMaskedBias):
|
| 92 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
masked_bias = self.bias * self.bias_mask if self.bias is not None else None
|
| 94 |
+
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, masked_bias)
|
| 95 |
+
out = out.unflatten(0, input.shape[:-1])
|
| 96 |
+
return out
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def convert_linears_to_fp8(
|
| 100 |
+
root_module: torch.nn.Module, *, filter: str
|
| 101 |
+
) -> torch.nn.Module:
|
| 102 |
+
filter_re = re.compile(filter)
|
| 103 |
+
total_count = 0
|
| 104 |
+
|
| 105 |
+
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
|
| 106 |
+
nonlocal total_count
|
| 107 |
+
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
|
| 108 |
+
return module
|
| 109 |
+
if type(module) == torch.nn.Linear:
|
| 110 |
+
new_cls = Fp8Linear
|
| 111 |
+
elif type(module) == LinearKMaskedBias:
|
| 112 |
+
new_cls = Fp8LinearKMaskedBias
|
| 113 |
+
else:
|
| 114 |
+
assert False, str(type(module))
|
| 115 |
+
if module.in_features % 64 != 0 or module.out_features % 64 != 0:
|
| 116 |
+
# This is not a strict requirement, but H100 TensorCores for fp8
|
| 117 |
+
# operate on tiles of 64 elements anyways, and Inductor sometimes
|
| 118 |
+
# pads inner dims to become multiples of 64. Also, if one day we
|
| 119 |
+
# switch back to cuBLAS, it artificially requires dims to be
|
| 120 |
+
# multiples of 16.
|
| 121 |
+
raise RuntimeError(
|
| 122 |
+
"fp8 requires all dimensions to be multiples of 64 "
|
| 123 |
+
"(consider using ffn_layer=swiglu64 or higher)"
|
| 124 |
+
)
|
| 125 |
+
new_module = new_cls(
|
| 126 |
+
in_features=module.in_features,
|
| 127 |
+
out_features=module.out_features,
|
| 128 |
+
bias=module.bias is not None,
|
| 129 |
+
dtype=module.weight.dtype,
|
| 130 |
+
device=module.weight.device,
|
| 131 |
+
)
|
| 132 |
+
new_module.weight = module.weight
|
| 133 |
+
new_module.bias = module.bias
|
| 134 |
+
total_count += 1
|
| 135 |
+
return new_module
|
| 136 |
+
|
| 137 |
+
out = named_replace(replace, root_module)
|
| 138 |
+
assert total_count > 0, "fp8: no layer found to convert"
|
| 139 |
+
# Force re-compile everything
|
| 140 |
+
torch._dynamo.reset_code_caches()
|
| 141 |
+
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
| 142 |
+
|
| 143 |
+
reset_cudagraph_trees()
|
| 144 |
+
return out
|
hf_src/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 2 |
+
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LayerScale(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
dim: int,
|
| 15 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 16 |
+
inplace: bool = False,
|
| 17 |
+
device=None,
|
| 18 |
+
) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.inplace = inplace
|
| 21 |
+
self.gamma = nn.Parameter(torch.empty(dim, device=device))
|
| 22 |
+
self.init_values = init_values
|
| 23 |
+
|
| 24 |
+
def reset_parameters(self):
|
| 25 |
+
nn.init.constant_(self.gamma, self.init_values)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
hf_src/layers/mlp.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from typing import Callable, List, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from hf_src.utils import cat_keep_shapes, uncat_with_shapes
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ListForwardMixin(object):
|
| 15 |
+
def forward(self, x: Tensor):
|
| 16 |
+
raise NotImplementedError
|
| 17 |
+
|
| 18 |
+
def forward_list(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 19 |
+
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
|
| 20 |
+
x_flat = self.forward(x_flat)
|
| 21 |
+
return uncat_with_shapes(x_flat, shapes, num_tokens)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Mlp(nn.Module, ListForwardMixin):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
in_features: int,
|
| 28 |
+
hidden_features: Optional[int] = None,
|
| 29 |
+
out_features: Optional[int] = None,
|
| 30 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 31 |
+
drop: float = 0.0,
|
| 32 |
+
bias: bool = True,
|
| 33 |
+
device=None,
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
out_features = out_features or in_features
|
| 37 |
+
hidden_features = hidden_features or in_features
|
| 38 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
|
| 39 |
+
self.act = act_layer()
|
| 40 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
|
| 41 |
+
self.drop = nn.Dropout(drop)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 44 |
+
x = self.fc1(x)
|
| 45 |
+
x = self.act(x)
|
| 46 |
+
x = self.drop(x)
|
| 47 |
+
x = self.fc2(x)
|
| 48 |
+
x = self.drop(x)
|
| 49 |
+
return x
|
hf_src/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# References:
|
| 2 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Callable, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_2tuple(x):
|
| 12 |
+
if isinstance(x, tuple):
|
| 13 |
+
assert len(x) == 2
|
| 14 |
+
return x
|
| 15 |
+
|
| 16 |
+
assert isinstance(x, int)
|
| 17 |
+
return (x, x)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PatchEmbed(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
img_size: Image size.
|
| 26 |
+
patch_size: Patch token size.
|
| 27 |
+
in_chans: Number of input image channels.
|
| 28 |
+
embed_dim: Number of linear projection output channels.
|
| 29 |
+
norm_layer: Normalization layer.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 35 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 36 |
+
in_chans: int = 3,
|
| 37 |
+
embed_dim: int = 768,
|
| 38 |
+
norm_layer: Callable | None = None,
|
| 39 |
+
flatten_embedding: bool = True,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
image_HW = make_2tuple(img_size)
|
| 44 |
+
patch_HW = make_2tuple(patch_size)
|
| 45 |
+
patch_grid_size = (
|
| 46 |
+
image_HW[0] // patch_HW[0],
|
| 47 |
+
image_HW[1] // patch_HW[1],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.img_size = image_HW
|
| 51 |
+
self.patch_size = patch_HW
|
| 52 |
+
self.patches_resolution = patch_grid_size
|
| 53 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 54 |
+
|
| 55 |
+
self.in_chans = in_chans
|
| 56 |
+
self.embed_dim = embed_dim
|
| 57 |
+
|
| 58 |
+
self.flatten_embedding = flatten_embedding
|
| 59 |
+
|
| 60 |
+
self.proj = nn.Conv2d(
|
| 61 |
+
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
| 62 |
+
)
|
| 63 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 64 |
+
|
| 65 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 66 |
+
_, _, H, W = x.shape
|
| 67 |
+
# patch_H, patch_W = self.patch_size
|
| 68 |
+
# assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 69 |
+
# assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 70 |
+
|
| 71 |
+
x = self.proj(x) # B C H W
|
| 72 |
+
H, W = x.size(2), x.size(3)
|
| 73 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 74 |
+
x = self.norm(x)
|
| 75 |
+
if not self.flatten_embedding:
|
| 76 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
def flops(self) -> float:
|
| 80 |
+
Ho, Wo = self.patches_resolution
|
| 81 |
+
flops = (
|
| 82 |
+
Ho
|
| 83 |
+
* Wo
|
| 84 |
+
* self.embed_dim
|
| 85 |
+
* self.in_chans
|
| 86 |
+
* (self.patch_size[0] * self.patch_size[1])
|
| 87 |
+
)
|
| 88 |
+
if self.norm is not None:
|
| 89 |
+
flops += Ho * Wo * self.embed_dim
|
| 90 |
+
return flops
|
| 91 |
+
|
| 92 |
+
def reset_parameters(self):
|
| 93 |
+
k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
|
| 94 |
+
nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
|
| 95 |
+
if self.proj.bias is not None:
|
| 96 |
+
nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
|
hf_src/layers/rms_norm.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RMSNorm(nn.Module):
|
| 11 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 14 |
+
self.eps = eps
|
| 15 |
+
|
| 16 |
+
def reset_parameters(self) -> None:
|
| 17 |
+
nn.init.constant_(self.weight, 1)
|
| 18 |
+
|
| 19 |
+
def _norm(self, x: Tensor) -> Tensor:
|
| 20 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 23 |
+
output = self._norm(x.float()).type_as(x)
|
| 24 |
+
return output * self.weight
|
hf_src/layers/rope_attention.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from torch import Tensor, nn
|
| 13 |
+
|
| 14 |
+
from hf_src.utils import cat_keep_shapes, uncat_with_shapes
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# RoPE-related functions:
|
| 18 |
+
def rope_rotate_half(x: Tensor) -> Tensor:
|
| 19 |
+
# x: [ x0 x1 x2 x3 x4 x5]
|
| 20 |
+
# out: [-x3 -x4 -x5 x0 x1 x2]
|
| 21 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 22 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
|
| 26 |
+
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
|
| 27 |
+
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
|
| 28 |
+
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
|
| 29 |
+
return (x * cos) + (rope_rotate_half(x) * sin)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LinearKMaskedBias(nn.Linear):
|
| 33 |
+
def __init__(self, *args, **kwargs):
|
| 34 |
+
super().__init__(*args, **kwargs)
|
| 35 |
+
o = self.out_features
|
| 36 |
+
assert o % 3 == 0
|
| 37 |
+
if self.bias is not None:
|
| 38 |
+
self.register_buffer(
|
| 39 |
+
"bias_mask", torch.full_like(self.bias, fill_value=math.nan)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
masked_bias = (
|
| 44 |
+
self.bias * self.bias_mask.to(self.bias.dtype)
|
| 45 |
+
if self.bias is not None
|
| 46 |
+
else None
|
| 47 |
+
)
|
| 48 |
+
return F.linear(input, self.weight, masked_bias)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SelfAttention(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
dim: int,
|
| 55 |
+
num_heads: int = 8,
|
| 56 |
+
qkv_bias: bool = False,
|
| 57 |
+
proj_bias: bool = True,
|
| 58 |
+
attn_drop: float = 0.0,
|
| 59 |
+
proj_drop: float = 0.0,
|
| 60 |
+
mask_k_bias: bool = False,
|
| 61 |
+
device=None,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
head_dim = dim // num_heads
|
| 66 |
+
self.scale = head_dim**-0.5
|
| 67 |
+
|
| 68 |
+
linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear
|
| 69 |
+
self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device)
|
| 70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 71 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
|
| 72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 73 |
+
|
| 74 |
+
def apply_rope(
|
| 75 |
+
self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]
|
| 76 |
+
) -> Tuple[Tensor, Tensor]:
|
| 77 |
+
# All operations will use the dtype of rope, the output is cast back to the dtype of q and k
|
| 78 |
+
q_dtype = q.dtype
|
| 79 |
+
k_dtype = k.dtype
|
| 80 |
+
sin, cos = rope
|
| 81 |
+
rope_dtype = sin.dtype
|
| 82 |
+
q = q.to(dtype=rope_dtype)
|
| 83 |
+
k = k.to(dtype=rope_dtype)
|
| 84 |
+
N = q.shape[-2]
|
| 85 |
+
prefix = N - sin.shape[-2]
|
| 86 |
+
assert prefix >= 0
|
| 87 |
+
q_prefix = q[:, :, :prefix, :]
|
| 88 |
+
q = rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
|
| 89 |
+
q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head]
|
| 90 |
+
k_prefix = k[:, :, :prefix, :]
|
| 91 |
+
k = rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
|
| 92 |
+
k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head]
|
| 93 |
+
q = q.to(dtype=q_dtype)
|
| 94 |
+
k = k.to(dtype=k_dtype)
|
| 95 |
+
return q, k
|
| 96 |
+
|
| 97 |
+
def forward(self, x: Tensor, attn_bias=None, rope: Tensor = None) -> Tensor:
|
| 98 |
+
qkv = self.qkv(x)
|
| 99 |
+
attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope)
|
| 100 |
+
x = self.proj(attn_v)
|
| 101 |
+
x = self.proj_drop(x)
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]:
|
| 105 |
+
assert len(x_list) == len(rope_list) # should be enforced by the Block
|
| 106 |
+
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
|
| 107 |
+
qkv_flat = self.qkv(x_flat)
|
| 108 |
+
qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens)
|
| 109 |
+
att_out = []
|
| 110 |
+
for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)):
|
| 111 |
+
att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope))
|
| 112 |
+
x_flat, shapes, num_tokens = cat_keep_shapes(att_out)
|
| 113 |
+
x_flat = self.proj(x_flat)
|
| 114 |
+
return uncat_with_shapes(x_flat, shapes, num_tokens)
|
| 115 |
+
|
| 116 |
+
def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor:
|
| 117 |
+
assert attn_bias is None
|
| 118 |
+
B, N, _ = qkv.shape
|
| 119 |
+
C = self.qkv.in_features
|
| 120 |
+
|
| 121 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 122 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 123 |
+
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
|
| 124 |
+
if rope is not None:
|
| 125 |
+
q, k = self.apply_rope(q, k, rope)
|
| 126 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 127 |
+
x = x.transpose(1, 2)
|
| 128 |
+
return x.reshape([B, N, C])
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class CausalSelfAttention(nn.Module):
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
dim: int,
|
| 135 |
+
num_heads: int = 8,
|
| 136 |
+
qkv_bias: bool = False,
|
| 137 |
+
proj_bias: bool = True,
|
| 138 |
+
attn_drop: float = 0.0,
|
| 139 |
+
proj_drop: float = 0.0,
|
| 140 |
+
) -> None:
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.dim = dim
|
| 143 |
+
self.num_heads = num_heads
|
| 144 |
+
head_dim = dim // num_heads
|
| 145 |
+
self.scale = head_dim**-0.5
|
| 146 |
+
|
| 147 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 148 |
+
self.attn_drop = attn_drop
|
| 149 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 150 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 151 |
+
|
| 152 |
+
def init_weights(
|
| 153 |
+
self,
|
| 154 |
+
init_attn_std: float | None = None,
|
| 155 |
+
init_proj_std: float | None = None,
|
| 156 |
+
factor: float = 1.0,
|
| 157 |
+
) -> None:
|
| 158 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
| 159 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
| 160 |
+
nn.init.normal_(self.qkv.weight, std=init_attn_std)
|
| 161 |
+
nn.init.normal_(self.proj.weight, std=init_proj_std)
|
| 162 |
+
if self.qkv.bias is not None:
|
| 163 |
+
nn.init.zeros_(self.qkv.bias)
|
| 164 |
+
if self.proj.bias is not None:
|
| 165 |
+
nn.init.zeros_(self.proj.bias)
|
| 166 |
+
|
| 167 |
+
def forward(self, x: Tensor, is_causal: bool = True) -> Tensor:
|
| 168 |
+
B, N, C = x.shape
|
| 169 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 170 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 171 |
+
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
|
| 172 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 173 |
+
q,
|
| 174 |
+
k,
|
| 175 |
+
v,
|
| 176 |
+
attn_mask=None,
|
| 177 |
+
dropout_p=self.attn_drop if self.training else 0,
|
| 178 |
+
is_causal=is_causal,
|
| 179 |
+
)
|
| 180 |
+
x = x.transpose(1, 2).contiguous().view(B, N, C)
|
| 181 |
+
x = self.proj_drop(self.proj(x))
|
| 182 |
+
return x
|
hf_src/layers/rope_block.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
from typing import Callable, List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
from hf_src.utils import cat_keep_shapes, uncat_with_shapes
|
| 12 |
+
|
| 13 |
+
from .mlp import Mlp
|
| 14 |
+
from .layer_scale import LayerScale # , DropPath
|
| 15 |
+
from .rope_attention import CausalSelfAttention, SelfAttention
|
| 16 |
+
|
| 17 |
+
torch._dynamo.config.automatic_dynamic_shapes = False
|
| 18 |
+
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SelfAttentionBlock(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
num_heads: int,
|
| 26 |
+
ffn_ratio: float = 4.0,
|
| 27 |
+
qkv_bias: bool = False,
|
| 28 |
+
proj_bias: bool = True,
|
| 29 |
+
ffn_bias: bool = True,
|
| 30 |
+
drop: float = 0.0,
|
| 31 |
+
attn_drop: float = 0.0,
|
| 32 |
+
init_values=None,
|
| 33 |
+
drop_path: float = 0.0,
|
| 34 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 35 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 36 |
+
attn_class: Callable[..., nn.Module] = SelfAttention,
|
| 37 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 38 |
+
mask_k_bias: bool = False,
|
| 39 |
+
device=None,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 43 |
+
self.norm1 = norm_layer(dim)
|
| 44 |
+
self.attn = attn_class(
|
| 45 |
+
dim,
|
| 46 |
+
num_heads=num_heads,
|
| 47 |
+
qkv_bias=qkv_bias,
|
| 48 |
+
proj_bias=proj_bias,
|
| 49 |
+
attn_drop=attn_drop,
|
| 50 |
+
proj_drop=drop,
|
| 51 |
+
mask_k_bias=mask_k_bias,
|
| 52 |
+
device=device,
|
| 53 |
+
)
|
| 54 |
+
self.ls1 = (
|
| 55 |
+
LayerScale(dim, init_values=init_values, device=device)
|
| 56 |
+
if init_values
|
| 57 |
+
else nn.Identity()
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.norm2 = norm_layer(dim)
|
| 61 |
+
mlp_hidden_dim = int(dim * ffn_ratio)
|
| 62 |
+
self.mlp = ffn_layer(
|
| 63 |
+
in_features=dim,
|
| 64 |
+
hidden_features=mlp_hidden_dim,
|
| 65 |
+
act_layer=act_layer,
|
| 66 |
+
drop=drop,
|
| 67 |
+
bias=ffn_bias,
|
| 68 |
+
device=device,
|
| 69 |
+
)
|
| 70 |
+
self.ls2 = (
|
| 71 |
+
LayerScale(dim, init_values=init_values, device=device)
|
| 72 |
+
if init_values
|
| 73 |
+
else nn.Identity()
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.sample_drop_ratio = drop_path
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def _maybe_index_rope(
|
| 80 |
+
rope: tuple[Tensor, Tensor] | None, indices: Tensor
|
| 81 |
+
) -> tuple[Tensor, Tensor] | None:
|
| 82 |
+
if rope is None:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
sin, cos = rope
|
| 86 |
+
assert sin.ndim == cos.ndim
|
| 87 |
+
if sin.ndim == 4:
|
| 88 |
+
# If the rope embedding has a batch dimension (is different for each batch element), index into it
|
| 89 |
+
return sin[indices], cos[indices] # [batch, heads, patches, embed_dim]
|
| 90 |
+
else:
|
| 91 |
+
# No batch dimension, do not index
|
| 92 |
+
return sin, cos # [heads, patches, embed_dim] or [patches, embed_dim]
|
| 93 |
+
|
| 94 |
+
def _forward(self, x: Tensor, rope=None) -> Tensor:
|
| 95 |
+
"""
|
| 96 |
+
This is the reference implementation for a single tensor, matching what is done below for a list.
|
| 97 |
+
We call the list op on [x] instead of this function.
|
| 98 |
+
"""
|
| 99 |
+
b, _, _ = x.shape
|
| 100 |
+
sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1)
|
| 101 |
+
residual_scale_factor = b / sample_subset_size
|
| 102 |
+
|
| 103 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 104 |
+
indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 105 |
+
|
| 106 |
+
x_subset_1 = x[indices_1]
|
| 107 |
+
rope_subset = self._maybe_index_rope(rope, indices_1)
|
| 108 |
+
residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset)
|
| 109 |
+
|
| 110 |
+
x_attn = torch.index_add(
|
| 111 |
+
x,
|
| 112 |
+
dim=0,
|
| 113 |
+
source=self.ls1(residual_1),
|
| 114 |
+
index=indices_1,
|
| 115 |
+
alpha=residual_scale_factor,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 119 |
+
|
| 120 |
+
x_subset_2 = x_attn[indices_2]
|
| 121 |
+
residual_2 = self.mlp(self.norm2(x_subset_2))
|
| 122 |
+
|
| 123 |
+
x_ffn = torch.index_add(
|
| 124 |
+
x_attn,
|
| 125 |
+
dim=0,
|
| 126 |
+
source=self.ls2(residual_2),
|
| 127 |
+
index=indices_2,
|
| 128 |
+
alpha=residual_scale_factor,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
|
| 132 |
+
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
|
| 133 |
+
|
| 134 |
+
return x_ffn
|
| 135 |
+
|
| 136 |
+
def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
|
| 137 |
+
"""
|
| 138 |
+
This list operator concatenates the tokens from the list of inputs together to save
|
| 139 |
+
on the elementwise operations. Torch-compile memory-planning allows hiding the overhead
|
| 140 |
+
related to concat ops.
|
| 141 |
+
"""
|
| 142 |
+
b_list = [x.shape[0] for x in x_list]
|
| 143 |
+
sample_subset_sizes = [
|
| 144 |
+
max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list
|
| 145 |
+
]
|
| 146 |
+
residual_scale_factors = [
|
| 147 |
+
b / sample_subset_size
|
| 148 |
+
for b, sample_subset_size in zip(b_list, sample_subset_sizes)
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 152 |
+
indices_1_list = [
|
| 153 |
+
(torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 154 |
+
for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
|
| 155 |
+
]
|
| 156 |
+
x_subset_1_list = [
|
| 157 |
+
x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
if rope_list is not None:
|
| 161 |
+
rope_subset_list = [
|
| 162 |
+
self._maybe_index_rope(rope, indices_1)
|
| 163 |
+
for rope, indices_1 in zip(rope_list, indices_1_list)
|
| 164 |
+
]
|
| 165 |
+
else:
|
| 166 |
+
rope_subset_list = rope_list
|
| 167 |
+
|
| 168 |
+
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
|
| 169 |
+
norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
|
| 170 |
+
residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
|
| 171 |
+
|
| 172 |
+
x_attn_list = [
|
| 173 |
+
torch.index_add(
|
| 174 |
+
x,
|
| 175 |
+
dim=0,
|
| 176 |
+
source=self.ls1(residual_1),
|
| 177 |
+
index=indices_1,
|
| 178 |
+
alpha=residual_scale_factor,
|
| 179 |
+
)
|
| 180 |
+
for x, residual_1, indices_1, residual_scale_factor in zip(
|
| 181 |
+
x_list, residual_1_list, indices_1_list, residual_scale_factors
|
| 182 |
+
)
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
indices_2_list = [
|
| 186 |
+
(torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 187 |
+
for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
|
| 188 |
+
]
|
| 189 |
+
x_subset_2_list = [
|
| 190 |
+
x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)
|
| 191 |
+
]
|
| 192 |
+
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
|
| 193 |
+
norm2_flat = self.norm2(flattened)
|
| 194 |
+
norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens)
|
| 195 |
+
|
| 196 |
+
residual_2_list = self.mlp.forward_list(norm2_list)
|
| 197 |
+
|
| 198 |
+
x_ffn = [
|
| 199 |
+
torch.index_add(
|
| 200 |
+
x_attn,
|
| 201 |
+
dim=0,
|
| 202 |
+
source=self.ls2(residual_2),
|
| 203 |
+
index=indices_2,
|
| 204 |
+
alpha=residual_scale_factor,
|
| 205 |
+
)
|
| 206 |
+
for x_attn, residual_2, indices_2, residual_scale_factor in zip(
|
| 207 |
+
x_attn_list, residual_2_list, indices_2_list, residual_scale_factors
|
| 208 |
+
)
|
| 209 |
+
]
|
| 210 |
+
else:
|
| 211 |
+
x_out = []
|
| 212 |
+
for x, rope in zip(x_list, rope_list):
|
| 213 |
+
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
|
| 214 |
+
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
|
| 215 |
+
x_out.append(x_ffn)
|
| 216 |
+
x_ffn = x_out
|
| 217 |
+
|
| 218 |
+
return x_ffn
|
| 219 |
+
|
| 220 |
+
def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
|
| 221 |
+
if isinstance(x_or_x_list, Tensor):
|
| 222 |
+
# for reference:
|
| 223 |
+
# return self._forward(x_or_x_list, rope=rope_or_rope_list)
|
| 224 |
+
# in order to match implementations we call the list op:
|
| 225 |
+
return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
|
| 226 |
+
elif isinstance(x_or_x_list, list):
|
| 227 |
+
if rope_or_rope_list is None:
|
| 228 |
+
rope_or_rope_list = [None for x in x_or_x_list]
|
| 229 |
+
# return [self._forward(x, rope=rope) for x, rope in zip(x_or_x_list, rope_or_rope_list)]
|
| 230 |
+
return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
|
| 231 |
+
else:
|
| 232 |
+
raise AssertionError
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CausalSelfAttentionBlock(nn.Module):
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
dim: int,
|
| 239 |
+
num_heads: int,
|
| 240 |
+
ffn_ratio: float = 4.0,
|
| 241 |
+
ls_init_value: Optional[float] = None,
|
| 242 |
+
is_causal: bool = True,
|
| 243 |
+
act_layer: Callable = nn.GELU,
|
| 244 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 245 |
+
dropout_prob: float = 0.0,
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
self.dim = dim
|
| 250 |
+
self.is_causal = is_causal
|
| 251 |
+
self.ls1 = (
|
| 252 |
+
LayerScale(dim, init_values=ls_init_value)
|
| 253 |
+
if ls_init_value
|
| 254 |
+
else nn.Identity()
|
| 255 |
+
)
|
| 256 |
+
self.attention_norm = norm_layer(dim)
|
| 257 |
+
self.attention = CausalSelfAttention(
|
| 258 |
+
dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self.ffn_norm = norm_layer(dim)
|
| 262 |
+
ffn_hidden_dim = int(dim * ffn_ratio)
|
| 263 |
+
self.feed_forward = Mlp(
|
| 264 |
+
in_features=dim,
|
| 265 |
+
hidden_features=ffn_hidden_dim,
|
| 266 |
+
drop=dropout_prob,
|
| 267 |
+
act_layer=act_layer,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.ls2 = (
|
| 271 |
+
LayerScale(dim, init_values=ls_init_value)
|
| 272 |
+
if ls_init_value
|
| 273 |
+
else nn.Identity()
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def init_weights(
|
| 277 |
+
self,
|
| 278 |
+
init_attn_std: float | None = None,
|
| 279 |
+
init_proj_std: float | None = None,
|
| 280 |
+
init_fc_std: float | None = None,
|
| 281 |
+
factor: float = 1.0,
|
| 282 |
+
) -> None:
|
| 283 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
| 284 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
| 285 |
+
init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
|
| 286 |
+
self.attention.init_weights(init_attn_std, init_proj_std)
|
| 287 |
+
self.attention_norm.reset_parameters()
|
| 288 |
+
nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
|
| 289 |
+
nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
|
| 290 |
+
self.ffn_norm.reset_parameters()
|
| 291 |
+
|
| 292 |
+
def forward(
|
| 293 |
+
self,
|
| 294 |
+
x: torch.Tensor,
|
| 295 |
+
):
|
| 296 |
+
|
| 297 |
+
x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
|
| 298 |
+
x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
|
| 299 |
+
return x_ffn
|
hf_src/layers/rope_position_encoding.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Literal
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights
|
| 15 |
+
# Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`.
|
| 16 |
+
class RopePositionEmbedding(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
embed_dim: int,
|
| 20 |
+
*,
|
| 21 |
+
num_heads: int,
|
| 22 |
+
base: float | None = 100.0,
|
| 23 |
+
min_period: float | None = None,
|
| 24 |
+
max_period: float | None = None,
|
| 25 |
+
normalize_coords: Literal["min", "max", "separate"] = "separate",
|
| 26 |
+
shift_coords: float | None = None,
|
| 27 |
+
jitter_coords: float | None = None,
|
| 28 |
+
rescale_coords: float | None = None,
|
| 29 |
+
dtype: torch.dtype | None = None,
|
| 30 |
+
device: torch.device | None = None,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert embed_dim % (4 * num_heads) == 0
|
| 34 |
+
both_periods = min_period is not None and max_period is not None
|
| 35 |
+
if (base is None and not both_periods) or (base is not None and both_periods):
|
| 36 |
+
raise ValueError(
|
| 37 |
+
"Either `base` or `min_period`+`max_period` must be provided."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
D_head = embed_dim // num_heads
|
| 41 |
+
self.base = base
|
| 42 |
+
self.min_period = min_period
|
| 43 |
+
self.max_period = max_period
|
| 44 |
+
self.D_head = D_head
|
| 45 |
+
self.normalize_coords = normalize_coords
|
| 46 |
+
self.shift_coords = shift_coords
|
| 47 |
+
self.jitter_coords = jitter_coords
|
| 48 |
+
self.rescale_coords = rescale_coords
|
| 49 |
+
|
| 50 |
+
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
|
| 51 |
+
self.dtype = dtype # Don't rely on self.periods.dtype
|
| 52 |
+
self.register_buffer(
|
| 53 |
+
"periods",
|
| 54 |
+
torch.empty(D_head // 4, device=device, dtype=dtype),
|
| 55 |
+
persistent=True,
|
| 56 |
+
)
|
| 57 |
+
self._init_weights()
|
| 58 |
+
|
| 59 |
+
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
|
| 60 |
+
device = self.periods.device
|
| 61 |
+
dtype = self.dtype
|
| 62 |
+
dd = {"device": device, "dtype": dtype}
|
| 63 |
+
|
| 64 |
+
# Prepare coords in range [-1, +1]
|
| 65 |
+
if self.normalize_coords == "max":
|
| 66 |
+
max_HW = max(H, W)
|
| 67 |
+
coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
|
| 68 |
+
coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
|
| 69 |
+
elif self.normalize_coords == "min":
|
| 70 |
+
min_HW = min(H, W)
|
| 71 |
+
coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
|
| 72 |
+
coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
|
| 73 |
+
elif self.normalize_coords == "separate":
|
| 74 |
+
coords_h = torch.arange(0.5, H, **dd) / H # [H]
|
| 75 |
+
coords_w = torch.arange(0.5, W, **dd) / W # [W]
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
|
| 78 |
+
coords = torch.stack(
|
| 79 |
+
torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1
|
| 80 |
+
) # [H, W, 2]
|
| 81 |
+
coords = coords.flatten(0, 1) # [HW, 2]
|
| 82 |
+
coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
|
| 83 |
+
|
| 84 |
+
# Shift coords by adding a uniform value in [-shift, shift]
|
| 85 |
+
if self.training and self.shift_coords is not None:
|
| 86 |
+
shift_hw = torch.empty(2, **dd).uniform_(
|
| 87 |
+
-self.shift_coords, self.shift_coords
|
| 88 |
+
)
|
| 89 |
+
coords += shift_hw[None, :]
|
| 90 |
+
|
| 91 |
+
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
|
| 92 |
+
if self.training and self.jitter_coords is not None:
|
| 93 |
+
jitter_max = np.log(self.jitter_coords)
|
| 94 |
+
jitter_min = -jitter_max
|
| 95 |
+
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
|
| 96 |
+
coords *= jitter_hw[None, :]
|
| 97 |
+
|
| 98 |
+
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
|
| 99 |
+
if self.training and self.rescale_coords is not None:
|
| 100 |
+
rescale_max = np.log(self.rescale_coords)
|
| 101 |
+
rescale_min = -rescale_max
|
| 102 |
+
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
|
| 103 |
+
coords *= rescale_hw
|
| 104 |
+
|
| 105 |
+
# Prepare angles and sin/cos
|
| 106 |
+
angles = (
|
| 107 |
+
2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
|
| 108 |
+
) # [HW, 2, D//4]
|
| 109 |
+
angles = angles.flatten(1, 2) # [HW, D//2]
|
| 110 |
+
angles = angles.tile(2) # [HW, D]
|
| 111 |
+
cos = torch.cos(angles) # [HW, D]
|
| 112 |
+
sin = torch.sin(angles) # [HW, D]
|
| 113 |
+
|
| 114 |
+
return sin, cos # 2 * [HW, D]
|
| 115 |
+
|
| 116 |
+
def _init_weights(self):
|
| 117 |
+
device = self.periods.device
|
| 118 |
+
dtype = self.dtype
|
| 119 |
+
if self.base is not None:
|
| 120 |
+
periods = self.base ** (
|
| 121 |
+
2
|
| 122 |
+
* torch.arange(self.D_head // 4, device=device, dtype=dtype)
|
| 123 |
+
/ (self.D_head // 2)
|
| 124 |
+
) # [D//4]
|
| 125 |
+
else:
|
| 126 |
+
base = self.max_period / self.min_period
|
| 127 |
+
exponents = torch.linspace(
|
| 128 |
+
0, 1, self.D_head // 4, device=device, dtype=dtype
|
| 129 |
+
) # [D//4] range [0, 1]
|
| 130 |
+
periods = base**exponents # range [1, max_period / min_period]
|
| 131 |
+
periods = periods / base # range [min_period / max_period, 1]
|
| 132 |
+
periods = periods * self.max_period # range [min_period, max_period]
|
| 133 |
+
self.periods.data = periods
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
import torch
|
| 138 |
+
import numpy as np
|
| 139 |
+
import matplotlib.pyplot as plt
|
| 140 |
+
|
| 141 |
+
def get_rope_values(H, W, embed_dim, num_heads, base):
|
| 142 |
+
# Setup parameters similar to Repo 1
|
| 143 |
+
D_head = embed_dim // num_heads
|
| 144 |
+
print(D_head // 4, D_head // 2, (D_head // 4) / (D_head // 2))
|
| 145 |
+
# We'll pick the first period (the "fastest" one)
|
| 146 |
+
period = base ** (2 * torch.arange(D_head // 4) / (D_head // 2))
|
| 147 |
+
|
| 148 |
+
period = period[3] # First period
|
| 149 |
+
|
| 150 |
+
# Normalized coordinates as per Repo 1
|
| 151 |
+
coords_h = torch.arange(0.5, H) / H
|
| 152 |
+
coords_w = torch.arange(0.5, W) / W
|
| 153 |
+
grid_h, grid_w = torch.meshgrid(coords_h, coords_w, indexing="ij")
|
| 154 |
+
|
| 155 |
+
# Convert to [-1, 1]
|
| 156 |
+
grid_h = 2.0 * grid_h - 1.0
|
| 157 |
+
grid_w = 2.0 * grid_w - 1.0
|
| 158 |
+
|
| 159 |
+
# Calculate Sine value (using H-coordinate for visualization)
|
| 160 |
+
# Formula: sin(2 * pi * coord / period)
|
| 161 |
+
vals = torch.sin(2 * np.pi * grid_h / period)
|
| 162 |
+
return vals.numpy()
|
| 163 |
+
|
| 164 |
+
# Settings
|
| 165 |
+
embed_dim = 768
|
| 166 |
+
num_heads = 12
|
| 167 |
+
bases = [100, 10000]
|
| 168 |
+
sizes = [14, 28]
|
| 169 |
+
|
| 170 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 171 |
+
|
| 172 |
+
for i, base in enumerate(bases):
|
| 173 |
+
for j, size in enumerate(sizes):
|
| 174 |
+
vals = get_rope_values(size, size, embed_dim, num_heads, base)
|
| 175 |
+
|
| 176 |
+
ax = axes[i, j]
|
| 177 |
+
im = ax.imshow(vals, cmap="RdBu", extent=[-1, 1, -1, 1])
|
| 178 |
+
ax.set_title(f"Base: {base} | Grid: {size}x{size}")
|
| 179 |
+
ax.set_xlabel("Width (Normalized)")
|
| 180 |
+
ax.set_ylabel("Height (Normalized)")
|
| 181 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 182 |
+
|
| 183 |
+
plt.tight_layout()
|
| 184 |
+
plt.show()
|
hf_src/layers/sparse_linear.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from typing import Callable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import xformers.ops as xops
|
| 14 |
+
|
| 15 |
+
from hf_src.utils import named_apply, named_replace
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LinearW24(torch.nn.Linear):
|
| 19 |
+
ALGO = "largest_abs_values_greedy"
|
| 20 |
+
|
| 21 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 22 |
+
super().__init__(*args, **kwargs)
|
| 23 |
+
self.sparsity_enabled = False
|
| 24 |
+
|
| 25 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
if not self.sparsity_enabled:
|
| 27 |
+
return super().forward(input)
|
| 28 |
+
|
| 29 |
+
input_shape = input.shape
|
| 30 |
+
input = input.flatten(end_dim=-2)
|
| 31 |
+
dim0 = input.shape[0]
|
| 32 |
+
if dim0 % 8 != 0:
|
| 33 |
+
# NOTE: This should be torch-compiled away
|
| 34 |
+
input = F.pad(input, [0, 0, 0, -dim0 % 8])
|
| 35 |
+
w_sparse = xops.sparsify24(
|
| 36 |
+
self.weight,
|
| 37 |
+
algo=self.ALGO,
|
| 38 |
+
gradient="ste",
|
| 39 |
+
backend="cusparselt",
|
| 40 |
+
)
|
| 41 |
+
return F.linear(
|
| 42 |
+
input,
|
| 43 |
+
w_sparse,
|
| 44 |
+
self.bias,
|
| 45 |
+
)[
|
| 46 |
+
:dim0
|
| 47 |
+
].unflatten(dim=0, sizes=input_shape[:-1])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def replace_linears_with_sparse_linear(
|
| 51 |
+
root_module: nn.Module, *, filter_fn: Callable[[str], bool]
|
| 52 |
+
) -> nn.Module:
|
| 53 |
+
total_count = 0
|
| 54 |
+
|
| 55 |
+
def replace(module: nn.Module, name: str) -> nn.Module:
|
| 56 |
+
nonlocal total_count
|
| 57 |
+
if not isinstance(module, nn.Linear) or not filter_fn(name):
|
| 58 |
+
return module
|
| 59 |
+
assert type(module) == nn.Linear, "Subtypes not supported"
|
| 60 |
+
new_module = LinearW24(
|
| 61 |
+
in_features=module.in_features,
|
| 62 |
+
out_features=module.out_features,
|
| 63 |
+
bias=module.bias is not None,
|
| 64 |
+
dtype=module.weight.dtype,
|
| 65 |
+
device=module.weight.device,
|
| 66 |
+
)
|
| 67 |
+
new_module.weight = module.weight
|
| 68 |
+
new_module.bias = module.bias
|
| 69 |
+
total_count += 1
|
| 70 |
+
return new_module
|
| 71 |
+
|
| 72 |
+
out = named_replace(replace, root_module)
|
| 73 |
+
assert total_count > 0, "2:4 sparsity: no layer found to sparsify"
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def update_24sparsity(root_module: nn.Module, enabled: bool) -> int:
|
| 78 |
+
num_modified = 0
|
| 79 |
+
|
| 80 |
+
def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module:
|
| 81 |
+
nonlocal num_modified
|
| 82 |
+
if not isinstance(module, LinearW24):
|
| 83 |
+
return module
|
| 84 |
+
num_modified += 1
|
| 85 |
+
module.sparsity_enabled = enabled
|
| 86 |
+
return module
|
| 87 |
+
|
| 88 |
+
named_apply(maybe_apply_sparsity, root_module)
|
| 89 |
+
# Force re-compile everything
|
| 90 |
+
torch._dynamo.reset_code_caches()
|
| 91 |
+
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
| 92 |
+
|
| 93 |
+
reset_cudagraph_trees()
|
| 94 |
+
return num_modified
|
hf_src/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from typing import Callable, Optional
|
| 4 |
+
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SwiGLUFFN(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_features: int,
|
| 14 |
+
hidden_features: Optional[int] = None,
|
| 15 |
+
out_features: Optional[int] = None,
|
| 16 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 17 |
+
drop: float = 0.0,
|
| 18 |
+
bias: bool = True,
|
| 19 |
+
) -> None:
|
| 20 |
+
super().__init__()
|
| 21 |
+
out_features = out_features or in_features
|
| 22 |
+
hidden_features = hidden_features or in_features
|
| 23 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 24 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
x12 = self.w12(x)
|
| 28 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 29 |
+
hidden = F.silu(x1) * x2
|
| 30 |
+
return self.w3(hidden)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 34 |
+
try:
|
| 35 |
+
if XFORMERS_ENABLED:
|
| 36 |
+
from xformers.ops import SwiGLU
|
| 37 |
+
|
| 38 |
+
XFORMERS_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
raise ImportError
|
| 41 |
+
except ImportError:
|
| 42 |
+
SwiGLU = SwiGLUFFN
|
| 43 |
+
XFORMERS_AVAILABLE = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
in_features: int,
|
| 50 |
+
hidden_features: Optional[int] = None,
|
| 51 |
+
out_features: Optional[int] = None,
|
| 52 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 53 |
+
drop: float = 0.0,
|
| 54 |
+
bias: bool = True,
|
| 55 |
+
) -> None:
|
| 56 |
+
out_features = out_features or in_features
|
| 57 |
+
hidden_features = hidden_features or in_features
|
| 58 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 59 |
+
super().__init__(
|
| 60 |
+
in_features=in_features,
|
| 61 |
+
hidden_features=hidden_features,
|
| 62 |
+
out_features=out_features,
|
| 63 |
+
bias=bias,
|
| 64 |
+
)
|
hf_src/model/__init__.py
ADDED
|
File without changes
|
hf_src/model/image/__init__.py
ADDED
|
File without changes
|
hf_src/model/image/vitv2/__init__.py
ADDED
|
File without changes
|
hf_src/model/image/vitv2/transformer.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
|
| 2 |
+
# References:
|
| 3 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 4 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
|
| 15 |
+
from torch.nn.init import trunc_normal_
|
| 16 |
+
from torch.nn.functional import interpolate
|
| 17 |
+
|
| 18 |
+
from hf_src.layers import (
|
| 19 |
+
Mlp,
|
| 20 |
+
PatchEmbed,
|
| 21 |
+
SwiGLUFFNFused,
|
| 22 |
+
MemEffAttention,
|
| 23 |
+
NestedTensorBlock as Block,
|
| 24 |
+
LayerScale,
|
| 25 |
+
RMSNorm,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def named_apply(
|
| 30 |
+
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
| 31 |
+
) -> nn.Module:
|
| 32 |
+
if not depth_first and include_root:
|
| 33 |
+
fn(module=module, name=name)
|
| 34 |
+
for child_name, child_module in module.named_children():
|
| 35 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 36 |
+
named_apply(
|
| 37 |
+
fn=fn,
|
| 38 |
+
module=child_module,
|
| 39 |
+
name=child_name,
|
| 40 |
+
depth_first=depth_first,
|
| 41 |
+
include_root=True,
|
| 42 |
+
)
|
| 43 |
+
if depth_first and include_root:
|
| 44 |
+
fn(module=module, name=name)
|
| 45 |
+
return module
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BlockChunk(nn.ModuleList):
|
| 49 |
+
def forward(self, x, return_attention=False):
|
| 50 |
+
# Adaptation for returing attentions
|
| 51 |
+
for i, b in enumerate(self):
|
| 52 |
+
if i < len(self) - 1:
|
| 53 |
+
x = b(x)
|
| 54 |
+
else:
|
| 55 |
+
return b(x, return_attention=return_attention)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ViTv2(nn.Module):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
*,
|
| 63 |
+
img_size=518,
|
| 64 |
+
patch_size=16,
|
| 65 |
+
in_chans=3,
|
| 66 |
+
embed_dim=768,
|
| 67 |
+
depth=12,
|
| 68 |
+
num_heads=12,
|
| 69 |
+
mlp_ratio=4.0,
|
| 70 |
+
qkv_bias=True,
|
| 71 |
+
ffn_bias=True,
|
| 72 |
+
proj_bias=True,
|
| 73 |
+
drop_path_rate=0.0,
|
| 74 |
+
drop_path_uniform=True,
|
| 75 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 76 |
+
embed_layer=PatchEmbed,
|
| 77 |
+
act_layer=nn.GELU,
|
| 78 |
+
block_fn=Block,
|
| 79 |
+
ffn_layer="mlp",
|
| 80 |
+
block_chunks=0,
|
| 81 |
+
num_register_tokens=0,
|
| 82 |
+
interpolate_antialias=False,
|
| 83 |
+
interpolate_offset=0.1,
|
| 84 |
+
num_classes=None,
|
| 85 |
+
**ignored_kwargs,
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
img_size (int, tuple): input image size
|
| 90 |
+
patch_size (int, tuple): patch size
|
| 91 |
+
in_chans (int): number of input channels
|
| 92 |
+
embed_dim (int): embedding dimension
|
| 93 |
+
depth (int): depth of transformer
|
| 94 |
+
num_heads (int): number of attention heads
|
| 95 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 96 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 97 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 98 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 99 |
+
drop_path_rate (float): stochastic depth rate
|
| 100 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 101 |
+
weight_init (str): weight init scheme
|
| 102 |
+
init_values (float): layer-scale init values
|
| 103 |
+
embed_layer (nn.Module): patch embedding layer
|
| 104 |
+
act_layer (nn.Module): MLP activation layer
|
| 105 |
+
block_fn (nn.Module): transformer block class
|
| 106 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 107 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 108 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 109 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 110 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 111 |
+
"""
|
| 112 |
+
super().__init__(**ignored_kwargs)
|
| 113 |
+
|
| 114 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 115 |
+
self.img_size = img_size
|
| 116 |
+
|
| 117 |
+
self.num_features = self.embed_dim = embed_dim
|
| 118 |
+
|
| 119 |
+
self.num_tokens = 1
|
| 120 |
+
self.n_blocks = depth
|
| 121 |
+
self.num_heads = num_heads
|
| 122 |
+
self.patch_size = patch_size
|
| 123 |
+
self.num_register_tokens = num_register_tokens
|
| 124 |
+
self.interpolate_antialias = interpolate_antialias
|
| 125 |
+
self.interpolate_offset = interpolate_offset
|
| 126 |
+
|
| 127 |
+
self.patch_embed = embed_layer(
|
| 128 |
+
img_size=img_size,
|
| 129 |
+
patch_size=patch_size,
|
| 130 |
+
in_chans=in_chans,
|
| 131 |
+
embed_dim=embed_dim,
|
| 132 |
+
)
|
| 133 |
+
num_patches = self.patch_embed.num_patches
|
| 134 |
+
|
| 135 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 136 |
+
self.pos_embed = nn.Parameter(
|
| 137 |
+
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
| 138 |
+
)
|
| 139 |
+
assert num_register_tokens >= 0
|
| 140 |
+
self.register_tokens = (
|
| 141 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
|
| 142 |
+
if num_register_tokens
|
| 143 |
+
else None
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if drop_path_uniform is True:
|
| 147 |
+
dpr = [drop_path_rate] * depth
|
| 148 |
+
else:
|
| 149 |
+
dpr = [
|
| 150 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 151 |
+
] # stochastic depth decay rule
|
| 152 |
+
|
| 153 |
+
if ffn_layer == "mlp":
|
| 154 |
+
ffn_layer = Mlp
|
| 155 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 156 |
+
ffn_layer = SwiGLUFFNFused
|
| 157 |
+
elif ffn_layer == "identity":
|
| 158 |
+
|
| 159 |
+
def f(*args, **kwargs):
|
| 160 |
+
return nn.Identity()
|
| 161 |
+
|
| 162 |
+
ffn_layer = f
|
| 163 |
+
else:
|
| 164 |
+
raise NotImplementedError
|
| 165 |
+
|
| 166 |
+
blocks_list = [
|
| 167 |
+
block_fn(
|
| 168 |
+
dim=embed_dim,
|
| 169 |
+
num_heads=num_heads,
|
| 170 |
+
mlp_ratio=mlp_ratio,
|
| 171 |
+
qkv_bias=qkv_bias,
|
| 172 |
+
proj_bias=proj_bias,
|
| 173 |
+
ffn_bias=ffn_bias,
|
| 174 |
+
drop_path=dpr[i],
|
| 175 |
+
norm_layer=norm_layer,
|
| 176 |
+
act_layer=act_layer,
|
| 177 |
+
ffn_layer=ffn_layer,
|
| 178 |
+
init_values=init_values,
|
| 179 |
+
)
|
| 180 |
+
for i in range(depth)
|
| 181 |
+
]
|
| 182 |
+
if block_chunks > 0:
|
| 183 |
+
self.chunked_blocks = True
|
| 184 |
+
chunked_blocks = []
|
| 185 |
+
chunksize = depth // block_chunks
|
| 186 |
+
for i in range(0, depth, chunksize):
|
| 187 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 188 |
+
chunked_blocks.append(
|
| 189 |
+
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
| 190 |
+
)
|
| 191 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 192 |
+
else:
|
| 193 |
+
self.chunked_blocks = False
|
| 194 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 195 |
+
|
| 196 |
+
self.mask_token = None
|
| 197 |
+
self.norm = norm_layer(embed_dim)
|
| 198 |
+
self.norm_patch = None
|
| 199 |
+
|
| 200 |
+
self.head = (
|
| 201 |
+
nn.Identity() if num_classes is None else nn.Linear(embed_dim, num_classes)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Initialize the model's weights
|
| 205 |
+
self.init_weights()
|
| 206 |
+
|
| 207 |
+
def init_weights(self):
|
| 208 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 209 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 210 |
+
if self.register_tokens is not None:
|
| 211 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 212 |
+
if self.mask_token is not None:
|
| 213 |
+
nn.init.zeros_(self.mask_token)
|
| 214 |
+
named_apply(init_weights_vit, self)
|
| 215 |
+
|
| 216 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 217 |
+
previous_dtype = x.dtype
|
| 218 |
+
npatch = x.shape[1] - 1
|
| 219 |
+
N = self.pos_embed.shape[1] - 1
|
| 220 |
+
if npatch == N and w == h:
|
| 221 |
+
return self.pos_embed
|
| 222 |
+
pos_embed = self.pos_embed.float()
|
| 223 |
+
class_pos_embed = pos_embed[:, 0]
|
| 224 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 225 |
+
dim = x.shape[-1]
|
| 226 |
+
w0 = w // self.patch_size
|
| 227 |
+
h0 = h // self.patch_size
|
| 228 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 229 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 230 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
| 231 |
+
|
| 232 |
+
sqrt_N = math.sqrt(N)
|
| 233 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
| 234 |
+
patch_pos_embed = interpolate(
|
| 235 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(
|
| 236 |
+
0, 3, 1, 2
|
| 237 |
+
),
|
| 238 |
+
scale_factor=(sx, sy),
|
| 239 |
+
mode="bicubic",
|
| 240 |
+
# antialias=self.interpolate_antialias,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
| 244 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
| 245 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 246 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
| 247 |
+
previous_dtype
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 251 |
+
B, nc, w, h = x.shape
|
| 252 |
+
x = self.patch_embed(x)
|
| 253 |
+
if masks is not None:
|
| 254 |
+
x = torch.where(
|
| 255 |
+
masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 259 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 260 |
+
|
| 261 |
+
if self.register_tokens is not None:
|
| 262 |
+
x = torch.cat(
|
| 263 |
+
(
|
| 264 |
+
x[:, :1],
|
| 265 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 266 |
+
x[:, 1:],
|
| 267 |
+
),
|
| 268 |
+
dim=1,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return x
|
| 272 |
+
|
| 273 |
+
def forward_features_list(self, x_list, masks_list):
|
| 274 |
+
x = [
|
| 275 |
+
self.prepare_tokens_with_masks(x, masks)
|
| 276 |
+
for x, masks in zip(x_list, masks_list)
|
| 277 |
+
]
|
| 278 |
+
for blk in self.blocks:
|
| 279 |
+
x = blk(x)
|
| 280 |
+
|
| 281 |
+
all_x = x
|
| 282 |
+
output = []
|
| 283 |
+
for x, masks in zip(all_x, masks_list):
|
| 284 |
+
cls_tokens = self.norm(x[:, : self.num_register_tokens + 1])
|
| 285 |
+
|
| 286 |
+
if self.norm_patch is None:
|
| 287 |
+
patch_tokens = self.norm(x[:, self.num_register_tokens + 1 :])
|
| 288 |
+
else:
|
| 289 |
+
patch_tokens = self.norm_patch(x[:, self.num_register_tokens + 1 :])
|
| 290 |
+
|
| 291 |
+
output.append(
|
| 292 |
+
{
|
| 293 |
+
"latent": cls_tokens[:, 0],
|
| 294 |
+
"patch_latent": patch_tokens,
|
| 295 |
+
"raw_latent": x[:, 0],
|
| 296 |
+
}
|
| 297 |
+
)
|
| 298 |
+
return output
|
| 299 |
+
|
| 300 |
+
def forward_features(self, x, masks=None, last_self_attention=False):
|
| 301 |
+
if isinstance(x, list):
|
| 302 |
+
return self.forward_features_list(x, masks)
|
| 303 |
+
|
| 304 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 305 |
+
|
| 306 |
+
for i, blk in enumerate(self.blocks):
|
| 307 |
+
if i < len(self.blocks) - 1:
|
| 308 |
+
x = blk(x)
|
| 309 |
+
else:
|
| 310 |
+
x = blk(x, return_attention=last_self_attention)
|
| 311 |
+
|
| 312 |
+
attn = None
|
| 313 |
+
if last_self_attention:
|
| 314 |
+
x, attn = x
|
| 315 |
+
# Attention is selected from the cls token to the patch tokens only
|
| 316 |
+
# Thus, we ignore the cls from the patch tokens (i.e., start from 1)
|
| 317 |
+
attn = attn[:, :, 0, self.num_register_tokens + 1 :]
|
| 318 |
+
|
| 319 |
+
cls_tokens = self.norm(x[:, : self.num_register_tokens + 1])
|
| 320 |
+
|
| 321 |
+
if self.norm_patch is None:
|
| 322 |
+
patch_tokens = self.norm(x[:, self.num_register_tokens + 1 :])
|
| 323 |
+
else:
|
| 324 |
+
patch_tokens = self.norm_patch(x[:, self.num_register_tokens + 1 :])
|
| 325 |
+
|
| 326 |
+
return {
|
| 327 |
+
"latent": cls_tokens[:, 0],
|
| 328 |
+
"patch_latent": patch_tokens,
|
| 329 |
+
"raw_latent": x[:, 0],
|
| 330 |
+
"last_self_attention": attn,
|
| 331 |
+
"logits": self.head(cls_tokens[:, 0]),
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
def forward_head(self, x):
|
| 335 |
+
# Projection with l2-norm bottleneck
|
| 336 |
+
x = self.projection_head(x)
|
| 337 |
+
if self.l2_norm:
|
| 338 |
+
x = nn.functional.normalize(x, dim=1, p=2)
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 342 |
+
x = self.prepare_tokens_with_masks(x)
|
| 343 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 344 |
+
output, total_block_len = [], len(self.blocks)
|
| 345 |
+
blocks_to_take = (
|
| 346 |
+
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 347 |
+
)
|
| 348 |
+
for i, blk in enumerate(self.blocks):
|
| 349 |
+
x = blk(x)
|
| 350 |
+
if i in blocks_to_take:
|
| 351 |
+
output.append(x)
|
| 352 |
+
assert len(output) == len(
|
| 353 |
+
blocks_to_take
|
| 354 |
+
), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 355 |
+
return output
|
| 356 |
+
|
| 357 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 358 |
+
x = self.prepare_tokens_with_masks(x)
|
| 359 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 360 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 361 |
+
blocks_to_take = (
|
| 362 |
+
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 363 |
+
)
|
| 364 |
+
for block_chunk in self.blocks:
|
| 365 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 366 |
+
x = blk(x)
|
| 367 |
+
if i in blocks_to_take:
|
| 368 |
+
output.append(x)
|
| 369 |
+
i += 1
|
| 370 |
+
assert len(output) == len(
|
| 371 |
+
blocks_to_take
|
| 372 |
+
), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 373 |
+
return output
|
| 374 |
+
|
| 375 |
+
def get_intermediate_layers(
|
| 376 |
+
self,
|
| 377 |
+
x: torch.Tensor,
|
| 378 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 379 |
+
reshape: bool = False,
|
| 380 |
+
return_class_token: bool = False,
|
| 381 |
+
norm=True,
|
| 382 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 383 |
+
if self.chunked_blocks:
|
| 384 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 385 |
+
else:
|
| 386 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 387 |
+
|
| 388 |
+
class_tokens = [
|
| 389 |
+
(
|
| 390 |
+
out[:, 0]
|
| 391 |
+
if not norm
|
| 392 |
+
else self.norm(out[:, : 1 + self.num_register_tokens])[:, 0]
|
| 393 |
+
)
|
| 394 |
+
for out in outputs
|
| 395 |
+
]
|
| 396 |
+
outputs = [
|
| 397 |
+
(
|
| 398 |
+
out[:, 1 + self.num_register_tokens :]
|
| 399 |
+
if not norm
|
| 400 |
+
else (
|
| 401 |
+
self.norm(out[:, self.num_register_tokens + 1 :])
|
| 402 |
+
if self.norm_patch is None
|
| 403 |
+
else self.norm_patch(out[:, self.num_register_tokens + 1 :])
|
| 404 |
+
)
|
| 405 |
+
)
|
| 406 |
+
for out in outputs
|
| 407 |
+
]
|
| 408 |
+
|
| 409 |
+
if reshape:
|
| 410 |
+
B, _, w, h = x.shape
|
| 411 |
+
outputs = [
|
| 412 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
|
| 413 |
+
.permute(0, 3, 1, 2)
|
| 414 |
+
.contiguous()
|
| 415 |
+
for out in outputs
|
| 416 |
+
]
|
| 417 |
+
if return_class_token:
|
| 418 |
+
return tuple(zip(outputs, class_tokens))
|
| 419 |
+
return tuple(outputs)
|
| 420 |
+
|
| 421 |
+
def forward(self, xs, masks=None, last_self_attention=False, **kwargs):
|
| 422 |
+
if not (isinstance(xs, list) or isinstance(xs, tuple)):
|
| 423 |
+
return self.forward_features(xs, masks, last_self_attention)
|
| 424 |
+
|
| 425 |
+
if masks is None:
|
| 426 |
+
masks = [None] * len(xs)
|
| 427 |
+
|
| 428 |
+
return self.forward_features_list(xs, masks)
|
| 429 |
+
|
| 430 |
+
def forward_backbone(self, x, last_self_attention=False):
|
| 431 |
+
out_dict = self.forward_features(x, last_self_attention=last_self_attention)
|
| 432 |
+
cls_token = out_dict["latent"]
|
| 433 |
+
x = out_dict["patch_latent"]
|
| 434 |
+
# Combine the cls token and the patch tokens
|
| 435 |
+
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
|
| 436 |
+
if last_self_attention:
|
| 437 |
+
return x, out_dict["last_self_attention"]
|
| 438 |
+
return x
|
| 439 |
+
|
| 440 |
+
def get_last_selfattention(self, x, masks=None):
|
| 441 |
+
"""
|
| 442 |
+
Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| 443 |
+
"""
|
| 444 |
+
if isinstance(x, list):
|
| 445 |
+
raise NotImplementedError("Not implemented for list of inputs")
|
| 446 |
+
# return self.forward_features_list(x, masks)
|
| 447 |
+
|
| 448 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 449 |
+
|
| 450 |
+
# Run through model, at the last block just return the attention.
|
| 451 |
+
for i, blk in enumerate(self.blocks):
|
| 452 |
+
if i < len(self.blocks) - 1:
|
| 453 |
+
x = blk(x)
|
| 454 |
+
else:
|
| 455 |
+
_, attn = blk(x, return_attention=True)
|
| 456 |
+
return attn
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def init_weights_vit(module: nn.Module, name: str = ""):
|
| 460 |
+
if isinstance(module, nn.Linear):
|
| 461 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
| 462 |
+
if module.bias is not None:
|
| 463 |
+
nn.init.zeros_(module.bias)
|
| 464 |
+
if hasattr(module, "bias_mask") and module.bias_mask is not None:
|
| 465 |
+
o = module.out_features
|
| 466 |
+
module.bias_mask.fill_(1)
|
| 467 |
+
module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
|
| 468 |
+
if isinstance(module, nn.LayerNorm):
|
| 469 |
+
module.reset_parameters()
|
| 470 |
+
if isinstance(module, LayerScale):
|
| 471 |
+
module.reset_parameters()
|
| 472 |
+
if isinstance(module, PatchEmbed):
|
| 473 |
+
module.reset_parameters()
|
| 474 |
+
if isinstance(module, RMSNorm):
|
| 475 |
+
module.reset_parameters()
|
hf_src/utils/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
from .dtype import as_torch_dtype
|
| 7 |
+
from .utils import (
|
| 8 |
+
cat_keep_shapes,
|
| 9 |
+
count_parameters,
|
| 10 |
+
fix_random_seeds,
|
| 11 |
+
get_conda_env,
|
| 12 |
+
get_sha,
|
| 13 |
+
named_apply,
|
| 14 |
+
named_replace,
|
| 15 |
+
uncat_with_shapes,
|
| 16 |
+
)
|
hf_src/utils/download.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import hashlib
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def check_sha1(filename, sha1_hash):
|
| 8 |
+
"""Check whether the sha1 hash of the file content matches the expected hash.
|
| 9 |
+
Parameters
|
| 10 |
+
----------
|
| 11 |
+
filename : str
|
| 12 |
+
Path to the file.
|
| 13 |
+
sha1_hash : str
|
| 14 |
+
Expected sha1 hash in hexadecimal digits.
|
| 15 |
+
Returns
|
| 16 |
+
-------
|
| 17 |
+
bool
|
| 18 |
+
Whether the file content matches the expected hash.
|
| 19 |
+
"""
|
| 20 |
+
sha1 = hashlib.sha1()
|
| 21 |
+
with open(filename, "rb") as f:
|
| 22 |
+
while True:
|
| 23 |
+
data = f.read(1048576)
|
| 24 |
+
if not data:
|
| 25 |
+
break
|
| 26 |
+
sha1.update(data)
|
| 27 |
+
|
| 28 |
+
return sha1.hexdigest() == sha1_hash
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def download(url, path=None, overwrite=False, sha1_hash=None):
|
| 32 |
+
"""
|
| 33 |
+
https://github.com/junfu1115/DANet/blob/master/encoding/utils/files.py
|
| 34 |
+
Download a given URL
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
url : str
|
| 38 |
+
URL to download
|
| 39 |
+
path : str, optional
|
| 40 |
+
Destination path to store downloaded file. By default stores to the
|
| 41 |
+
current directory with same name as in url.
|
| 42 |
+
overwrite : bool, optional
|
| 43 |
+
Whether to overwrite destination file if already exists.
|
| 44 |
+
sha1_hash : str, optional
|
| 45 |
+
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
| 46 |
+
but doesn't match.
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
str
|
| 50 |
+
The file path of the downloaded file.
|
| 51 |
+
"""
|
| 52 |
+
if path is None:
|
| 53 |
+
fname = url.split("/")[-1]
|
| 54 |
+
else:
|
| 55 |
+
path = os.path.expanduser(path)
|
| 56 |
+
if os.path.isdir(path):
|
| 57 |
+
fname = os.path.join(path, url.split("/")[-1])
|
| 58 |
+
else:
|
| 59 |
+
fname = path
|
| 60 |
+
|
| 61 |
+
if (
|
| 62 |
+
overwrite
|
| 63 |
+
or not os.path.exists(fname)
|
| 64 |
+
or (sha1_hash and not check_sha1(fname, sha1_hash))
|
| 65 |
+
):
|
| 66 |
+
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
| 67 |
+
if not os.path.exists(dirname):
|
| 68 |
+
os.makedirs(dirname)
|
| 69 |
+
|
| 70 |
+
print("Downloading %s from %s..." % (fname, url))
|
| 71 |
+
r = requests.get(url, stream=True)
|
| 72 |
+
if r.status_code != 200:
|
| 73 |
+
raise RuntimeError("Failed downloading url %s" % url)
|
| 74 |
+
total_length = r.headers.get("content-length")
|
| 75 |
+
with open(fname, "wb") as f:
|
| 76 |
+
if total_length is None: # no content length header
|
| 77 |
+
for chunk in r.iter_content(chunk_size=1024):
|
| 78 |
+
if chunk: # filter out keep-alive new chunks
|
| 79 |
+
f.write(chunk)
|
| 80 |
+
else:
|
| 81 |
+
total_length = int(total_length)
|
| 82 |
+
for chunk in tqdm(
|
| 83 |
+
r.iter_content(chunk_size=1024),
|
| 84 |
+
total=int(total_length / 1024.0 + 0.5),
|
| 85 |
+
unit="KB",
|
| 86 |
+
unit_scale=False,
|
| 87 |
+
dynamic_ncols=True,
|
| 88 |
+
):
|
| 89 |
+
f.write(chunk)
|
| 90 |
+
|
| 91 |
+
if sha1_hash and not check_sha1(fname, sha1_hash):
|
| 92 |
+
raise UserWarning(
|
| 93 |
+
"File {} is downloaded but the content hash does not match. "
|
| 94 |
+
"The repo may be outdated or download may be incomplete. "
|
| 95 |
+
'If the "repo_url" is overridden, consider switching to '
|
| 96 |
+
"the default repo.".format(fname)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return fname
|
hf_src/utils/dtype.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 15 |
+
np.dtype("bool"): torch.bool,
|
| 16 |
+
np.dtype("uint8"): torch.uint8,
|
| 17 |
+
np.dtype("int8"): torch.int8,
|
| 18 |
+
np.dtype("int16"): torch.int16,
|
| 19 |
+
np.dtype("int32"): torch.int32,
|
| 20 |
+
np.dtype("int64"): torch.int64,
|
| 21 |
+
np.dtype("float16"): torch.float16,
|
| 22 |
+
np.dtype("float32"): torch.float32,
|
| 23 |
+
np.dtype("float64"): torch.float64,
|
| 24 |
+
np.dtype("complex64"): torch.complex64,
|
| 25 |
+
np.dtype("complex128"): torch.complex128,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 30 |
+
if isinstance(dtype, torch.dtype):
|
| 31 |
+
return dtype
|
| 32 |
+
if isinstance(dtype, str):
|
| 33 |
+
dtype = np.dtype(dtype)
|
| 34 |
+
assert isinstance(
|
| 35 |
+
dtype, np.dtype
|
| 36 |
+
), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
hf_src/utils/masking.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def complete_mask_randomly_np(mask, num_masking_patches, rng):
|
| 8 |
+
flat = mask.reshape(-1)
|
| 9 |
+
missing = num_masking_patches - flat.sum()
|
| 10 |
+
|
| 11 |
+
if missing <= 0:
|
| 12 |
+
return mask
|
| 13 |
+
|
| 14 |
+
available = np.flatnonzero(~flat)
|
| 15 |
+
chosen = rng.choice(available, size=missing, replace=False)
|
| 16 |
+
flat[chosen] = True
|
| 17 |
+
|
| 18 |
+
return mask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class IBotMasker:
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
input_size,
|
| 25 |
+
num_masking_patches=None,
|
| 26 |
+
min_num_patches=0,
|
| 27 |
+
max_num_patches=None,
|
| 28 |
+
min_aspect=0.3,
|
| 29 |
+
max_aspect=3.33,
|
| 30 |
+
max_tries=10,
|
| 31 |
+
):
|
| 32 |
+
if isinstance(input_size, int):
|
| 33 |
+
input_size = (input_size, input_size)
|
| 34 |
+
|
| 35 |
+
self.h, self.w = input_size
|
| 36 |
+
self.num_patches = self.h * self.w
|
| 37 |
+
|
| 38 |
+
self.min_num_patches = min_num_patches
|
| 39 |
+
self.num_masking_patches = num_masking_patches
|
| 40 |
+
self.max_num_patches = max_num_patches or num_masking_patches
|
| 41 |
+
|
| 42 |
+
self.log_min_aspect = np.log(min_aspect)
|
| 43 |
+
self.log_max_aspect = np.log(max_aspect or 1 / min_aspect)
|
| 44 |
+
|
| 45 |
+
self.max_tries = max_tries
|
| 46 |
+
|
| 47 |
+
def __call__(self, num_masking_patches, starting_mask=None, rng=None):
|
| 48 |
+
if rng is None:
|
| 49 |
+
rng = np.random.default_rng()
|
| 50 |
+
|
| 51 |
+
if starting_mask is None:
|
| 52 |
+
mask = np.zeros((self.h, self.w), dtype=np.bool_)
|
| 53 |
+
else:
|
| 54 |
+
mask = starting_mask.copy()
|
| 55 |
+
|
| 56 |
+
mask_count = mask.sum()
|
| 57 |
+
|
| 58 |
+
while mask_count < num_masking_patches:
|
| 59 |
+
max_mask = num_masking_patches - mask_count
|
| 60 |
+
if self.max_num_patches is not None:
|
| 61 |
+
max_mask = min(max_mask, self.max_num_patches)
|
| 62 |
+
|
| 63 |
+
delta = self._mask(mask, max_mask, rng)
|
| 64 |
+
if delta == 0:
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
mask_count += delta
|
| 68 |
+
|
| 69 |
+
return complete_mask_randomly_np(mask, num_masking_patches, rng)
|
| 70 |
+
|
| 71 |
+
def _mask(self, mask, max_mask_patches, rng):
|
| 72 |
+
for _ in range(self.max_tries):
|
| 73 |
+
target = rng.uniform(self.min_num_patches, max_mask_patches)
|
| 74 |
+
aspect = np.exp(rng.uniform(self.log_min_aspect, self.log_max_aspect))
|
| 75 |
+
|
| 76 |
+
h = int(round(np.sqrt(target * aspect)))
|
| 77 |
+
w = int(round(np.sqrt(target / aspect)))
|
| 78 |
+
|
| 79 |
+
if h <= 0 or w <= 0 or h >= self.h or w >= self.w:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
top = rng.integers(0, self.h - h + 1)
|
| 83 |
+
left = rng.integers(0, self.w - w + 1)
|
| 84 |
+
|
| 85 |
+
region = mask[top : top + h, left : left + w]
|
| 86 |
+
newly = (~region).sum()
|
| 87 |
+
|
| 88 |
+
if 0 < newly <= max_mask_patches:
|
| 89 |
+
region[:] = True
|
| 90 |
+
return newly
|
| 91 |
+
|
| 92 |
+
return 0
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def generate_masks(
|
| 96 |
+
mask_generator, number_of_samples, mask_prob=0.5, per_sample_range=(0.1, 0.5)
|
| 97 |
+
):
|
| 98 |
+
num_masks = int(number_of_samples * mask_prob)
|
| 99 |
+
num_tokens = mask_generator.num_patches
|
| 100 |
+
prob_per_sample = np.linspace(*per_sample_range, num=num_masks)
|
| 101 |
+
masks = [
|
| 102 |
+
(
|
| 103 |
+
mask_generator(num_masking_patches=int(prob_per_sample[i] * num_tokens))
|
| 104 |
+
if i < num_masks
|
| 105 |
+
else mask_generator(num_masking_patches=0)
|
| 106 |
+
)
|
| 107 |
+
for i in range(number_of_samples)
|
| 108 |
+
]
|
| 109 |
+
random.shuffle(masks)
|
| 110 |
+
masks = np.stack(masks, dtype=bool)
|
| 111 |
+
masks = torch.from_numpy(masks).flatten(1, -1)
|
| 112 |
+
|
| 113 |
+
return masks
|
hf_src/utils/seedlet_masking.py
ADDED
|
File without changes
|
hf_src/utils/utils.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This software may be used and distributed in accordance with
|
| 4 |
+
# the terms of the DINOv3 License Agreement.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import subprocess
|
| 10 |
+
from typing import Callable, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("dinov3")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
|
| 20 |
+
shapes = [x.shape for x in x_list]
|
| 21 |
+
num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
|
| 22 |
+
flattened = torch.cat([x.flatten(0, -2) for x in x_list])
|
| 23 |
+
return flattened, shapes, num_tokens
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def uncat_with_shapes(
|
| 27 |
+
flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]
|
| 28 |
+
) -> List[Tensor]:
|
| 29 |
+
outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
|
| 30 |
+
shapes_adjusted = [
|
| 31 |
+
shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes
|
| 32 |
+
]
|
| 33 |
+
outputs_reshaped = [
|
| 34 |
+
o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)
|
| 35 |
+
]
|
| 36 |
+
return outputs_reshaped
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def named_replace(
|
| 40 |
+
fn: Callable,
|
| 41 |
+
module: nn.Module,
|
| 42 |
+
name: str = "",
|
| 43 |
+
depth_first: bool = True,
|
| 44 |
+
include_root: bool = False,
|
| 45 |
+
) -> nn.Module:
|
| 46 |
+
if not depth_first and include_root:
|
| 47 |
+
module = fn(module=module, name=name)
|
| 48 |
+
for child_name_o, child_module in list(module.named_children()):
|
| 49 |
+
child_name = ".".join((name, child_name_o)) if name else child_name_o
|
| 50 |
+
new_child = named_replace(
|
| 51 |
+
fn=fn,
|
| 52 |
+
module=child_module,
|
| 53 |
+
name=child_name,
|
| 54 |
+
depth_first=depth_first,
|
| 55 |
+
include_root=True,
|
| 56 |
+
)
|
| 57 |
+
setattr(module, child_name_o, new_child)
|
| 58 |
+
|
| 59 |
+
if depth_first and include_root:
|
| 60 |
+
module = fn(module=module, name=name)
|
| 61 |
+
return module
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def named_apply(
|
| 65 |
+
fn: Callable,
|
| 66 |
+
module: nn.Module,
|
| 67 |
+
name: str = "",
|
| 68 |
+
depth_first: bool = True,
|
| 69 |
+
include_root: bool = False,
|
| 70 |
+
) -> nn.Module:
|
| 71 |
+
if not depth_first and include_root:
|
| 72 |
+
fn(module=module, name=name)
|
| 73 |
+
for child_name, child_module in module.named_children():
|
| 74 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 75 |
+
named_apply(
|
| 76 |
+
fn=fn,
|
| 77 |
+
module=child_module,
|
| 78 |
+
name=child_name,
|
| 79 |
+
depth_first=depth_first,
|
| 80 |
+
include_root=True,
|
| 81 |
+
)
|
| 82 |
+
if depth_first and include_root:
|
| 83 |
+
fn(module=module, name=name)
|
| 84 |
+
return module
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def fix_random_seeds(seed: int = 31):
|
| 88 |
+
"""
|
| 89 |
+
Fix random seeds.
|
| 90 |
+
"""
|
| 91 |
+
torch.manual_seed(seed)
|
| 92 |
+
torch.cuda.manual_seed_all(seed)
|
| 93 |
+
np.random.seed(seed)
|
| 94 |
+
random.seed(seed)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_sha() -> str:
|
| 98 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 99 |
+
|
| 100 |
+
def _run(command):
|
| 101 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
| 102 |
+
|
| 103 |
+
sha = "N/A"
|
| 104 |
+
diff = "clean"
|
| 105 |
+
branch = "N/A"
|
| 106 |
+
try:
|
| 107 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
| 108 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
| 109 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
| 110 |
+
diff = "has uncommited changes" if diff else "clean"
|
| 111 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| 112 |
+
except Exception:
|
| 113 |
+
pass
|
| 114 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 115 |
+
return message
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_conda_env() -> Tuple[Optional[str], Optional[str]]:
|
| 119 |
+
conda_env_name = os.environ.get("CONDA_DEFAULT_ENV")
|
| 120 |
+
conda_env_path = os.environ.get("CONDA_PREFIX")
|
| 121 |
+
return conda_env_name, conda_env_path
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def count_parameters(module: nn.Module) -> int:
|
| 125 |
+
c = 0
|
| 126 |
+
for m in module.parameters():
|
| 127 |
+
c += m.nelement()
|
| 128 |
+
return c
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def has_batchnorms(model: nn.Module) -> bool:
|
| 132 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 133 |
+
for _, module in model.named_modules():
|
| 134 |
+
if isinstance(module, bn_types):
|
| 135 |
+
return True
|
| 136 |
+
return False
|
modelling_vitv2.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from transformers import PreTrainedModel
|
| 6 |
+
|
| 7 |
+
from configuration_vitv2 import ViTv2Config
|
| 8 |
+
from hf_src.model.image.vitv2.transformer import ViTv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ViTv2PretrainedModel(PreTrainedModel):
|
| 12 |
+
config_class = ViTv2Config
|
| 13 |
+
|
| 14 |
+
def __init__(self, config: ViTv2Config):
|
| 15 |
+
|
| 16 |
+
super().__init__(config)
|
| 17 |
+
|
| 18 |
+
self.backbone = ViTv2(
|
| 19 |
+
img_size=config.img_size,
|
| 20 |
+
patch_size=config.patch_size,
|
| 21 |
+
embed_dim=config.embed_dim,
|
| 22 |
+
depth=config.depth,
|
| 23 |
+
num_heads=config.num_heads,
|
| 24 |
+
mlp_ratio=config.mlp_ratio,
|
| 25 |
+
init_values=config.init_values,
|
| 26 |
+
num_register_tokens=config.num_register_tokens,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.post_init()
|
| 30 |
+
|
| 31 |
+
def forward(self, *args, **kwargs) -> dict[str, Union[torch.Tensor, None]]:
|
| 32 |
+
return self.backbone(*args, **kwargs)
|
requirements.txt
CHANGED
|
@@ -5,3 +5,5 @@ transformers>=4.38.0
|
|
| 5 |
scikit-learn>=1.3.0
|
| 6 |
Pillow>=9.0.0
|
| 7 |
numpy>=1.24.0
|
|
|
|
|
|
|
|
|
| 5 |
scikit-learn>=1.3.0
|
| 6 |
Pillow>=9.0.0
|
| 7 |
numpy>=1.24.0
|
| 8 |
+
einops
|
| 9 |
+
opt_einsum
|