jihao
commited on
Commit
·
f73bf08
1
Parent(s):
9d021bd
update eval files
Browse files- eva_vit_model/__init__.py +1 -0
- eva_vit_model/eva_vit.py +575 -0
- eva_vit_model/rope.py +137 -0
- eva_vit_model/transformer.py +625 -0
- eva_vit_model/uta_clip.py +31 -0
- imagenet_zeroshot_data.py +254 -0
- imagenet_zeroshot_eval.py +108 -0
- requirements.txt +6 -0
eva_vit_model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .uta_clip import CLIP
|
eva_vit_model/eva_vit.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from functools import partial
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
try:
|
| 11 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
| 12 |
+
except:
|
| 13 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
| 14 |
+
|
| 15 |
+
from .transformer import PatchDropout, LayerNorm
|
| 16 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
| 17 |
+
|
| 18 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
| 19 |
+
try:
|
| 20 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
| 21 |
+
except:
|
| 22 |
+
from torch.utils.checkpoint import checkpoint
|
| 23 |
+
else:
|
| 24 |
+
from torch.utils.checkpoint import checkpoint
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import xformers.ops as xops
|
| 28 |
+
except ImportError:
|
| 29 |
+
xops = None
|
| 30 |
+
print("Please 'pip install xformers'")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DropPath(nn.Module):
|
| 34 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, drop_prob=None):
|
| 37 |
+
super(DropPath, self).__init__()
|
| 38 |
+
self.drop_prob = drop_prob
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 42 |
+
|
| 43 |
+
def extra_repr(self) -> str:
|
| 44 |
+
return 'p={}'.format(self.drop_prob)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Mlp(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
in_features,
|
| 51 |
+
hidden_features=None,
|
| 52 |
+
out_features=None,
|
| 53 |
+
act_layer=nn.GELU,
|
| 54 |
+
norm_layer=nn.LayerNorm,
|
| 55 |
+
drop=0.,
|
| 56 |
+
subln=False,
|
| 57 |
+
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
out_features = out_features or in_features
|
| 61 |
+
hidden_features = hidden_features or in_features
|
| 62 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 63 |
+
self.act = act_layer()
|
| 64 |
+
|
| 65 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
| 66 |
+
|
| 67 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 68 |
+
self.drop = nn.Dropout(drop)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
x = self.fc1(x)
|
| 72 |
+
x = self.act(x)
|
| 73 |
+
# x = self.drop(x)
|
| 74 |
+
# commit this for the orignal BERT implement
|
| 75 |
+
x = self.ffn_ln(x)
|
| 76 |
+
|
| 77 |
+
x = self.fc2(x)
|
| 78 |
+
x = self.drop(x)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
class SwiGLU(nn.Module):
|
| 82 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
| 83 |
+
norm_layer=nn.LayerNorm, subln=False):
|
| 84 |
+
super().__init__()
|
| 85 |
+
out_features = out_features or in_features
|
| 86 |
+
hidden_features = hidden_features or in_features
|
| 87 |
+
|
| 88 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
| 89 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
| 90 |
+
|
| 91 |
+
self.act = act_layer()
|
| 92 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
| 93 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
| 94 |
+
|
| 95 |
+
self.drop = nn.Dropout(drop)
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
x1 = self.w1(x)
|
| 99 |
+
x2 = self.w2(x)
|
| 100 |
+
hidden = self.act(x1) * x2
|
| 101 |
+
x = self.ffn_ln(hidden)
|
| 102 |
+
x = self.w3(x)
|
| 103 |
+
x = self.drop(x)
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
class Attention(nn.Module):
|
| 107 |
+
def __init__(
|
| 108 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 109 |
+
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.num_heads = num_heads
|
| 112 |
+
head_dim = dim // num_heads
|
| 113 |
+
if attn_head_dim is not None:
|
| 114 |
+
head_dim = attn_head_dim
|
| 115 |
+
all_head_dim = head_dim * self.num_heads
|
| 116 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 117 |
+
|
| 118 |
+
self.subln = subln
|
| 119 |
+
if self.subln:
|
| 120 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 121 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 122 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 123 |
+
else:
|
| 124 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 125 |
+
|
| 126 |
+
if qkv_bias:
|
| 127 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 128 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 129 |
+
else:
|
| 130 |
+
self.q_bias = None
|
| 131 |
+
self.v_bias = None
|
| 132 |
+
|
| 133 |
+
if window_size:
|
| 134 |
+
self.window_size = window_size
|
| 135 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 136 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 137 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 138 |
+
# cls to token & token 2 cls & cls to cls
|
| 139 |
+
|
| 140 |
+
# get pair-wise relative position index for each token inside the window
|
| 141 |
+
coords_h = torch.arange(window_size[0])
|
| 142 |
+
coords_w = torch.arange(window_size[1])
|
| 143 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 144 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 145 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 146 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 147 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 148 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 149 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 150 |
+
relative_position_index = \
|
| 151 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
| 152 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 153 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
| 154 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
| 155 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
| 156 |
+
|
| 157 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 158 |
+
else:
|
| 159 |
+
self.window_size = None
|
| 160 |
+
self.relative_position_bias_table = None
|
| 161 |
+
self.relative_position_index = None
|
| 162 |
+
|
| 163 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 164 |
+
# self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
| 165 |
+
self.inner_attn_ln = nn.Identity()
|
| 166 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
| 167 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 168 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 169 |
+
self.xattn = xattn
|
| 170 |
+
self.xattn_drop = attn_drop
|
| 171 |
+
|
| 172 |
+
self.rope = rope
|
| 173 |
+
|
| 174 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
| 175 |
+
B, N, C = x.shape
|
| 176 |
+
if self.subln:
|
| 177 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
| 178 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
| 179 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
| 180 |
+
|
| 181 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
| 182 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
| 183 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
| 184 |
+
else:
|
| 185 |
+
|
| 186 |
+
qkv_bias = None
|
| 187 |
+
if self.q_bias is not None:
|
| 188 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 189 |
+
|
| 190 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 191 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
| 192 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 193 |
+
|
| 194 |
+
if self.rope:
|
| 195 |
+
# slightly fast impl
|
| 196 |
+
q_t = q[:, :, 1:, :]
|
| 197 |
+
ro_q_t = self.rope(q_t)
|
| 198 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
| 199 |
+
|
| 200 |
+
k_t = k[:, :, 1:, :]
|
| 201 |
+
ro_k_t = self.rope(k_t)
|
| 202 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
| 203 |
+
|
| 204 |
+
if self.xattn:
|
| 205 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
| 206 |
+
k = k.permute(0, 2, 1, 3)
|
| 207 |
+
v = v.permute(0, 2, 1, 3)
|
| 208 |
+
|
| 209 |
+
x = xops.memory_efficient_attention(
|
| 210 |
+
q, k, v,
|
| 211 |
+
p=self.xattn_drop,
|
| 212 |
+
scale=self.scale,
|
| 213 |
+
)
|
| 214 |
+
x = x.reshape(B, N, -1)
|
| 215 |
+
x = self.inner_attn_ln(x)
|
| 216 |
+
x = self.proj(x)
|
| 217 |
+
x = self.proj_drop(x)
|
| 218 |
+
else:
|
| 219 |
+
q = q * self.scale
|
| 220 |
+
attn = (q @ k.transpose(-2, -1))
|
| 221 |
+
|
| 222 |
+
if self.relative_position_bias_table is not None:
|
| 223 |
+
relative_position_bias = \
|
| 224 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 225 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 226 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 227 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 228 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
| 229 |
+
|
| 230 |
+
if rel_pos_bias is not None:
|
| 231 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
| 232 |
+
|
| 233 |
+
if attn_mask is not None:
|
| 234 |
+
attn_mask = attn_mask.bool()
|
| 235 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
| 236 |
+
|
| 237 |
+
attn = attn.softmax(dim=-1)
|
| 238 |
+
attn = self.attn_drop(attn)
|
| 239 |
+
|
| 240 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 241 |
+
x = self.inner_attn_ln(x)
|
| 242 |
+
x = self.proj(x)
|
| 243 |
+
x = self.proj_drop(x)
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Block(nn.Module):
|
| 248 |
+
|
| 249 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 250 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
| 251 |
+
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
|
| 252 |
+
subln=False, naiveswiglu=False):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.norm1 = norm_layer(dim)
|
| 255 |
+
self.attn = Attention(
|
| 256 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 257 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
|
| 258 |
+
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
|
| 259 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 260 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 261 |
+
self.norm2 = norm_layer(dim)
|
| 262 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 263 |
+
|
| 264 |
+
if naiveswiglu:
|
| 265 |
+
self.mlp = SwiGLU(
|
| 266 |
+
in_features=dim,
|
| 267 |
+
hidden_features=mlp_hidden_dim,
|
| 268 |
+
subln=subln,
|
| 269 |
+
norm_layer=norm_layer,
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
self.mlp = Mlp(
|
| 273 |
+
in_features=dim,
|
| 274 |
+
hidden_features=mlp_hidden_dim,
|
| 275 |
+
act_layer=act_layer,
|
| 276 |
+
subln=subln,
|
| 277 |
+
drop=drop
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if init_values is not None and init_values > 0:
|
| 281 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 282 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 283 |
+
else:
|
| 284 |
+
self.gamma_1, self.gamma_2 = None, None
|
| 285 |
+
|
| 286 |
+
self.postnorm = postnorm
|
| 287 |
+
|
| 288 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
| 289 |
+
if self.gamma_1 is None:
|
| 290 |
+
if self.postnorm:
|
| 291 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
| 292 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
| 293 |
+
else:
|
| 294 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
| 295 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 296 |
+
else:
|
| 297 |
+
if self.postnorm:
|
| 298 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
| 299 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
| 300 |
+
else:
|
| 301 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
| 302 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 303 |
+
return x
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class PatchEmbed(nn.Module):
|
| 307 |
+
""" Image to Patch Embedding
|
| 308 |
+
"""
|
| 309 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
| 310 |
+
super().__init__()
|
| 311 |
+
img_size = to_2tuple(img_size)
|
| 312 |
+
patch_size = to_2tuple(patch_size)
|
| 313 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 314 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 315 |
+
self.img_size = img_size
|
| 316 |
+
self.patch_size = patch_size
|
| 317 |
+
self.num_patches = num_patches
|
| 318 |
+
|
| 319 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 320 |
+
|
| 321 |
+
def forward(self, x, **kwargs):
|
| 322 |
+
B, C, H, W = x.shape
|
| 323 |
+
# FIXME look at relaxing size constraints
|
| 324 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 325 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 326 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class RelativePositionBias(nn.Module):
|
| 331 |
+
|
| 332 |
+
def __init__(self, window_size, num_heads):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.window_size = window_size
|
| 335 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 336 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 337 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 338 |
+
# cls to token & token 2 cls & cls to cls
|
| 339 |
+
|
| 340 |
+
# get pair-wise relative position index for each token inside the window
|
| 341 |
+
coords_h = torch.arange(window_size[0])
|
| 342 |
+
coords_w = torch.arange(window_size[1])
|
| 343 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 344 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 345 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 346 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 347 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 348 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 349 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 350 |
+
relative_position_index = \
|
| 351 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
| 352 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 353 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
| 354 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
| 355 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
| 356 |
+
|
| 357 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 358 |
+
|
| 359 |
+
def forward(self):
|
| 360 |
+
relative_position_bias = \
|
| 361 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 362 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 363 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 364 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class EVAVisionTransformer(nn.Module):
|
| 368 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 369 |
+
"""
|
| 370 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 371 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 372 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
|
| 373 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
|
| 374 |
+
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
|
| 375 |
+
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False, head_2mlp=False):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.image_size = img_size
|
| 378 |
+
self.num_classes = num_classes
|
| 379 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 380 |
+
self.head_2mlp = head_2mlp
|
| 381 |
+
|
| 382 |
+
self.patch_embed = PatchEmbed(
|
| 383 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 384 |
+
num_patches = self.patch_embed.num_patches
|
| 385 |
+
|
| 386 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 387 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 388 |
+
if use_abs_pos_emb:
|
| 389 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 390 |
+
else:
|
| 391 |
+
self.pos_embed = None
|
| 392 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 393 |
+
|
| 394 |
+
if use_shared_rel_pos_bias:
|
| 395 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
| 396 |
+
else:
|
| 397 |
+
self.rel_pos_bias = None
|
| 398 |
+
|
| 399 |
+
if rope:
|
| 400 |
+
half_head_dim = embed_dim // num_heads // 2
|
| 401 |
+
hw_seq_len = img_size // patch_size
|
| 402 |
+
self.rope = VisionRotaryEmbeddingFast(
|
| 403 |
+
dim=half_head_dim,
|
| 404 |
+
pt_seq_len=pt_hw_seq_len,
|
| 405 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
| 406 |
+
# patch_dropout=patch_dropout
|
| 407 |
+
)
|
| 408 |
+
else:
|
| 409 |
+
self.rope = None
|
| 410 |
+
|
| 411 |
+
self.naiveswiglu = naiveswiglu
|
| 412 |
+
|
| 413 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 414 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
| 415 |
+
self.blocks = nn.ModuleList([
|
| 416 |
+
Block(
|
| 417 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 418 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 419 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
| 420 |
+
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
|
| 421 |
+
for i in range(depth)])
|
| 422 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
| 423 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
| 424 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 425 |
+
|
| 426 |
+
if self.pos_embed is not None:
|
| 427 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 428 |
+
|
| 429 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 430 |
+
# trunc_normal_(self.mask_token, std=.02)
|
| 431 |
+
|
| 432 |
+
self.apply(self._init_weights)
|
| 433 |
+
self.fix_init_weight()
|
| 434 |
+
|
| 435 |
+
if isinstance(self.head, nn.Linear):
|
| 436 |
+
trunc_normal_(self.head.weight, std=.02)
|
| 437 |
+
self.head.weight.data.mul_(init_scale)
|
| 438 |
+
self.head.bias.data.mul_(init_scale)
|
| 439 |
+
|
| 440 |
+
if head_2mlp:
|
| 441 |
+
self.proj = nn.Linear(embed_dim, 512)
|
| 442 |
+
self.out_norm = norm_layer(512)
|
| 443 |
+
self.head_clip = nn.Linear(512, num_classes)
|
| 444 |
+
del self.head
|
| 445 |
+
|
| 446 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
| 447 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
| 448 |
+
|
| 449 |
+
self.grad_checkpointing = grad_checkpointing
|
| 450 |
+
|
| 451 |
+
def fix_init_weight(self):
|
| 452 |
+
def rescale(param, layer_id):
|
| 453 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 454 |
+
|
| 455 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 456 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 457 |
+
if self.naiveswiglu:
|
| 458 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
| 459 |
+
else:
|
| 460 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 461 |
+
|
| 462 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 463 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
| 464 |
+
|
| 465 |
+
def _init_weights(self, m):
|
| 466 |
+
if isinstance(m, nn.Linear):
|
| 467 |
+
trunc_normal_(m.weight, std=.02)
|
| 468 |
+
if m.bias is not None:
|
| 469 |
+
nn.init.constant_(m.bias, 0)
|
| 470 |
+
elif isinstance(m, nn.LayerNorm):
|
| 471 |
+
nn.init.constant_(m.bias, 0)
|
| 472 |
+
nn.init.constant_(m.weight, 1.0)
|
| 473 |
+
|
| 474 |
+
def get_num_layers(self):
|
| 475 |
+
return len(self.blocks)
|
| 476 |
+
|
| 477 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 478 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
| 479 |
+
for param in self.parameters():
|
| 480 |
+
param.requires_grad = False
|
| 481 |
+
|
| 482 |
+
@torch.jit.ignore
|
| 483 |
+
def set_grad_checkpointing(self, enable=True):
|
| 484 |
+
self.grad_checkpointing = enable
|
| 485 |
+
|
| 486 |
+
@torch.jit.ignore
|
| 487 |
+
def no_weight_decay(self):
|
| 488 |
+
return {'pos_embed', 'cls_token'}
|
| 489 |
+
|
| 490 |
+
def get_classifier(self):
|
| 491 |
+
return self.head
|
| 492 |
+
|
| 493 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 494 |
+
self.num_classes = num_classes
|
| 495 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 496 |
+
|
| 497 |
+
def forward_features(self, x, return_all_features=False):
|
| 498 |
+
|
| 499 |
+
x = self.patch_embed(x)
|
| 500 |
+
batch_size, seq_len, _ = x.size()
|
| 501 |
+
|
| 502 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 503 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 504 |
+
if self.pos_embed is not None:
|
| 505 |
+
x = x + self.pos_embed
|
| 506 |
+
x = self.pos_drop(x)
|
| 507 |
+
|
| 508 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
| 509 |
+
if os.getenv('RoPE') == '1':
|
| 510 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
| 511 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
| 512 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
| 513 |
+
else:
|
| 514 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
| 515 |
+
x = self.patch_dropout(x)
|
| 516 |
+
else:
|
| 517 |
+
x = self.patch_dropout(x)
|
| 518 |
+
|
| 519 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
| 520 |
+
for blk in self.blocks:
|
| 521 |
+
if self.grad_checkpointing:
|
| 522 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
| 523 |
+
else:
|
| 524 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
| 525 |
+
|
| 526 |
+
if not return_all_features:
|
| 527 |
+
x = self.norm(x)
|
| 528 |
+
if self.fc_norm is not None:
|
| 529 |
+
return self.fc_norm(x.mean(1))
|
| 530 |
+
else:
|
| 531 |
+
return x[:, 0]
|
| 532 |
+
return x
|
| 533 |
+
|
| 534 |
+
def forward(self, x, return_all_features=False):
|
| 535 |
+
if return_all_features:
|
| 536 |
+
return self.forward_features(x, return_all_features)
|
| 537 |
+
x = self.forward_features(x)
|
| 538 |
+
if self.head_2mlp:
|
| 539 |
+
x = self.proj(x)
|
| 540 |
+
x = self.out_norm(x)
|
| 541 |
+
x = self.head_clip(x)
|
| 542 |
+
else:
|
| 543 |
+
x = self.head(x)
|
| 544 |
+
return x
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def eva_base_p16():
|
| 548 |
+
model = EVAVisionTransformer(
|
| 549 |
+
depth=12, embed_dim=768, num_heads=12, mlp_ratio=2.6667, num_classes=1024,
|
| 550 |
+
xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
|
| 551 |
+
subln=True, use_mean_pooling=False, qkv_bias=True,
|
| 552 |
+
norm_layer=partial(LayerNorm, eps=1e-6)
|
| 553 |
+
)
|
| 554 |
+
return model
|
| 555 |
+
|
| 556 |
+
def eva_large_p14_336():
|
| 557 |
+
model = EVAVisionTransformer(
|
| 558 |
+
img_size=336,
|
| 559 |
+
depth=24, embed_dim=1024, num_heads=16, mlp_ratio=2.6667,patch_size=14, num_classes=1024,
|
| 560 |
+
xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
|
| 561 |
+
subln=True, use_mean_pooling=False, qkv_bias=True,
|
| 562 |
+
norm_layer=partial(LayerNorm, eps=1e-6)
|
| 563 |
+
)
|
| 564 |
+
return model
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def eva_giant_p14_336():
|
| 568 |
+
model = EVAVisionTransformer(
|
| 569 |
+
img_size=336,
|
| 570 |
+
depth=40, embed_dim=1408, num_heads=16, mlp_ratio=2.909133333333333,patch_size=14, num_classes=1024,
|
| 571 |
+
xattn=True, rope=True, intp_freq=True, naiveswiglu=True,
|
| 572 |
+
subln=True, use_mean_pooling=False, qkv_bias=True,
|
| 573 |
+
norm_layer=partial(LayerNorm, eps=1e-6)
|
| 574 |
+
)
|
| 575 |
+
return model
|
eva_vit_model/rope.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import pi
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
def broadcat(tensors, dim = -1):
|
| 8 |
+
num_tensors = len(tensors)
|
| 9 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 10 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
| 11 |
+
shape_len = list(shape_lens)[0]
|
| 12 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 13 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 14 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 15 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
| 16 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 17 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 18 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 19 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 20 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 21 |
+
return torch.cat(tensors, dim = dim)
|
| 22 |
+
|
| 23 |
+
def rotate_half(x):
|
| 24 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 25 |
+
x1, x2 = x.unbind(dim = -1)
|
| 26 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 27 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
dim,
|
| 34 |
+
pt_seq_len,
|
| 35 |
+
ft_seq_len=None,
|
| 36 |
+
custom_freqs = None,
|
| 37 |
+
freqs_for = 'lang',
|
| 38 |
+
theta = 10000,
|
| 39 |
+
max_freq = 10,
|
| 40 |
+
num_freqs = 1,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
if custom_freqs:
|
| 44 |
+
freqs = custom_freqs
|
| 45 |
+
elif freqs_for == 'lang':
|
| 46 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 47 |
+
elif freqs_for == 'pixel':
|
| 48 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 49 |
+
elif freqs_for == 'constant':
|
| 50 |
+
freqs = torch.ones(num_freqs).float()
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 53 |
+
|
| 54 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 55 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 56 |
+
|
| 57 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
| 58 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
| 59 |
+
|
| 60 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
| 61 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
| 62 |
+
|
| 63 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
| 64 |
+
|
| 65 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 66 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 67 |
+
|
| 68 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
| 69 |
+
|
| 70 |
+
def forward(self, t, start_index = 0):
|
| 71 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 72 |
+
end_index = start_index + rot_dim
|
| 73 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 74 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
| 75 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 76 |
+
|
| 77 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
| 78 |
+
|
| 79 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
dim,
|
| 83 |
+
pt_seq_len,
|
| 84 |
+
ft_seq_len=None,
|
| 85 |
+
custom_freqs = None,
|
| 86 |
+
freqs_for = 'lang',
|
| 87 |
+
theta = 10000,
|
| 88 |
+
max_freq = 10,
|
| 89 |
+
num_freqs = 1,
|
| 90 |
+
patch_dropout = 0.
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
if custom_freqs:
|
| 94 |
+
freqs = custom_freqs
|
| 95 |
+
elif freqs_for == 'lang':
|
| 96 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 97 |
+
elif freqs_for == 'pixel':
|
| 98 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 99 |
+
elif freqs_for == 'constant':
|
| 100 |
+
freqs = torch.ones(num_freqs).float()
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 103 |
+
|
| 104 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 105 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 106 |
+
|
| 107 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
| 108 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 109 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
| 110 |
+
|
| 111 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 112 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 113 |
+
|
| 114 |
+
self.patch_dropout = patch_dropout
|
| 115 |
+
|
| 116 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 117 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 118 |
+
|
| 119 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
| 120 |
+
|
| 121 |
+
def forward(self, t, patch_indices_keep=None):
|
| 122 |
+
if patch_indices_keep is not None:
|
| 123 |
+
batch = t.size()[0]
|
| 124 |
+
batch_indices = torch.arange(batch)
|
| 125 |
+
batch_indices = batch_indices[..., None]
|
| 126 |
+
|
| 127 |
+
freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
| 128 |
+
freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
| 129 |
+
|
| 130 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
| 131 |
+
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
|
| 132 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
| 133 |
+
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
|
| 134 |
+
|
| 135 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
| 136 |
+
|
| 137 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
eva_vit_model/transformer.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import math
|
| 5 |
+
from typing import Callable, Optional, Sequence
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
| 12 |
+
try:
|
| 13 |
+
import deepspeed
|
| 14 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
| 15 |
+
except:
|
| 16 |
+
print("Please 'pip install deepspeed'")
|
| 17 |
+
deepspeed = None
|
| 18 |
+
from torch.utils.checkpoint import checkpoint
|
| 19 |
+
else:
|
| 20 |
+
from torch.utils.checkpoint import checkpoint
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import xformers.ops as xops
|
| 24 |
+
except ImportError:
|
| 25 |
+
xops = None
|
| 26 |
+
print("Please 'pip install xformers'")
|
| 27 |
+
|
| 28 |
+
class LayerNormFp32(nn.LayerNorm):
|
| 29 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor):
|
| 34 |
+
output = F.layer_norm(
|
| 35 |
+
x.float(),
|
| 36 |
+
self.normalized_shape,
|
| 37 |
+
self.weight.float() if self.weight is not None else None,
|
| 38 |
+
self.bias.float() if self.bias is not None else None,
|
| 39 |
+
self.eps,
|
| 40 |
+
)
|
| 41 |
+
return output.type_as(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LayerNorm(nn.LayerNorm):
|
| 45 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
orig_type = x.dtype
|
| 49 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 50 |
+
return x.to(orig_type)
|
| 51 |
+
|
| 52 |
+
class QuickGELU(nn.Module):
|
| 53 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
| 54 |
+
def forward(self, x: torch.Tensor):
|
| 55 |
+
return x * torch.sigmoid(1.702 * x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class LayerScale(nn.Module):
|
| 59 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.inplace = inplace
|
| 62 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 66 |
+
|
| 67 |
+
class PatchDropout(nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
https://arxiv.org/abs/2212.00794
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, prob, exclude_first_token=True):
|
| 73 |
+
super().__init__()
|
| 74 |
+
assert 0 <= prob < 1.
|
| 75 |
+
self.prob = prob
|
| 76 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
| 77 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
if not self.training or self.prob == 0.:
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
if self.exclude_first_token:
|
| 84 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
| 85 |
+
else:
|
| 86 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
| 87 |
+
|
| 88 |
+
batch = x.size()[0]
|
| 89 |
+
num_tokens = x.size()[1]
|
| 90 |
+
|
| 91 |
+
batch_indices = torch.arange(batch)
|
| 92 |
+
batch_indices = batch_indices[..., None]
|
| 93 |
+
|
| 94 |
+
keep_prob = 1 - self.prob
|
| 95 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
| 96 |
+
|
| 97 |
+
rand = torch.randn(batch, num_tokens)
|
| 98 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
| 99 |
+
|
| 100 |
+
x = x[batch_indices, patch_indices_keep]
|
| 101 |
+
|
| 102 |
+
if self.exclude_first_token:
|
| 103 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 104 |
+
|
| 105 |
+
if self.training and os.getenv('RoPE') == '1':
|
| 106 |
+
return x, patch_indices_keep
|
| 107 |
+
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _in_projection_packed(
|
| 112 |
+
q: torch.Tensor,
|
| 113 |
+
k: torch.Tensor,
|
| 114 |
+
v: torch.Tensor,
|
| 115 |
+
w: torch.Tensor,
|
| 116 |
+
b: Optional[torch.Tensor] = None,
|
| 117 |
+
):
|
| 118 |
+
"""
|
| 119 |
+
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
|
| 120 |
+
"""
|
| 121 |
+
E = q.size(-1)
|
| 122 |
+
if k is v:
|
| 123 |
+
if q is k:
|
| 124 |
+
# self-attention
|
| 125 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
| 126 |
+
else:
|
| 127 |
+
# encoder-decoder attention
|
| 128 |
+
w_q, w_kv = w.split([E, E * 2])
|
| 129 |
+
if b is None:
|
| 130 |
+
b_q = b_kv = None
|
| 131 |
+
else:
|
| 132 |
+
b_q, b_kv = b.split([E, E * 2])
|
| 133 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
| 134 |
+
else:
|
| 135 |
+
w_q, w_k, w_v = w.chunk(3)
|
| 136 |
+
if b is None:
|
| 137 |
+
b_q = b_k = b_v = None
|
| 138 |
+
else:
|
| 139 |
+
b_q, b_k, b_v = b.chunk(3)
|
| 140 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
| 141 |
+
|
| 142 |
+
class Attention(nn.Module):
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
dim,
|
| 146 |
+
num_heads=8,
|
| 147 |
+
qkv_bias=True,
|
| 148 |
+
scaled_cosine=False,
|
| 149 |
+
scale_heads=False,
|
| 150 |
+
logit_scale_max=math.log(1. / 0.01),
|
| 151 |
+
attn_drop=0.,
|
| 152 |
+
proj_drop=0.,
|
| 153 |
+
xattn=False,
|
| 154 |
+
rope=False
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.scaled_cosine = scaled_cosine
|
| 158 |
+
self.scale_heads = scale_heads
|
| 159 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 160 |
+
self.num_heads = num_heads
|
| 161 |
+
self.head_dim = dim // num_heads
|
| 162 |
+
self.scale = self.head_dim ** -0.5
|
| 163 |
+
self.logit_scale_max = logit_scale_max
|
| 164 |
+
|
| 165 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
| 166 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
| 167 |
+
if qkv_bias:
|
| 168 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
| 169 |
+
else:
|
| 170 |
+
self.in_proj_bias = None
|
| 171 |
+
|
| 172 |
+
if self.scaled_cosine:
|
| 173 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
| 174 |
+
else:
|
| 175 |
+
self.logit_scale = None
|
| 176 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 177 |
+
if self.scale_heads:
|
| 178 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
| 179 |
+
else:
|
| 180 |
+
self.head_scale = None
|
| 181 |
+
self.out_proj = nn.Linear(dim, dim)
|
| 182 |
+
self.out_drop = nn.Dropout(proj_drop)
|
| 183 |
+
self.xattn = xattn
|
| 184 |
+
self.xattn_drop = attn_drop
|
| 185 |
+
self.rope = rope
|
| 186 |
+
|
| 187 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
| 188 |
+
L, N, C = x.shape
|
| 189 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
| 190 |
+
if self.xattn:
|
| 191 |
+
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
| 192 |
+
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
| 193 |
+
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
| 194 |
+
|
| 195 |
+
x = xops.memory_efficient_attention(
|
| 196 |
+
q, k, v,
|
| 197 |
+
p=self.xattn_drop,
|
| 198 |
+
scale=self.scale if self.logit_scale is None else None,
|
| 199 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 203 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 204 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 205 |
+
|
| 206 |
+
if self.logit_scale is not None:
|
| 207 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
| 208 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
| 209 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
| 210 |
+
attn = attn.view(-1, L, L)
|
| 211 |
+
else:
|
| 212 |
+
q = q * self.scale
|
| 213 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
| 214 |
+
|
| 215 |
+
if attn_mask is not None:
|
| 216 |
+
if attn_mask.dtype == torch.bool:
|
| 217 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 218 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 219 |
+
attn_mask = new_attn_mask
|
| 220 |
+
attn += attn_mask
|
| 221 |
+
|
| 222 |
+
attn = attn.softmax(dim=-1)
|
| 223 |
+
attn = self.attn_drop(attn)
|
| 224 |
+
|
| 225 |
+
x = torch.bmm(attn, v)
|
| 226 |
+
|
| 227 |
+
if self.head_scale is not None:
|
| 228 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
| 229 |
+
x = x.view(-1, L, C)
|
| 230 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
| 231 |
+
x = self.out_proj(x)
|
| 232 |
+
x = self.out_drop(x)
|
| 233 |
+
return x
|
| 234 |
+
|
| 235 |
+
class CustomAttention(nn.Module):
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
dim,
|
| 239 |
+
num_heads=8,
|
| 240 |
+
qkv_bias=True,
|
| 241 |
+
scaled_cosine=True,
|
| 242 |
+
scale_heads=False,
|
| 243 |
+
logit_scale_max=math.log(1. / 0.01),
|
| 244 |
+
attn_drop=0.,
|
| 245 |
+
proj_drop=0.,
|
| 246 |
+
xattn=False
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.scaled_cosine = scaled_cosine
|
| 250 |
+
self.scale_heads = scale_heads
|
| 251 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 252 |
+
self.num_heads = num_heads
|
| 253 |
+
self.head_dim = dim // num_heads
|
| 254 |
+
self.scale = self.head_dim ** -0.5
|
| 255 |
+
self.logit_scale_max = logit_scale_max
|
| 256 |
+
|
| 257 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
| 258 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
| 259 |
+
if qkv_bias:
|
| 260 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
| 261 |
+
else:
|
| 262 |
+
self.in_proj_bias = None
|
| 263 |
+
|
| 264 |
+
if self.scaled_cosine:
|
| 265 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
| 266 |
+
else:
|
| 267 |
+
self.logit_scale = None
|
| 268 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 269 |
+
if self.scale_heads:
|
| 270 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
| 271 |
+
else:
|
| 272 |
+
self.head_scale = None
|
| 273 |
+
self.out_proj = nn.Linear(dim, dim)
|
| 274 |
+
self.out_drop = nn.Dropout(proj_drop)
|
| 275 |
+
self.xattn = xattn
|
| 276 |
+
self.xattn_drop = attn_drop
|
| 277 |
+
|
| 278 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 279 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
| 280 |
+
N_q, B_q, C_q = q.shape
|
| 281 |
+
N_k, B_k, C_k = k.shape
|
| 282 |
+
N_v, B_v, C_v = v.shape
|
| 283 |
+
if self.xattn:
|
| 284 |
+
# B, N, C -> B, N, num_heads, C
|
| 285 |
+
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
|
| 286 |
+
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
|
| 287 |
+
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
|
| 288 |
+
|
| 289 |
+
x = xops.memory_efficient_attention(
|
| 290 |
+
q, k, v,
|
| 291 |
+
p=self.xattn_drop,
|
| 292 |
+
scale=self.scale if self.logit_scale is None else None,
|
| 293 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
# B*H, L, C
|
| 297 |
+
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
|
| 298 |
+
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
|
| 299 |
+
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
|
| 300 |
+
|
| 301 |
+
if self.logit_scale is not None:
|
| 302 |
+
# B*H, N_q, N_k
|
| 303 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
| 304 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
| 305 |
+
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
|
| 306 |
+
attn = attn.view(-1, N_q, N_k)
|
| 307 |
+
else:
|
| 308 |
+
q = q * self.scale
|
| 309 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
| 310 |
+
|
| 311 |
+
if attn_mask is not None:
|
| 312 |
+
if attn_mask.dtype == torch.bool:
|
| 313 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 314 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 315 |
+
attn_mask = new_attn_mask
|
| 316 |
+
attn += attn_mask
|
| 317 |
+
|
| 318 |
+
attn = attn.softmax(dim=-1)
|
| 319 |
+
attn = self.attn_drop(attn)
|
| 320 |
+
|
| 321 |
+
x = torch.bmm(attn, v)
|
| 322 |
+
|
| 323 |
+
if self.head_scale is not None:
|
| 324 |
+
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
|
| 325 |
+
x = x.view(-1, N_q, C_q)
|
| 326 |
+
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
|
| 327 |
+
x = self.out_proj(x)
|
| 328 |
+
x = self.out_drop(x)
|
| 329 |
+
return x
|
| 330 |
+
|
| 331 |
+
class CustomResidualAttentionBlock(nn.Module):
|
| 332 |
+
def __init__(
|
| 333 |
+
self,
|
| 334 |
+
d_model: int,
|
| 335 |
+
n_head: int,
|
| 336 |
+
mlp_ratio: float = 4.0,
|
| 337 |
+
ls_init_value: float = None,
|
| 338 |
+
act_layer: Callable = nn.GELU,
|
| 339 |
+
norm_layer: Callable = LayerNorm,
|
| 340 |
+
scale_cosine_attn: bool = False,
|
| 341 |
+
scale_heads: bool = False,
|
| 342 |
+
scale_attn: bool = False,
|
| 343 |
+
scale_fc: bool = False,
|
| 344 |
+
cross_attn: bool = False,
|
| 345 |
+
xattn: bool = False,
|
| 346 |
+
):
|
| 347 |
+
super().__init__()
|
| 348 |
+
|
| 349 |
+
self.ln_1 = norm_layer(d_model)
|
| 350 |
+
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
|
| 351 |
+
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
|
| 352 |
+
self.attn = CustomAttention(
|
| 353 |
+
d_model, n_head,
|
| 354 |
+
qkv_bias=True,
|
| 355 |
+
attn_drop=0.,
|
| 356 |
+
proj_drop=0.,
|
| 357 |
+
scaled_cosine=scale_cosine_attn,
|
| 358 |
+
scale_heads=scale_heads,
|
| 359 |
+
xattn=xattn
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
| 363 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 364 |
+
|
| 365 |
+
self.ln_2 = norm_layer(d_model)
|
| 366 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 367 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 368 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 369 |
+
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
| 370 |
+
("gelu", act_layer()),
|
| 371 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
| 372 |
+
]))
|
| 373 |
+
|
| 374 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 375 |
+
|
| 376 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 377 |
+
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
|
| 378 |
+
q = q + self.ls_2(self.mlp(self.ln_2(q)))
|
| 379 |
+
return q
|
| 380 |
+
|
| 381 |
+
class CustomTransformer(nn.Module):
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
width: int,
|
| 385 |
+
layers: int,
|
| 386 |
+
heads: int,
|
| 387 |
+
mlp_ratio: float = 4.0,
|
| 388 |
+
ls_init_value: float = None,
|
| 389 |
+
act_layer: Callable = nn.GELU,
|
| 390 |
+
norm_layer: Callable = LayerNorm,
|
| 391 |
+
scale_cosine_attn: bool = True,
|
| 392 |
+
scale_heads: bool = False,
|
| 393 |
+
scale_attn: bool = False,
|
| 394 |
+
scale_fc: bool = False,
|
| 395 |
+
cross_attn: bool = False,
|
| 396 |
+
xattn: bool = False,
|
| 397 |
+
):
|
| 398 |
+
super().__init__()
|
| 399 |
+
self.width = width
|
| 400 |
+
self.layers = layers
|
| 401 |
+
self.grad_checkpointing = False
|
| 402 |
+
self.xattn = xattn
|
| 403 |
+
|
| 404 |
+
self.resblocks = nn.ModuleList([
|
| 405 |
+
CustomResidualAttentionBlock(
|
| 406 |
+
width,
|
| 407 |
+
heads,
|
| 408 |
+
mlp_ratio,
|
| 409 |
+
ls_init_value=ls_init_value,
|
| 410 |
+
act_layer=act_layer,
|
| 411 |
+
norm_layer=norm_layer,
|
| 412 |
+
scale_cosine_attn=scale_cosine_attn,
|
| 413 |
+
scale_heads=scale_heads,
|
| 414 |
+
scale_attn=scale_attn,
|
| 415 |
+
scale_fc=scale_fc,
|
| 416 |
+
cross_attn=cross_attn,
|
| 417 |
+
xattn=xattn)
|
| 418 |
+
for _ in range(layers)
|
| 419 |
+
])
|
| 420 |
+
|
| 421 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 422 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
| 423 |
+
|
| 424 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
|
| 425 |
+
if k is None and v is None:
|
| 426 |
+
k = v = q
|
| 427 |
+
for r in self.resblocks:
|
| 428 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 429 |
+
q = checkpoint(r, q, k, v, attn_mask)
|
| 430 |
+
else:
|
| 431 |
+
q = r(q, k, v, attn_mask=attn_mask)
|
| 432 |
+
return q
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class ResidualAttentionBlock(nn.Module):
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
d_model: int,
|
| 439 |
+
n_head: int,
|
| 440 |
+
mlp_ratio: float = 4.0,
|
| 441 |
+
ls_init_value: float = None,
|
| 442 |
+
act_layer: Callable = nn.GELU,
|
| 443 |
+
norm_layer: Callable = LayerNorm,
|
| 444 |
+
xattn: bool = False,
|
| 445 |
+
):
|
| 446 |
+
super().__init__()
|
| 447 |
+
|
| 448 |
+
self.ln_1 = norm_layer(d_model)
|
| 449 |
+
if xattn:
|
| 450 |
+
self.attn = Attention(d_model, n_head, xattn=True)
|
| 451 |
+
else:
|
| 452 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 453 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 454 |
+
|
| 455 |
+
self.ln_2 = norm_layer(d_model)
|
| 456 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 457 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 458 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 459 |
+
("gelu", act_layer()),
|
| 460 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
| 461 |
+
]))
|
| 462 |
+
|
| 463 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 464 |
+
self.xattn = xattn
|
| 465 |
+
|
| 466 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 467 |
+
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
| 468 |
+
if self.xattn:
|
| 469 |
+
return self.attn(x, attn_mask=attn_mask)
|
| 470 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
| 471 |
+
|
| 472 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 473 |
+
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
|
| 474 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
| 475 |
+
return x
|
| 476 |
+
|
| 477 |
+
class Transformer(nn.Module):
|
| 478 |
+
def __init__(
|
| 479 |
+
self,
|
| 480 |
+
width: int,
|
| 481 |
+
layers: int,
|
| 482 |
+
heads: int,
|
| 483 |
+
mlp_ratio: float = 4.0,
|
| 484 |
+
ls_init_value: float = None,
|
| 485 |
+
act_layer: Callable = nn.GELU,
|
| 486 |
+
norm_layer: Callable = LayerNorm,
|
| 487 |
+
xattn: bool = False,
|
| 488 |
+
):
|
| 489 |
+
super().__init__()
|
| 490 |
+
self.width = width
|
| 491 |
+
self.layers = layers
|
| 492 |
+
self.grad_checkpointing = False
|
| 493 |
+
|
| 494 |
+
self.resblocks = nn.ModuleList([
|
| 495 |
+
ResidualAttentionBlock(
|
| 496 |
+
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
|
| 497 |
+
for _ in range(layers)
|
| 498 |
+
])
|
| 499 |
+
|
| 500 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 501 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
| 502 |
+
|
| 503 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 504 |
+
for r in self.resblocks:
|
| 505 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 506 |
+
x = checkpoint(r, x, attn_mask)
|
| 507 |
+
else:
|
| 508 |
+
x = r(x, attn_mask=attn_mask)
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class TextTransformer(nn.Module):
|
| 513 |
+
def __init__(
|
| 514 |
+
self,
|
| 515 |
+
context_length: int = 77,
|
| 516 |
+
vocab_size: int = 49408,
|
| 517 |
+
width: int = 512,
|
| 518 |
+
heads: int = 8,
|
| 519 |
+
layers: int = 12,
|
| 520 |
+
ls_init_value: float = None,
|
| 521 |
+
output_dim: int = 512,
|
| 522 |
+
act_layer: Callable = nn.GELU,
|
| 523 |
+
norm_layer: Callable = LayerNorm,
|
| 524 |
+
xattn: bool= False,
|
| 525 |
+
attn_mask: bool = True
|
| 526 |
+
):
|
| 527 |
+
super().__init__()
|
| 528 |
+
self.context_length = context_length
|
| 529 |
+
self.vocab_size = vocab_size
|
| 530 |
+
self.width = width
|
| 531 |
+
self.output_dim = output_dim
|
| 532 |
+
|
| 533 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
| 534 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
|
| 535 |
+
self.transformer = Transformer(
|
| 536 |
+
width=width,
|
| 537 |
+
layers=layers,
|
| 538 |
+
heads=heads,
|
| 539 |
+
ls_init_value=ls_init_value,
|
| 540 |
+
act_layer=act_layer,
|
| 541 |
+
norm_layer=norm_layer,
|
| 542 |
+
xattn=xattn
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
self.xattn = xattn
|
| 546 |
+
self.ln_final = norm_layer(width)
|
| 547 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 548 |
+
|
| 549 |
+
if attn_mask:
|
| 550 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
| 551 |
+
else:
|
| 552 |
+
self.attn_mask = None
|
| 553 |
+
|
| 554 |
+
self.init_parameters()
|
| 555 |
+
|
| 556 |
+
def init_parameters(self):
|
| 557 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 558 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 559 |
+
|
| 560 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 561 |
+
attn_std = self.transformer.width ** -0.5
|
| 562 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 563 |
+
for block in self.transformer.resblocks:
|
| 564 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 565 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 566 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 567 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 568 |
+
|
| 569 |
+
if self.text_projection is not None:
|
| 570 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 571 |
+
|
| 572 |
+
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
| 573 |
+
if not unlocked_layers: # full freezing
|
| 574 |
+
for n, p in self.named_parameters():
|
| 575 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
| 576 |
+
else:
|
| 577 |
+
raise ValueError("Not support partial freeze")
|
| 578 |
+
|
| 579 |
+
@torch.jit.ignore
|
| 580 |
+
def set_grad_checkpointing(self, enable=True):
|
| 581 |
+
self.transformer.grad_checkpointing = enable
|
| 582 |
+
|
| 583 |
+
@torch.jit.ignore
|
| 584 |
+
def no_weight_decay(self):
|
| 585 |
+
# return {'positional_embedding', 'token_embedding'}
|
| 586 |
+
return {'positional_embedding'}
|
| 587 |
+
|
| 588 |
+
def get_num_layers(self):
|
| 589 |
+
return self.transformer.layers
|
| 590 |
+
|
| 591 |
+
def build_attention_mask(self):
|
| 592 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 593 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 594 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 595 |
+
mask.fill_(float("-inf"))
|
| 596 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 597 |
+
return mask
|
| 598 |
+
|
| 599 |
+
def forward(self, text, return_all_features: bool=False):
|
| 600 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 601 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
| 602 |
+
|
| 603 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
| 604 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 605 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
| 606 |
+
# x = self.transformer(x) # no attention mask is applied
|
| 607 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 608 |
+
x = self.ln_final(x)
|
| 609 |
+
|
| 610 |
+
if not return_all_features:
|
| 611 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 612 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 613 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 614 |
+
return x
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def text_transformer():
|
| 618 |
+
model = TextTransformer(
|
| 619 |
+
width=1024,
|
| 620 |
+
output_dim=1024,
|
| 621 |
+
heads=16,
|
| 622 |
+
layers=24,
|
| 623 |
+
xattn=True
|
| 624 |
+
)
|
| 625 |
+
return model
|
eva_vit_model/uta_clip.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from . import eva_vit
|
| 8 |
+
from .transformer import text_transformer
|
| 9 |
+
|
| 10 |
+
class CLIP(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
vision_model: str = 'eva_base_p16',
|
| 14 |
+
):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.visual = eva_vit.__dict__[vision_model]()
|
| 17 |
+
self.text = text_transformer()
|
| 18 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 19 |
+
|
| 20 |
+
def encode_image(self, image, normalize: bool = False):
|
| 21 |
+
features = self.visual(image)
|
| 22 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 23 |
+
|
| 24 |
+
def encode_text(self, text, normalize: bool = False):
|
| 25 |
+
features = self.text(text)
|
| 26 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 27 |
+
|
| 28 |
+
def forward(self, image, text):
|
| 29 |
+
image_features = self.encode_image(image, normalize=True)
|
| 30 |
+
text_features = self.encode_text(text, normalize=True)
|
| 31 |
+
return image_features, text_features, self.logit_scale.exp()
|
imagenet_zeroshot_data.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
|
| 4 |
+
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
|
| 5 |
+
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
|
| 6 |
+
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
|
| 7 |
+
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
|
| 8 |
+
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
|
| 9 |
+
"box turtle", "banded gecko", "green iguana", "Carolina anole",
|
| 10 |
+
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
|
| 11 |
+
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
|
| 12 |
+
"American alligator", "triceratops", "worm snake", "ring-necked snake",
|
| 13 |
+
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
|
| 14 |
+
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
|
| 15 |
+
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
|
| 16 |
+
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
|
| 17 |
+
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
|
| 18 |
+
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
|
| 19 |
+
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
|
| 20 |
+
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
|
| 21 |
+
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
|
| 22 |
+
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
|
| 23 |
+
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
|
| 24 |
+
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
|
| 25 |
+
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
|
| 26 |
+
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
|
| 27 |
+
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
|
| 28 |
+
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
|
| 29 |
+
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
|
| 30 |
+
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
|
| 31 |
+
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
|
| 32 |
+
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
|
| 33 |
+
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
|
| 34 |
+
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
|
| 35 |
+
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
|
| 36 |
+
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
|
| 37 |
+
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
|
| 38 |
+
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
|
| 39 |
+
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
|
| 40 |
+
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
|
| 41 |
+
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
|
| 42 |
+
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
|
| 43 |
+
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
|
| 44 |
+
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
|
| 45 |
+
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
|
| 46 |
+
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
|
| 47 |
+
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
|
| 48 |
+
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
|
| 49 |
+
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
|
| 50 |
+
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
|
| 51 |
+
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
|
| 52 |
+
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
|
| 53 |
+
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
|
| 54 |
+
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
|
| 55 |
+
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
|
| 56 |
+
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
|
| 57 |
+
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
|
| 58 |
+
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
|
| 59 |
+
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
|
| 60 |
+
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
|
| 61 |
+
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
|
| 62 |
+
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
|
| 63 |
+
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
|
| 64 |
+
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
|
| 65 |
+
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
|
| 66 |
+
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
|
| 67 |
+
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
|
| 68 |
+
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
|
| 69 |
+
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
|
| 70 |
+
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
|
| 71 |
+
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
|
| 72 |
+
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
|
| 73 |
+
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
|
| 74 |
+
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
|
| 75 |
+
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
|
| 76 |
+
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
|
| 77 |
+
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
|
| 78 |
+
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
|
| 79 |
+
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
|
| 80 |
+
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
|
| 81 |
+
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
|
| 82 |
+
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
|
| 83 |
+
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
|
| 84 |
+
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
|
| 85 |
+
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
|
| 86 |
+
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
|
| 87 |
+
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
|
| 88 |
+
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
|
| 89 |
+
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
|
| 90 |
+
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
|
| 91 |
+
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
|
| 92 |
+
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
|
| 93 |
+
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
|
| 94 |
+
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
|
| 95 |
+
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
|
| 96 |
+
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
|
| 97 |
+
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
|
| 98 |
+
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
|
| 99 |
+
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
|
| 100 |
+
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
|
| 101 |
+
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
|
| 102 |
+
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
|
| 103 |
+
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
|
| 104 |
+
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
|
| 105 |
+
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
|
| 106 |
+
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
|
| 107 |
+
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
|
| 108 |
+
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
|
| 109 |
+
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
|
| 110 |
+
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
|
| 111 |
+
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
|
| 112 |
+
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
|
| 113 |
+
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
|
| 114 |
+
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
|
| 115 |
+
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
|
| 116 |
+
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
|
| 117 |
+
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
|
| 118 |
+
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
|
| 119 |
+
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
|
| 120 |
+
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
|
| 121 |
+
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
|
| 122 |
+
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
|
| 123 |
+
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
|
| 124 |
+
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
|
| 125 |
+
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
|
| 126 |
+
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
|
| 127 |
+
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
|
| 128 |
+
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
|
| 129 |
+
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
|
| 130 |
+
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
|
| 131 |
+
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
|
| 132 |
+
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
|
| 133 |
+
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
|
| 134 |
+
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
|
| 135 |
+
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
|
| 136 |
+
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
|
| 137 |
+
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
|
| 138 |
+
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
| 139 |
+
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
|
| 140 |
+
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
|
| 141 |
+
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
|
| 142 |
+
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
|
| 143 |
+
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
|
| 144 |
+
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
|
| 145 |
+
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
|
| 146 |
+
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
|
| 147 |
+
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
|
| 148 |
+
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
|
| 149 |
+
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
|
| 150 |
+
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
|
| 151 |
+
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
|
| 152 |
+
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
|
| 153 |
+
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
|
| 154 |
+
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
|
| 155 |
+
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
|
| 156 |
+
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
|
| 157 |
+
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
|
| 158 |
+
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
|
| 159 |
+
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
|
| 160 |
+
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
|
| 161 |
+
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
|
| 162 |
+
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
|
| 163 |
+
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
|
| 164 |
+
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
|
| 165 |
+
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
|
| 166 |
+
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
|
| 167 |
+
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
openai_imagenet_template = [
|
| 174 |
+
lambda c: f'a bad photo of a {c}.',
|
| 175 |
+
lambda c: f'a photo of many {c}.',
|
| 176 |
+
lambda c: f'a sculpture of a {c}.',
|
| 177 |
+
lambda c: f'a photo of the hard to see {c}.',
|
| 178 |
+
lambda c: f'a low resolution photo of the {c}.',
|
| 179 |
+
lambda c: f'a rendering of a {c}.',
|
| 180 |
+
lambda c: f'graffiti of a {c}.',
|
| 181 |
+
lambda c: f'a bad photo of the {c}.',
|
| 182 |
+
lambda c: f'a cropped photo of the {c}.',
|
| 183 |
+
lambda c: f'a tattoo of a {c}.',
|
| 184 |
+
lambda c: f'the embroidered {c}.',
|
| 185 |
+
lambda c: f'a photo of a hard to see {c}.',
|
| 186 |
+
lambda c: f'a bright photo of a {c}.',
|
| 187 |
+
lambda c: f'a photo of a clean {c}.',
|
| 188 |
+
lambda c: f'a photo of a dirty {c}.',
|
| 189 |
+
lambda c: f'a dark photo of the {c}.',
|
| 190 |
+
lambda c: f'a drawing of a {c}.',
|
| 191 |
+
lambda c: f'a photo of my {c}.',
|
| 192 |
+
lambda c: f'the plastic {c}.',
|
| 193 |
+
lambda c: f'a photo of the cool {c}.',
|
| 194 |
+
lambda c: f'a close-up photo of a {c}.',
|
| 195 |
+
lambda c: f'a black and white photo of the {c}.',
|
| 196 |
+
lambda c: f'a painting of the {c}.',
|
| 197 |
+
lambda c: f'a painting of a {c}.',
|
| 198 |
+
lambda c: f'a pixelated photo of the {c}.',
|
| 199 |
+
lambda c: f'a sculpture of the {c}.',
|
| 200 |
+
lambda c: f'a bright photo of the {c}.',
|
| 201 |
+
lambda c: f'a cropped photo of a {c}.',
|
| 202 |
+
lambda c: f'a plastic {c}.',
|
| 203 |
+
lambda c: f'a photo of the dirty {c}.',
|
| 204 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
| 205 |
+
lambda c: f'a blurry photo of the {c}.',
|
| 206 |
+
lambda c: f'a photo of the {c}.',
|
| 207 |
+
lambda c: f'a good photo of the {c}.',
|
| 208 |
+
lambda c: f'a rendering of the {c}.',
|
| 209 |
+
lambda c: f'a {c} in a video game.',
|
| 210 |
+
lambda c: f'a photo of one {c}.',
|
| 211 |
+
lambda c: f'a doodle of a {c}.',
|
| 212 |
+
lambda c: f'a close-up photo of the {c}.',
|
| 213 |
+
lambda c: f'a photo of a {c}.',
|
| 214 |
+
lambda c: f'the origami {c}.',
|
| 215 |
+
lambda c: f'the {c} in a video game.',
|
| 216 |
+
lambda c: f'a sketch of a {c}.',
|
| 217 |
+
lambda c: f'a doodle of the {c}.',
|
| 218 |
+
lambda c: f'a origami {c}.',
|
| 219 |
+
lambda c: f'a low resolution photo of a {c}.',
|
| 220 |
+
lambda c: f'the toy {c}.',
|
| 221 |
+
lambda c: f'a rendition of the {c}.',
|
| 222 |
+
lambda c: f'a photo of the clean {c}.',
|
| 223 |
+
lambda c: f'a photo of a large {c}.',
|
| 224 |
+
lambda c: f'a rendition of a {c}.',
|
| 225 |
+
lambda c: f'a photo of a nice {c}.',
|
| 226 |
+
lambda c: f'a photo of a weird {c}.',
|
| 227 |
+
lambda c: f'a blurry photo of a {c}.',
|
| 228 |
+
lambda c: f'a cartoon {c}.',
|
| 229 |
+
lambda c: f'art of a {c}.',
|
| 230 |
+
lambda c: f'a sketch of the {c}.',
|
| 231 |
+
lambda c: f'a embroidered {c}.',
|
| 232 |
+
lambda c: f'a pixelated photo of a {c}.',
|
| 233 |
+
lambda c: f'itap of the {c}.',
|
| 234 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
| 235 |
+
lambda c: f'a good photo of a {c}.',
|
| 236 |
+
lambda c: f'a plushie {c}.',
|
| 237 |
+
lambda c: f'a photo of the nice {c}.',
|
| 238 |
+
lambda c: f'a photo of the small {c}.',
|
| 239 |
+
lambda c: f'a photo of the weird {c}.',
|
| 240 |
+
lambda c: f'the cartoon {c}.',
|
| 241 |
+
lambda c: f'art of the {c}.',
|
| 242 |
+
lambda c: f'a drawing of the {c}.',
|
| 243 |
+
lambda c: f'a photo of the large {c}.',
|
| 244 |
+
lambda c: f'a black and white photo of a {c}.',
|
| 245 |
+
lambda c: f'the plushie {c}.',
|
| 246 |
+
lambda c: f'a dark photo of a {c}.',
|
| 247 |
+
lambda c: f'itap of a {c}.',
|
| 248 |
+
lambda c: f'graffiti of the {c}.',
|
| 249 |
+
lambda c: f'a toy {c}.',
|
| 250 |
+
lambda c: f'itap of my {c}.',
|
| 251 |
+
lambda c: f'a photo of a cool {c}.',
|
| 252 |
+
lambda c: f'a photo of a small {c}.',
|
| 253 |
+
lambda c: f'a tattoo of the {c}.',
|
| 254 |
+
]
|
imagenet_zeroshot_eval.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import torchvision.datasets as datasets
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
| 12 |
+
|
| 13 |
+
import eva_vit_model
|
| 14 |
+
from eva_vit_model import CLIP
|
| 15 |
+
from open_clip.tokenizer import tokenize
|
| 16 |
+
from imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main(args):
|
| 20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 23 |
+
torch.backends.cudnn.benchmark = True
|
| 24 |
+
torch.backends.cudnn.deterministic = False
|
| 25 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 26 |
+
|
| 27 |
+
print(f"creating model: {args.model}")
|
| 28 |
+
model = CLIP(vision_model=args.model)
|
| 29 |
+
|
| 30 |
+
print(f"loading checkpoint from {args.ckpt_path}")
|
| 31 |
+
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
| 32 |
+
model.load_state_dict(state_dict, strict=True)
|
| 33 |
+
model.to(device)
|
| 34 |
+
|
| 35 |
+
def _convert_image_to_rgb(image):
|
| 36 |
+
return image.convert("RGB")
|
| 37 |
+
|
| 38 |
+
val_transform = transforms.Compose([
|
| 39 |
+
transforms.Resize(args.image_size, transforms.InterpolationMode.BICUBIC),
|
| 40 |
+
transforms.CenterCrop(args.image_size),
|
| 41 |
+
_convert_image_to_rgb,
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
val_dataset = datasets.ImageFolder(os.path.join(args.imagenet_path, 'val'), transform=val_transform)
|
| 47 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers)
|
| 48 |
+
|
| 49 |
+
model.eval()
|
| 50 |
+
classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, device)
|
| 51 |
+
top1, top5 = zero_shot_eval(model, classifier, val_loader, device)
|
| 52 |
+
print(f'ImageNet zeroshot top1: {top1:.4f}, top5: {top5:.4f}')
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def zero_shot_classifier(model, classnames, templates, device):
|
| 56 |
+
tokenizer = tokenize
|
| 57 |
+
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
zeroshot_weights = []
|
| 60 |
+
for classname in tqdm(classnames):
|
| 61 |
+
texts = [template(classname) for template in templates] # format with class
|
| 62 |
+
texts = tokenizer(texts).to(device=device) # tokenize
|
| 63 |
+
with torch.cuda.amp.autocast():
|
| 64 |
+
class_embeddings = model.encode_text(texts)
|
| 65 |
+
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
|
| 66 |
+
class_embedding /= class_embedding.norm()
|
| 67 |
+
zeroshot_weights.append(class_embedding)
|
| 68 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
| 69 |
+
return zeroshot_weights
|
| 70 |
+
|
| 71 |
+
def accuracy(output, target, topk=(1,)):
|
| 72 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 73 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 74 |
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
|
| 75 |
+
|
| 76 |
+
def zero_shot_eval(model, classifier, dataloader, device):
|
| 77 |
+
top1, top5, n = 0., 0., 0.
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
|
| 80 |
+
images = images.to(device=device)
|
| 81 |
+
target = target.to(device=device)
|
| 82 |
+
|
| 83 |
+
with torch.cuda.amp.autocast():
|
| 84 |
+
image_features = model.encode_image(images)
|
| 85 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 86 |
+
logits = 100. * image_features @ classifier
|
| 87 |
+
|
| 88 |
+
# measure accuracy
|
| 89 |
+
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
| 90 |
+
top1 += acc1
|
| 91 |
+
top5 += acc5
|
| 92 |
+
n += images.size(0)
|
| 93 |
+
|
| 94 |
+
top1 = (top1 / n)
|
| 95 |
+
top5 = (top5 / n)
|
| 96 |
+
return top1, top5
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == '__main__':
|
| 100 |
+
parser = argparse.ArgumentParser(description='ImageNet zero shot evaluations', add_help=False)
|
| 101 |
+
parser.add_argument('--imagenet-path', default='path/to/imagenet', type=str, help='path to imagenet dataset')
|
| 102 |
+
parser.add_argument('--ckpt-path', default='path/to/ckpt', type=str, help='path to checkpoint')
|
| 103 |
+
parser.add_argument('--batch-size', default=64, type=int, help='batch size')
|
| 104 |
+
parser.add_argument('--model', default='eva_base_p16', type=str, help='model')
|
| 105 |
+
parser.add_argument('--image-size', default=224, type=int, help='image size for evaluation')
|
| 106 |
+
parser.add_argument('--workers', default=8, type=int)
|
| 107 |
+
args = parser.parse_args()
|
| 108 |
+
main(args)
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tqdm
|
| 2 |
+
timm
|
| 3 |
+
torch
|
| 4 |
+
open_clip
|
| 5 |
+
torchvision
|
| 6 |
+
xformers
|