MohmedAnik commited on
Commit
6ebe9b0
·
verified ·
1 Parent(s): 565f88f

Delete vision_tower.py

Browse files
Files changed (1) hide show
  1. vision_tower.py +0 -161
vision_tower.py DELETED
@@ -1,161 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.init as init
4
- import torch.nn.functional as F
5
-
6
- from paths import *
7
-
8
- from typing import Dict, List, Optional, Set, Tuple, Union
9
- from transformers import AutoImageProcessor, AutoModel, Dinov2Model
10
- from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings
11
- from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
12
- import numpy as np
13
- from contextlib import nullcontext
14
-
15
- def get_activation(activation):
16
- if activation.lower() == 'gelu':
17
- return nn.GELU()
18
- elif activation.lower() == 'rrelu':
19
- return nn.RReLU(inplace=True)
20
- elif activation.lower() == 'selu':
21
- return nn.SELU(inplace=True)
22
- elif activation.lower() == 'silu':
23
- return nn.SiLU(inplace=True)
24
- elif activation.lower() == 'hardswish':
25
- return nn.Hardswish(inplace=True)
26
- elif activation.lower() == 'leakyrelu':
27
- return nn.LeakyReLU(inplace=True)
28
- elif activation.lower() == 'sigmoid':
29
- return nn.Sigmoid()
30
- elif activation.lower() == 'tanh':
31
- return nn.Tanh()
32
- else:
33
- return nn.ReLU(inplace=True)
34
-
35
-
36
-
37
- class MLP_dim(nn.Module):
38
- def __init__(
39
- self, in_dim=512, out_dim=1024, bias=True, activation='relu'):
40
- super().__init__()
41
- self.act = get_activation(activation)
42
- self.net1 = nn.Sequential(
43
- nn.Linear(in_dim, int(out_dim), bias=bias),
44
- nn.BatchNorm1d(int(out_dim)),
45
- self.act
46
- )
47
- self.net2 = nn.Sequential(
48
- nn.Linear(int(out_dim), out_dim, bias=bias),
49
- nn.BatchNorm1d(out_dim)
50
- )
51
-
52
- def forward(self, x):
53
- return self.net2(self.net1(x))
54
-
55
- class FLIP_Dinov2Embeddings(Dinov2Embeddings):
56
- """
57
- Construct the CLS token, mask token, position and patch embeddings.
58
- """
59
-
60
- def __init__(self, config: Dinov2Config) -> None:
61
- super().__init__(config)
62
-
63
- def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
64
- batch_size, _, height, width = pixel_values.shape
65
- target_dtype = self.patch_embeddings.projection.weight.dtype
66
- embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
67
-
68
- # add the [CLS] token to the embedded patch tokens
69
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
70
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
71
-
72
- # add positional encoding to each token
73
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
74
-
75
- if bool_masked_pos is not None:
76
- # embeddings = torch.where(
77
- # bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
78
- # )
79
- B,S,D = embeddings.shape
80
- batch_indices = torch.arange(B).unsqueeze(1)
81
- embeddings = embeddings[batch_indices, bool_masked_pos]
82
-
83
- embeddings = self.dropout(embeddings)
84
-
85
- return embeddings
86
-
87
- class FLIP_DINOv2(Dinov2Model):
88
- def __init__(self, config):
89
- super().__init__(config)
90
-
91
- self.embeddings = FLIP_Dinov2Embeddings(config)
92
-
93
- class DINOv2_MLP(nn.Module):
94
- def __init__(self,
95
- dino_mode,
96
- in_dim,
97
- out_dim,
98
- evaluate,
99
- mask_dino,
100
- frozen_back
101
- ) -> None:
102
- super().__init__()
103
- # self.dinov2 = AutoModel.from_pretrained(DINO_BASE)
104
- if dino_mode == 'base':
105
- self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_BASE, cache_dir='./')
106
- elif dino_mode == 'large':
107
- self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_LARGE, cache_dir='./')
108
- elif dino_mode == 'small':
109
- self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_SMALL, cache_dir='./')
110
- elif dino_mode == 'giant':
111
- self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_GIANT, cache_dir='./')
112
-
113
- self.down_sampler = MLP_dim(in_dim=in_dim, out_dim=out_dim)
114
- self.random_mask = False
115
- if not evaluate:
116
- self.init_weights(self.down_sampler)
117
- self.random_mask = mask_dino
118
- if frozen_back:
119
- self.forward_mode = torch.no_grad()
120
- else:
121
- self.forward_mode = nullcontext()
122
-
123
- def forward(self, img_inputs):
124
- device = self.get_device()
125
- # print(img_inputs['pixel_values'].shape)
126
-
127
- with self.forward_mode:
128
- if self.random_mask:
129
- B = len(img_inputs['pixel_values'])
130
- S = 256
131
- indices = []
132
- for i in range(B):
133
- tmp = torch.randperm(S)[:S//2]
134
- tmp = tmp.sort().values + 1
135
- indices.append(tmp)
136
- indices = torch.stack(indices, dim=0)
137
- indices = torch.cat([torch.zeros(B, 1, dtype=torch.long, device='cpu'), indices], dim=1)
138
- # print(indices.shape)
139
- img_inputs['bool_masked_pos'] = indices.to(device)
140
-
141
- dino_outputs = self.dinov2(**img_inputs)
142
- dino_seq = dino_outputs.last_hidden_state
143
- # B,S,_ = dino_seq.shape
144
- # dino_seq = dino_seq.view(B*S,-1)
145
- dino_seq = dino_seq[:,0,:]
146
-
147
- down_sample_out = self.down_sampler(dino_seq)
148
- # down_sample_out = down_sample_out.view(B,S,-1)
149
- # down_sample_out = down_sample_out[:,0,:]
150
-
151
- return down_sample_out
152
-
153
- def get_device(self):
154
- return next(self.parameters()).device
155
-
156
- def init_weights(self, m):
157
- if isinstance(m, nn.Linear):
158
- init.xavier_uniform_(m.weight)
159
- if m.bias is not None:
160
- init.constant_(m.bias, 0)
161
-