serdaryildiz commited on
Commit
2b8e195
·
verified ·
1 Parent(s): affe4a1

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/test1.png filter=lfs diff=lfs merge=lfs -text
37
+ images/test2.JPG filter=lfs diff=lfs merge=lfs -text
38
+ images/test3.png filter=lfs diff=lfs merge=lfs -text
39
+ images/test4.png filter=lfs diff=lfs merge=lfs -text
40
+ images/test5.png filter=lfs diff=lfs merge=lfs -text
41
+ images/test6.png filter=lfs diff=lfs merge=lfs -text
42
+ images/test7.png filter=lfs diff=lfs merge=lfs -text
43
+ images/test8.png filter=lfs diff=lfs merge=lfs -text
Model/TRCaptionNet.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+
7
+ from Model import clip
8
+ import torch
9
+ from torch import nn
10
+ from transformers import AutoTokenizer, BertTokenizer
11
+ from Model.bert import BertLMHeadModel, BertConfig
12
+ from Model.clip.model import Transformer
13
+ from Model.dino import DinoV2
14
+
15
+
16
+ class Proj(nn.Module):
17
+
18
+ def __init__(self, encoder_output_size, num_head=16):
19
+ super().__init__()
20
+ self.encoder_output_size = encoder_output_size
21
+
22
+ self.transformer = Transformer(encoder_output_size, 1, num_head)
23
+ self.linear = nn.Linear(encoder_output_size, 768)
24
+ return
25
+
26
+ def forward(self, x):
27
+ x = x.permute(1, 0, 2) # NLD -> LND
28
+ x = self.transformer(x)
29
+ x = x.permute(1, 0, 2) # LND -> NLD
30
+ return self.linear(x)
31
+
32
+
33
+ class TRCaptionNetpp(nn.Module):
34
+ def __init__(self, config: dict):
35
+ super().__init__()
36
+ # parameters
37
+ self.max_length = config["max_length"]
38
+ self.proj_flag = config["proj"]
39
+ assert type(self.proj_flag) is bool
40
+ self.proj_num_head = config["proj_num_head"]
41
+
42
+ # vision encoder
43
+ if "clip" in config:
44
+ self.vision_encoder, preprocess = clip.load(config["clip"], jit=False)
45
+ self.vision_encoder.eval()
46
+ self.vision_encoder = self.vision_encoder.visual
47
+ with torch.no_grad():
48
+ dummpy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(
49
+ next(self.parameters()).device).half()
50
+ encoder_output_size = self.vision_encoder(dummpy_input_image.unsqueeze(0)).shape[-1]
51
+ elif "dino2" in config:
52
+ self.vision_encoder = DinoV2(config["dino2"])
53
+ encoder_output_size = self.vision_encoder.get_output_dim()
54
+ else:
55
+ raise Exception("Image Encoder Init Error!")
56
+
57
+ # language decoder
58
+ if not os.path.isfile(config["bert"]):
59
+ self.language_decoder = BertLMHeadModel.from_pretrained(config["bert"],
60
+ is_decoder=True,
61
+ add_cross_attention=True)
62
+ self.tokenizer = BertTokenizer.from_pretrained(config["bert"])
63
+ else:
64
+ med_config = BertConfig.from_json_file(config["bert"])
65
+ self.language_decoder = BertLMHeadModel(config=med_config)
66
+ self.tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-turkish-cased")
67
+
68
+ # proj
69
+ if self.proj_flag:
70
+ if self.proj_num_head is None:
71
+ self.proj = nn.Linear(encoder_output_size, 768)
72
+ else:
73
+ self.proj = Proj(encoder_output_size, self.proj_num_head)
74
+ else:
75
+ self.proj = None
76
+ return
77
+
78
+ def forward(self, images, captions):
79
+ with torch.no_grad():
80
+ image_embeds = self.vision_encoder(images.half()).float().detach()
81
+
82
+ image_embeds = self.proj(image_embeds)
83
+
84
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(images.device)
85
+
86
+ captions = self.tokenizer(captions, padding='longest', truncation=True, max_length=self.max_length,
87
+ return_tensors="pt").to(images.device)
88
+
89
+ captions.input_ids[:, 0] = 2
90
+ decoder_targets = captions.input_ids.masked_fill(captions.input_ids == self.tokenizer.pad_token_id, -100)
91
+ decoder_targets[:, 0] = -100
92
+
93
+ decoder_output = self.language_decoder(input_ids=captions.input_ids,
94
+ attention_mask=captions.attention_mask,
95
+ encoder_hidden_states=image_embeds,
96
+ encoder_attention_mask=image_atts,
97
+ labels=decoder_targets,
98
+ return_dict=True,
99
+ )
100
+
101
+ loss_lm = decoder_output.loss
102
+ return loss_lm
103
+
104
+ @torch.no_grad()
105
+ def generate(self, images, max_length: int = None, min_length: int = 12, num_beams: int = 3,
106
+ repetition_penalty: float = 1.1):
107
+ image_embeds = self.vision_encoder(images.half()).float()
108
+
109
+ if self.proj is not None:
110
+ image_embeds = self.proj(image_embeds)
111
+
112
+ image_atts = torch.ones(image_embeds.shape[:-1], dtype=torch.long).to(images.device)
113
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts}
114
+
115
+ input_ids = torch.ones((image_embeds.shape[0], 1), device=images.device, dtype=torch.long)
116
+ input_ids *= 2
117
+
118
+ outputs = self.language_decoder.generate(input_ids=input_ids,
119
+ max_length=self.max_length if max_length is None else max_length,
120
+ min_length=min_length,
121
+ num_beams=num_beams,
122
+ eos_token_id=self.tokenizer.sep_token_id,
123
+ pad_token_id=self.tokenizer.pad_token_id,
124
+ repetition_penalty=repetition_penalty,
125
+ **model_kwargs)
126
+
127
+ captions = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
128
+ return captions
Model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .TRCaptionNet import TRCaptionNetpp
Model/bert/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .med import BertLMHeadModel, BertConfig
Model/bert/med.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ # self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ # if is_cross_attention:
113
+ # self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ # self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ # else:
116
+ # self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ # self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
119
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
120
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
121
+
122
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
123
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
124
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
125
+ self.max_position_embeddings = config.max_position_embeddings
126
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
127
+ self.save_attention = False
128
+
129
+ def save_attn_gradients(self, attn_gradients):
130
+ self.attn_gradients = attn_gradients
131
+
132
+ def get_attn_gradients(self):
133
+ return self.attn_gradients
134
+
135
+ def save_attention_map(self, attention_map):
136
+ self.attention_map = attention_map
137
+
138
+ def get_attention_map(self):
139
+ return self.attention_map
140
+
141
+ def transpose_for_scores(self, x):
142
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
143
+ x = x.view(*new_x_shape)
144
+ return x.permute(0, 2, 1, 3)
145
+
146
+ def forward(
147
+ self,
148
+ hidden_states,
149
+ attention_mask=None,
150
+ head_mask=None,
151
+ encoder_hidden_states=None,
152
+ encoder_attention_mask=None,
153
+ past_key_value=None,
154
+ output_attentions=False,
155
+ ):
156
+ mixed_query_layer = self.query(hidden_states)
157
+
158
+ # If this is instantiated as a cross-attention module, the keys
159
+ # and values come from an encoder; the attention mask needs to be
160
+ # such that the encoder's padding tokens are not attended to.
161
+ is_cross_attention = encoder_hidden_states is not None
162
+
163
+ if is_cross_attention:
164
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
165
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
166
+ attention_mask = encoder_attention_mask
167
+ elif past_key_value is not None:
168
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
169
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
170
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
171
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
172
+ else:
173
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
174
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
175
+
176
+ query_layer = self.transpose_for_scores(mixed_query_layer)
177
+
178
+ past_key_value = (key_layer, value_layer)
179
+
180
+ # Take the dot product between "query" and "key" to get the raw attention scores.
181
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
182
+
183
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
184
+ seq_length = hidden_states.size()[1]
185
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
186
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
187
+ distance = position_ids_l - position_ids_r
188
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
189
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
190
+
191
+ if self.position_embedding_type == "relative_key":
192
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ attention_scores = attention_scores + relative_position_scores
194
+ elif self.position_embedding_type == "relative_key_query":
195
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
196
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
197
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
198
+
199
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
200
+ if attention_mask is not None:
201
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
202
+ attention_scores = attention_scores + attention_mask
203
+
204
+ # Normalize the attention scores to probabilities.
205
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
206
+
207
+ if is_cross_attention and self.save_attention:
208
+ self.save_attention_map(attention_probs)
209
+ attention_probs.register_hook(self.save_attn_gradients)
210
+
211
+ # This is actually dropping out entire tokens to attend to, which might
212
+ # seem a bit unusual, but is taken from the original Transformer paper.
213
+ attention_probs_dropped = self.dropout(attention_probs)
214
+
215
+ # Mask heads if we want to
216
+ if head_mask is not None:
217
+ attention_probs_dropped = attention_probs_dropped * head_mask
218
+
219
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
220
+
221
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
222
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
223
+ context_layer = context_layer.view(*new_context_layer_shape)
224
+
225
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
226
+
227
+ outputs = outputs + (past_key_value,)
228
+ return outputs
229
+
230
+
231
+ class BertSelfOutput(nn.Module):
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
235
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
236
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
237
+
238
+ def forward(self, hidden_states, input_tensor):
239
+ hidden_states = self.dense(hidden_states)
240
+ hidden_states = self.dropout(hidden_states)
241
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
242
+ return hidden_states
243
+
244
+
245
+ class BertAttention(nn.Module):
246
+ def __init__(self, config, is_cross_attention=False):
247
+ super().__init__()
248
+ self.self = BertSelfAttention(config, is_cross_attention)
249
+ self.output = BertSelfOutput(config)
250
+ self.pruned_heads = set()
251
+
252
+ def prune_heads(self, heads):
253
+ if len(heads) == 0:
254
+ return
255
+ heads, index = find_pruneable_heads_and_indices(
256
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
257
+ )
258
+
259
+ # Prune linear layers
260
+ self.self.query = prune_linear_layer(self.self.query, index)
261
+ self.self.key = prune_linear_layer(self.self.key, index)
262
+ self.self.value = prune_linear_layer(self.self.value, index)
263
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
264
+
265
+ # Update hyper params and store pruned heads
266
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
267
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
268
+ self.pruned_heads = self.pruned_heads.union(heads)
269
+
270
+ def forward(
271
+ self,
272
+ hidden_states,
273
+ attention_mask=None,
274
+ head_mask=None,
275
+ encoder_hidden_states=None,
276
+ encoder_attention_mask=None,
277
+ past_key_value=None,
278
+ output_attentions=False,
279
+ ):
280
+ self_outputs = self.self(
281
+ hidden_states,
282
+ attention_mask,
283
+ head_mask,
284
+ encoder_hidden_states,
285
+ encoder_attention_mask,
286
+ past_key_value,
287
+ output_attentions,
288
+ )
289
+ attention_output = self.output(self_outputs[0], hidden_states)
290
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
291
+ return outputs
292
+
293
+
294
+ class BertIntermediate(nn.Module):
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
298
+ if isinstance(config.hidden_act, str):
299
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
300
+ else:
301
+ self.intermediate_act_fn = config.hidden_act
302
+
303
+ def forward(self, hidden_states):
304
+ hidden_states = self.dense(hidden_states)
305
+ hidden_states = self.intermediate_act_fn(hidden_states)
306
+ return hidden_states
307
+
308
+
309
+ class BertOutput(nn.Module):
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
313
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
314
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
315
+
316
+ def forward(self, hidden_states, input_tensor):
317
+ hidden_states = self.dense(hidden_states)
318
+ hidden_states = self.dropout(hidden_states)
319
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
320
+ return hidden_states
321
+
322
+
323
+ class BertLayer(nn.Module):
324
+ def __init__(self, config, layer_num):
325
+ super().__init__()
326
+ self.config = config
327
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
328
+ self.seq_len_dim = 1
329
+ self.attention = BertAttention(config)
330
+ self.layer_num = layer_num
331
+ if self.config.add_cross_attention:
332
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
333
+ self.intermediate = BertIntermediate(config)
334
+ self.output = BertOutput(config)
335
+
336
+ def forward(
337
+ self,
338
+ hidden_states,
339
+ attention_mask=None,
340
+ head_mask=None,
341
+ encoder_hidden_states=None,
342
+ encoder_attention_mask=None,
343
+ past_key_value=None,
344
+ output_attentions=False,
345
+ mode=None,
346
+ ):
347
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
348
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
349
+ self_attention_outputs = self.attention(
350
+ hidden_states,
351
+ attention_mask,
352
+ head_mask,
353
+ output_attentions=output_attentions,
354
+ past_key_value=self_attn_past_key_value,
355
+ )
356
+ attention_output = self_attention_outputs[0]
357
+
358
+ outputs = self_attention_outputs[1:-1]
359
+ present_key_value = self_attention_outputs[-1]
360
+
361
+ if mode=='multimodal':
362
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
363
+
364
+ cross_attention_outputs = self.crossattention(
365
+ attention_output,
366
+ attention_mask,
367
+ head_mask,
368
+ encoder_hidden_states,
369
+ encoder_attention_mask,
370
+ output_attentions=output_attentions,
371
+ )
372
+ attention_output = cross_attention_outputs[0]
373
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
374
+ layer_output = apply_chunking_to_forward(
375
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
376
+ )
377
+ outputs = (layer_output,) + outputs
378
+
379
+ outputs = outputs + (present_key_value,)
380
+
381
+ return outputs
382
+
383
+ def feed_forward_chunk(self, attention_output):
384
+ intermediate_output = self.intermediate(attention_output)
385
+ layer_output = self.output(intermediate_output, attention_output)
386
+ return layer_output
387
+
388
+
389
+ class BertEncoder(nn.Module):
390
+ def __init__(self, config):
391
+ super().__init__()
392
+ self.config = config
393
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
394
+ self.gradient_checkpointing = False
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states,
399
+ attention_mask=None,
400
+ head_mask=None,
401
+ encoder_hidden_states=None,
402
+ encoder_attention_mask=None,
403
+ past_key_values=None,
404
+ use_cache=None,
405
+ output_attentions=False,
406
+ output_hidden_states=False,
407
+ return_dict=True,
408
+ mode='multimodal',
409
+ ):
410
+ all_hidden_states = () if output_hidden_states else None
411
+ all_self_attentions = () if output_attentions else None
412
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
413
+
414
+ next_decoder_cache = () if use_cache else None
415
+
416
+ for i in range(self.config.num_hidden_layers):
417
+ layer_module = self.layer[i]
418
+ if output_hidden_states:
419
+ all_hidden_states = all_hidden_states + (hidden_states,)
420
+
421
+ layer_head_mask = head_mask[i] if head_mask is not None else None
422
+ past_key_value = past_key_values[i] if past_key_values is not None else None
423
+
424
+ if self.gradient_checkpointing and self.training:
425
+
426
+ if use_cache:
427
+ logger.warn(
428
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
429
+ )
430
+ use_cache = False
431
+
432
+ def create_custom_forward(module):
433
+ def custom_forward(*inputs):
434
+ return module(*inputs, past_key_value, output_attentions)
435
+
436
+ return custom_forward
437
+
438
+ layer_outputs = torch.utils.checkpoint.checkpoint(
439
+ create_custom_forward(layer_module),
440
+ hidden_states,
441
+ attention_mask,
442
+ layer_head_mask,
443
+ encoder_hidden_states,
444
+ encoder_attention_mask,
445
+ mode=mode,
446
+ )
447
+ else:
448
+ layer_outputs = layer_module(
449
+ hidden_states,
450
+ attention_mask,
451
+ layer_head_mask,
452
+ encoder_hidden_states,
453
+ encoder_attention_mask,
454
+ past_key_value,
455
+ output_attentions,
456
+ mode=mode,
457
+ )
458
+
459
+ hidden_states = layer_outputs[0]
460
+ if use_cache:
461
+ next_decoder_cache += (layer_outputs[-1],)
462
+ if output_attentions:
463
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
464
+
465
+ if output_hidden_states:
466
+ all_hidden_states = all_hidden_states + (hidden_states,)
467
+
468
+ if not return_dict:
469
+ return tuple(
470
+ v
471
+ for v in [
472
+ hidden_states,
473
+ next_decoder_cache,
474
+ all_hidden_states,
475
+ all_self_attentions,
476
+ all_cross_attentions,
477
+ ]
478
+ if v is not None
479
+ )
480
+ return BaseModelOutputWithPastAndCrossAttentions(
481
+ last_hidden_state=hidden_states,
482
+ past_key_values=next_decoder_cache,
483
+ hidden_states=all_hidden_states,
484
+ attentions=all_self_attentions,
485
+ cross_attentions=all_cross_attentions,
486
+ )
487
+
488
+
489
+ class BertPooler(nn.Module):
490
+ def __init__(self, config):
491
+ super().__init__()
492
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
493
+ self.activation = nn.Tanh()
494
+
495
+ def forward(self, hidden_states):
496
+ # We "pool" the model by simply taking the hidden state corresponding
497
+ # to the first token.
498
+ first_token_tensor = hidden_states[:, 0]
499
+ pooled_output = self.dense(first_token_tensor)
500
+ pooled_output = self.activation(pooled_output)
501
+ return pooled_output
502
+
503
+
504
+ class BertPredictionHeadTransform(nn.Module):
505
+ def __init__(self, config):
506
+ super().__init__()
507
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
508
+ if isinstance(config.hidden_act, str):
509
+ self.transform_act_fn = ACT2FN[config.hidden_act]
510
+ else:
511
+ self.transform_act_fn = config.hidden_act
512
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
513
+
514
+ def forward(self, hidden_states):
515
+ hidden_states = self.dense(hidden_states)
516
+ hidden_states = self.transform_act_fn(hidden_states)
517
+ hidden_states = self.LayerNorm(hidden_states)
518
+ return hidden_states
519
+
520
+
521
+ class BertLMPredictionHead(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.transform = BertPredictionHeadTransform(config)
525
+
526
+ # The output weights are the same as the input embeddings, but there is
527
+ # an output-only bias for each token.
528
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
529
+
530
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
531
+
532
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
533
+ self.decoder.bias = self.bias
534
+
535
+ def forward(self, hidden_states):
536
+ hidden_states = self.transform(hidden_states)
537
+ hidden_states = self.decoder(hidden_states)
538
+ return hidden_states
539
+
540
+
541
+ class BertOnlyMLMHead(nn.Module):
542
+ def __init__(self, config):
543
+ super().__init__()
544
+ self.predictions = BertLMPredictionHead(config)
545
+
546
+ def forward(self, sequence_output):
547
+ prediction_scores = self.predictions(sequence_output)
548
+ return prediction_scores
549
+
550
+
551
+ class BertPreTrainedModel(PreTrainedModel):
552
+ """
553
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
554
+ models.
555
+ """
556
+
557
+ config_class = BertConfig
558
+ base_model_prefix = "bert"
559
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
560
+
561
+ def _init_weights(self, module):
562
+ """ Initialize the weights """
563
+ if isinstance(module, (nn.Linear, nn.Embedding)):
564
+ # Slightly different from the TF version which uses truncated_normal for initialization
565
+ # cf https://github.com/pytorch/pytorch/pull/5617
566
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
567
+ elif isinstance(module, nn.LayerNorm):
568
+ module.bias.data.zero_()
569
+ module.weight.data.fill_(1.0)
570
+ if isinstance(module, nn.Linear) and module.bias is not None:
571
+ module.bias.data.zero_()
572
+
573
+
574
+ class BertModel(BertPreTrainedModel):
575
+ """
576
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
577
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
578
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
579
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
580
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
581
+ input to the forward pass.
582
+ """
583
+
584
+ def __init__(self, config, add_pooling_layer=True):
585
+ super().__init__(config)
586
+ self.config = config
587
+
588
+ self.embeddings = BertEmbeddings(config)
589
+
590
+ self.encoder = BertEncoder(config)
591
+
592
+ self.pooler = BertPooler(config) if add_pooling_layer else None
593
+
594
+ self.init_weights()
595
+
596
+
597
+ def get_input_embeddings(self):
598
+ return self.embeddings.word_embeddings
599
+
600
+ def set_input_embeddings(self, value):
601
+ self.embeddings.word_embeddings = value
602
+
603
+ def _prune_heads(self, heads_to_prune):
604
+ """
605
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
606
+ class PreTrainedModel
607
+ """
608
+ for layer, heads in heads_to_prune.items():
609
+ self.encoder.layer[layer].attention.prune_heads(heads)
610
+
611
+
612
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
613
+ """
614
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
615
+
616
+ Arguments:
617
+ attention_mask (:obj:`torch.Tensor`):
618
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
619
+ input_shape (:obj:`Tuple[int]`):
620
+ The shape of the input to the model.
621
+ device: (:obj:`torch.device`):
622
+ The device of the input to the model.
623
+
624
+ Returns:
625
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
626
+ """
627
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
628
+ # ourselves in which case we just need to make it broadcastable to all heads.
629
+ if attention_mask.dim() == 3:
630
+ extended_attention_mask = attention_mask[:, None, :, :]
631
+ elif attention_mask.dim() == 2:
632
+ # Provided a padding mask of dimensions [batch_size, seq_length]
633
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
634
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
635
+ if is_decoder:
636
+ batch_size, seq_length = input_shape
637
+
638
+ seq_ids = torch.arange(seq_length, device=device)
639
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
640
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
641
+ # causal and attention masks must have same type with pytorch version < 1.3
642
+ causal_mask = causal_mask.to(attention_mask.dtype)
643
+
644
+ if causal_mask.shape[1] < attention_mask.shape[1]:
645
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
646
+ causal_mask = torch.cat(
647
+ [
648
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
649
+ causal_mask,
650
+ ],
651
+ axis=-1,
652
+ )
653
+
654
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
655
+ else:
656
+ extended_attention_mask = attention_mask[:, None, None, :]
657
+ else:
658
+ raise ValueError(
659
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
660
+ input_shape, attention_mask.shape
661
+ )
662
+ )
663
+
664
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
665
+ # masked positions, this operation will create a tensor which is 0.0 for
666
+ # positions we want to attend and -10000.0 for masked positions.
667
+ # Since we are adding it to the raw scores before the softmax, this is
668
+ # effectively the same as removing these entirely.
669
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
670
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
671
+ return extended_attention_mask
672
+
673
+ def forward(
674
+ self,
675
+ input_ids=None,
676
+ attention_mask=None,
677
+ position_ids=None,
678
+ head_mask=None,
679
+ inputs_embeds=None,
680
+ encoder_embeds=None,
681
+ encoder_hidden_states=None,
682
+ encoder_attention_mask=None,
683
+ past_key_values=None,
684
+ use_cache=None,
685
+ output_attentions=None,
686
+ output_hidden_states=None,
687
+ return_dict=None,
688
+ is_decoder=False,
689
+ mode='multimodal',
690
+ ):
691
+ r"""
692
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
693
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
694
+ the model is configured as a decoder.
695
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
696
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
697
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
698
+ - 1 for tokens that are **not masked**,
699
+ - 0 for tokens that are **masked**.
700
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
701
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
702
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
703
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
704
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
705
+ use_cache (:obj:`bool`, `optional`):
706
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
707
+ decoding (see :obj:`past_key_values`).
708
+ """
709
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
710
+ output_hidden_states = (
711
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
712
+ )
713
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
714
+
715
+ if is_decoder:
716
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
717
+ else:
718
+ use_cache = False
719
+
720
+ if input_ids is not None and inputs_embeds is not None:
721
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
722
+ elif input_ids is not None:
723
+ input_shape = input_ids.size()
724
+ batch_size, seq_length = input_shape
725
+ device = input_ids.device
726
+ elif inputs_embeds is not None:
727
+ input_shape = inputs_embeds.size()[:-1]
728
+ batch_size, seq_length = input_shape
729
+ device = inputs_embeds.device
730
+ elif encoder_embeds is not None:
731
+ input_shape = encoder_embeds.size()[:-1]
732
+ batch_size, seq_length = input_shape
733
+ device = encoder_embeds.device
734
+ else:
735
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
736
+
737
+ # past_key_values_length
738
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
739
+
740
+ if attention_mask is None:
741
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
742
+
743
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
744
+ # ourselves in which case we just need to make it broadcastable to all heads.
745
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
746
+ device, is_decoder)
747
+
748
+ # If a 2D or 3D attention mask is provided for the cross-attention
749
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
750
+ if encoder_hidden_states is not None:
751
+ if type(encoder_hidden_states) == list:
752
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
753
+ else:
754
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
755
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
756
+
757
+ if type(encoder_attention_mask) == list:
758
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
759
+ elif encoder_attention_mask is None:
760
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
761
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
762
+ else:
763
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
764
+ else:
765
+ encoder_extended_attention_mask = None
766
+
767
+ # Prepare head mask if needed
768
+ # 1.0 in head_mask indicate we keep the head
769
+ # attention_probs has shape bsz x n_heads x N x N
770
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
771
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
772
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
773
+
774
+ if encoder_embeds is None:
775
+ embedding_output = self.embeddings(
776
+ input_ids=input_ids,
777
+ position_ids=position_ids,
778
+ inputs_embeds=inputs_embeds,
779
+ past_key_values_length=past_key_values_length,
780
+ )
781
+ else:
782
+ embedding_output = encoder_embeds
783
+
784
+ encoder_outputs = self.encoder(
785
+ embedding_output,
786
+ attention_mask=extended_attention_mask,
787
+ head_mask=head_mask,
788
+ encoder_hidden_states=encoder_hidden_states,
789
+ encoder_attention_mask=encoder_extended_attention_mask,
790
+ past_key_values=past_key_values,
791
+ use_cache=use_cache,
792
+ output_attentions=output_attentions,
793
+ output_hidden_states=output_hidden_states,
794
+ return_dict=return_dict,
795
+ mode=mode,
796
+ )
797
+ sequence_output = encoder_outputs[0]
798
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
799
+
800
+ if not return_dict:
801
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
802
+
803
+ return BaseModelOutputWithPoolingAndCrossAttentions(
804
+ last_hidden_state=sequence_output,
805
+ pooler_output=pooled_output,
806
+ past_key_values=encoder_outputs.past_key_values,
807
+ hidden_states=encoder_outputs.hidden_states,
808
+ attentions=encoder_outputs.attentions,
809
+ cross_attentions=encoder_outputs.cross_attentions,
810
+ )
811
+
812
+
813
+
814
+ class BertLMHeadModel(BertPreTrainedModel):
815
+
816
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
817
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
818
+
819
+ def __init__(self, config):
820
+ super().__init__(config)
821
+
822
+ self.bert = BertModel(config, add_pooling_layer=False)
823
+ self.cls = BertOnlyMLMHead(config)
824
+
825
+ self.init_weights()
826
+
827
+ def get_output_embeddings(self):
828
+ return self.cls.predictions.decoder
829
+
830
+ def set_output_embeddings(self, new_embeddings):
831
+ self.cls.predictions.decoder = new_embeddings
832
+
833
+ def forward(
834
+ self,
835
+ input_ids=None,
836
+ attention_mask=None,
837
+ position_ids=None,
838
+ head_mask=None,
839
+ inputs_embeds=None,
840
+ encoder_hidden_states=None,
841
+ encoder_attention_mask=None,
842
+ labels=None,
843
+ past_key_values=None,
844
+ use_cache=None,
845
+ output_attentions=None,
846
+ output_hidden_states=None,
847
+ return_dict=None,
848
+ return_logits=False,
849
+ is_decoder=True,
850
+ reduction='mean',
851
+ mode='multimodal',
852
+ ):
853
+ r"""
854
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
855
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
856
+ the model is configured as a decoder.
857
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
858
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
859
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
860
+ - 1 for tokens that are **not masked**,
861
+ - 0 for tokens that are **masked**.
862
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
863
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
864
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
865
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
866
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
867
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
868
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
869
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
870
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
871
+ use_cache (:obj:`bool`, `optional`):
872
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
873
+ decoding (see :obj:`past_key_values`).
874
+ Returns:
875
+ Example::
876
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
877
+ >>> import torch
878
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
879
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
880
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
881
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
882
+ >>> outputs = model(**inputs)
883
+ >>> prediction_logits = outputs.logits
884
+ """
885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
886
+ if labels is not None:
887
+ use_cache = False
888
+
889
+ outputs = self.bert(
890
+ input_ids,
891
+ attention_mask=attention_mask,
892
+ position_ids=position_ids,
893
+ head_mask=head_mask,
894
+ inputs_embeds=inputs_embeds,
895
+ encoder_hidden_states=encoder_hidden_states,
896
+ encoder_attention_mask=encoder_attention_mask,
897
+ past_key_values=past_key_values,
898
+ use_cache=use_cache,
899
+ output_attentions=output_attentions,
900
+ output_hidden_states=output_hidden_states,
901
+ return_dict=return_dict,
902
+ is_decoder=is_decoder,
903
+ mode=mode,
904
+ )
905
+
906
+ sequence_output = outputs[0]
907
+ prediction_scores = self.cls(sequence_output)
908
+
909
+ if return_logits:
910
+ return prediction_scores[:, :-1, :].contiguous()
911
+
912
+ lm_loss = None
913
+ if labels is not None:
914
+ # we are doing next-token prediction; shift prediction scores and input ids by one
915
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
916
+ labels = labels[:, 1:].contiguous()
917
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
918
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
919
+ if reduction=='none':
920
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
921
+
922
+ if not return_dict:
923
+ output = (prediction_scores,) + outputs[2:]
924
+ return ((lm_loss,) + output) if lm_loss is not None else output
925
+
926
+ return CausalLMOutputWithCrossAttentions(
927
+ loss=lm_loss,
928
+ logits=prediction_scores,
929
+ past_key_values=outputs.past_key_values,
930
+ hidden_states=outputs.hidden_states,
931
+ attentions=outputs.attentions,
932
+ cross_attentions=outputs.cross_attentions,
933
+ )
934
+
935
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
936
+ input_shape = input_ids.shape
937
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
938
+ if attention_mask is None:
939
+ attention_mask = input_ids.new_ones(input_shape)
940
+
941
+ # cut decoder_input_ids if past is used
942
+ if past is not None:
943
+ input_ids = input_ids[:, -1:]
944
+
945
+ return {
946
+ "input_ids": input_ids,
947
+ "attention_mask": attention_mask,
948
+ "past_key_values": past,
949
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
950
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
951
+ "is_decoder": True,
952
+ }
953
+
954
+ def _reorder_cache(self, past, beam_idx):
955
+ reordered_past = ()
956
+ for layer_past in past:
957
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
958
+ return reordered_past
Model/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
Model/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
Model/clip/clip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize(n_px, interpolation=BICUBIC),
82
+ CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ device : Union[str, torch.device]
103
+ The device to put the loaded model
104
+
105
+ jit : bool
106
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
107
+
108
+ download_root: str
109
+ path to download the model files; by default, it uses "~/.cache/clip"
110
+
111
+ Returns
112
+ -------
113
+ model : torch.nn.Module
114
+ The CLIP model
115
+
116
+ preprocess : Callable[[PIL.Image], torch.Tensor]
117
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118
+ """
119
+ if name in _MODELS:
120
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121
+ elif os.path.isfile(name):
122
+ model_path = name
123
+ else:
124
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125
+
126
+ with open(model_path, 'rb') as opened_file:
127
+ try:
128
+ # loading JIT archive
129
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130
+ state_dict = None
131
+ except RuntimeError:
132
+ # loading saved state dict
133
+ if jit:
134
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135
+ jit = False
136
+ state_dict = torch.load(opened_file, map_location="cpu")
137
+
138
+ if not jit:
139
+ model = build_model(state_dict or model.state_dict()).to(device)
140
+ if str(device) == "cpu":
141
+ model.float()
142
+ return model, _transform(model.visual.input_resolution)
143
+
144
+ # patch the device names
145
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147
+
148
+ def patch_device(module):
149
+ try:
150
+ graphs = [module.graph] if hasattr(module, "graph") else []
151
+ except RuntimeError:
152
+ graphs = []
153
+
154
+ if hasattr(module, "forward1"):
155
+ graphs.append(module.forward1.graph)
156
+
157
+ for graph in graphs:
158
+ for node in graph.findAllNodes("prim::Constant"):
159
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160
+ node.copyAttributes(device_node)
161
+
162
+ model.apply(patch_device)
163
+ patch_device(model.encode_image)
164
+ patch_device(model.encode_text)
165
+
166
+ # patch dtype to float32 on CPU
167
+ if str(device) == "cpu":
168
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170
+ float_node = float_input.node()
171
+
172
+ def patch_float(module):
173
+ try:
174
+ graphs = [module.graph] if hasattr(module, "graph") else []
175
+ except RuntimeError:
176
+ graphs = []
177
+
178
+ if hasattr(module, "forward1"):
179
+ graphs.append(module.forward1.graph)
180
+
181
+ for graph in graphs:
182
+ for node in graph.findAllNodes("aten::to"):
183
+ inputs = list(node.inputs())
184
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185
+ if inputs[i].node()["value"] == 5:
186
+ inputs[i].node().copyAttributes(float_node)
187
+
188
+ model.apply(patch_float)
189
+ patch_float(model.encode_image)
190
+ patch_float(model.encode_text)
191
+
192
+ model.float()
193
+
194
+ return model, _transform(model.input_resolution.item())
195
+
196
+
197
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198
+ """
199
+ Returns the tokenized representation of given input string(s)
200
+
201
+ Parameters
202
+ ----------
203
+ texts : Union[str, List[str]]
204
+ An input string or a list of input strings to tokenize
205
+
206
+ context_length : int
207
+ The context length to use; all CLIP models use 77 as the context length
208
+
209
+ truncate: bool
210
+ Whether to truncate the text in case its encoding is longer than the context length
211
+
212
+ Returns
213
+ -------
214
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216
+ """
217
+ if isinstance(texts, str):
218
+ texts = [texts]
219
+
220
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
221
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
222
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225
+ else:
226
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227
+
228
+ for i, tokens in enumerate(all_tokens):
229
+ if len(tokens) > context_length:
230
+ if truncate:
231
+ tokens = tokens[:context_length]
232
+ tokens[-1] = eot_token
233
+ else:
234
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235
+ result[i, :len(tokens)] = torch.tensor(tokens)
236
+
237
+ return result
Model/clip/model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x, key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.input_resolution = input_resolution
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.relu1 = nn.ReLU(inplace=True)
112
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
+ self.bn2 = nn.BatchNorm2d(width // 2)
114
+ self.relu2 = nn.ReLU(inplace=True)
115
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
+ self.bn3 = nn.BatchNorm2d(width)
117
+ self.relu3 = nn.ReLU(inplace=True)
118
+ self.avgpool = nn.AvgPool2d(2)
119
+
120
+ # residual layers
121
+ self._inplanes = width # this is a *mutable* variable used during construction
122
+ self.layer1 = self._make_layer(width, layers[0])
123
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
+
127
+ embed_dim = width * 32 # the ResNet feature dimension
128
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
129
+
130
+ def _make_layer(self, planes, blocks, stride=1):
131
+ layers = [Bottleneck(self._inplanes, planes, stride)]
132
+
133
+ self._inplanes = planes * Bottleneck.expansion
134
+ for _ in range(1, blocks):
135
+ layers.append(Bottleneck(self._inplanes, planes))
136
+
137
+ return nn.Sequential(*layers)
138
+
139
+ def forward(self, x):
140
+ def stem(x):
141
+ x = self.relu1(self.bn1(self.conv1(x)))
142
+ x = self.relu2(self.bn2(self.conv2(x)))
143
+ x = self.relu3(self.bn3(self.conv3(x)))
144
+ x = self.avgpool(x)
145
+ return x
146
+
147
+ x = x.type(self.conv1.weight.dtype)
148
+ x = stem(x)
149
+ x = self.layer1(x)
150
+ x = self.layer2(x)
151
+ x = self.layer3(x)
152
+ x = self.layer4(x)
153
+ x = self.attnpool(x)
154
+
155
+ return x
156
+
157
+
158
+ class LayerNorm(nn.LayerNorm):
159
+ """Subclass torch's LayerNorm to handle fp16."""
160
+
161
+ def forward(self, x: torch.Tensor):
162
+ orig_type = x.dtype
163
+ ret = super().forward(x.type(torch.float32))
164
+ return ret.type(orig_type)
165
+
166
+
167
+ class QuickGELU(nn.Module):
168
+ def forward(self, x: torch.Tensor):
169
+ return x * torch.sigmoid(1.702 * x)
170
+
171
+
172
+ class ResidualAttentionBlock(nn.Module):
173
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
174
+ super().__init__()
175
+
176
+ self.attn = nn.MultiheadAttention(d_model, n_head)
177
+ self.ln_1 = LayerNorm(d_model)
178
+ self.mlp = nn.Sequential(OrderedDict([
179
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
180
+ ("gelu", QuickGELU()),
181
+ ("c_proj", nn.Linear(d_model * 4, d_model))
182
+ ]))
183
+ self.ln_2 = LayerNorm(d_model)
184
+ self.attn_mask = attn_mask
185
+
186
+ def attention(self, x: torch.Tensor):
187
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
188
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
189
+
190
+ def forward(self, x: torch.Tensor):
191
+ x = x + self.attention(self.ln_1(x))
192
+ x = x + self.mlp(self.ln_2(x))
193
+ return x
194
+
195
+
196
+ class Transformer(nn.Module):
197
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
198
+ super().__init__()
199
+ self.width = width
200
+ self.layers = layers
201
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
202
+
203
+ def forward(self, x: torch.Tensor):
204
+ return self.resblocks(x)
205
+
206
+
207
+ class VisionTransformer(nn.Module):
208
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
209
+ super().__init__()
210
+ self.input_resolution = input_resolution
211
+ self.output_dim = output_dim
212
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
213
+
214
+ scale = width ** -0.5
215
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
216
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
217
+ self.ln_pre = LayerNorm(width)
218
+
219
+ self.transformer = Transformer(width, layers, heads)
220
+
221
+ self.ln_post = LayerNorm(width)
222
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
223
+
224
+ def forward(self, x: torch.Tensor):
225
+ x = self.conv1(x) # shape = [*, width, grid, grid]
226
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
227
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
228
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
229
+ x = x + self.positional_embedding.to(x.dtype)
230
+ x = self.ln_pre(x)
231
+
232
+ x = x.permute(1, 0, 2) # NLD -> LND
233
+ x = self.transformer(x)
234
+ x = x.permute(1, 0, 2) # LND -> NLD
235
+
236
+ # x = self.ln_post(x[:, 0, :])
237
+ #
238
+ # if self.proj is not None:
239
+ # x = x @ self.proj
240
+
241
+ return x
242
+
243
+
244
+ class CLIP(nn.Module):
245
+ def __init__(self,
246
+ embed_dim: int,
247
+ # vision
248
+ image_resolution: int,
249
+ vision_layers: Union[Tuple[int, int, int, int], int],
250
+ vision_width: int,
251
+ vision_patch_size: int,
252
+ # text
253
+ context_length: int,
254
+ vocab_size: int,
255
+ transformer_width: int,
256
+ transformer_heads: int,
257
+ transformer_layers: int
258
+ ):
259
+ super().__init__()
260
+
261
+ self.context_length = context_length
262
+
263
+ if isinstance(vision_layers, (tuple, list)):
264
+ vision_heads = vision_width * 32 // 64
265
+ self.visual = ModifiedResNet(
266
+ layers=vision_layers,
267
+ output_dim=embed_dim,
268
+ heads=vision_heads,
269
+ input_resolution=image_resolution,
270
+ width=vision_width
271
+ )
272
+ else:
273
+ vision_heads = vision_width // 64
274
+ self.visual = VisionTransformer(
275
+ input_resolution=image_resolution,
276
+ patch_size=vision_patch_size,
277
+ width=vision_width,
278
+ layers=vision_layers,
279
+ heads=vision_heads,
280
+ output_dim=embed_dim
281
+ )
282
+
283
+ self.transformer = Transformer(
284
+ width=transformer_width,
285
+ layers=transformer_layers,
286
+ heads=transformer_heads,
287
+ attn_mask=self.build_attention_mask()
288
+ )
289
+
290
+ self.vocab_size = vocab_size
291
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
292
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
293
+ self.ln_final = LayerNorm(transformer_width)
294
+
295
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
296
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
297
+
298
+ self.initialize_parameters()
299
+
300
+ def initialize_parameters(self):
301
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
302
+ nn.init.normal_(self.positional_embedding, std=0.01)
303
+
304
+ if isinstance(self.visual, ModifiedResNet):
305
+ if self.visual.attnpool is not None:
306
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
307
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
310
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
311
+
312
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
313
+ for name, param in resnet_block.named_parameters():
314
+ if name.endswith("bn3.weight"):
315
+ nn.init.zeros_(param)
316
+
317
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
318
+ attn_std = self.transformer.width ** -0.5
319
+ fc_std = (2 * self.transformer.width) ** -0.5
320
+ for block in self.transformer.resblocks:
321
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
322
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
323
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
324
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
325
+
326
+ if self.text_projection is not None:
327
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
328
+
329
+ def build_attention_mask(self):
330
+ # lazily create causal attention mask, with full attention between the vision tokens
331
+ # pytorch uses additive attention mask; fill with -inf
332
+ mask = torch.empty(self.context_length, self.context_length)
333
+ mask.fill_(float("-inf"))
334
+ mask.triu_(1) # zero out the lower diagonal
335
+ return mask
336
+
337
+ @property
338
+ def dtype(self):
339
+ return self.visual.conv1.weight.dtype
340
+
341
+ def encode_image(self, image):
342
+ return self.visual(image.type(self.dtype))
343
+
344
+ def encode_text(self, text):
345
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
346
+
347
+ x = x + self.positional_embedding.type(self.dtype)
348
+ x = x.permute(1, 0, 2) # NLD -> LND
349
+ x = self.transformer(x)
350
+ x = x.permute(1, 0, 2) # LND -> NLD
351
+ x = self.ln_final(x).type(self.dtype)
352
+
353
+ # x.shape = [batch_size, n_ctx, transformer.width]
354
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
355
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
356
+
357
+ return x
358
+
359
+ def forward(self, image, text):
360
+ image_features = self.encode_image(image)
361
+ text_features = self.encode_text(text)
362
+
363
+ # normalized features
364
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
365
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
366
+
367
+ # cosine similarity as logits
368
+ logit_scale = self.logit_scale.exp()
369
+ logits_per_image = logit_scale * image_features @ text_features.t()
370
+ logits_per_text = logits_per_image.t()
371
+
372
+ # shape = [global_batch_size, global_batch_size]
373
+ return logits_per_image, logits_per_text
374
+
375
+
376
+ def convert_weights(model: nn.Module):
377
+ """Convert applicable model parameters to fp16"""
378
+
379
+ def _convert_weights_to_fp16(l):
380
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
381
+ l.weight.data = l.weight.data.half()
382
+ if l.bias is not None:
383
+ l.bias.data = l.bias.data.half()
384
+
385
+ if isinstance(l, nn.MultiheadAttention):
386
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
387
+ tensor = getattr(l, attr)
388
+ if tensor is not None:
389
+ tensor.data = tensor.data.half()
390
+
391
+ for name in ["text_projection", "proj"]:
392
+ if hasattr(l, name):
393
+ attr = getattr(l, name)
394
+ if attr is not None:
395
+ attr.data = attr.data.half()
396
+
397
+ model.apply(_convert_weights_to_fp16)
398
+
399
+
400
+ def build_model(state_dict: dict):
401
+ vit = "visual.proj" in state_dict
402
+
403
+ if vit:
404
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
405
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
406
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
407
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
408
+ image_resolution = vision_patch_size * grid_size
409
+ else:
410
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
411
+ vision_layers = tuple(counts)
412
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
413
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
414
+ vision_patch_size = None
415
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
416
+ image_resolution = output_width * 32
417
+
418
+ embed_dim = state_dict["text_projection"].shape[1]
419
+ context_length = state_dict["positional_embedding"].shape[0]
420
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
421
+ transformer_width = state_dict["ln_final.weight"].shape[0]
422
+ transformer_heads = transformer_width // 64
423
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
424
+
425
+ model = CLIP(
426
+ embed_dim,
427
+ image_resolution, vision_layers, vision_width, vision_patch_size,
428
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
429
+ )
430
+
431
+ for key in ["input_resolution", "context_length", "vocab_size"]:
432
+ if key in state_dict:
433
+ del state_dict[key]
434
+
435
+ convert_weights(model)
436
+ model.load_state_dict(state_dict)
437
+ return model.eval()
Model/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
Model/dino/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dino import DinoV2
Model/dino/dino.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch
3
+ from PIL import Image
4
+ from torch import nn
5
+ from torchvision import transforms
6
+
7
+ preprocess = transforms.Compose([transforms.Resize((224, 224)),
8
+ transforms.ToTensor(),
9
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
10
+ std=[0.229, 0.224, 0.225])])
11
+
12
+
13
+ class DinoV2(nn.Module):
14
+
15
+ def __init__(self, model_name):
16
+ super().__init__()
17
+ self.vision_encoder = torch.hub.load('facebookresearch/dinov2', model_name)
18
+ self.vision_encoder = self.vision_encoder.eval().cuda().half()
19
+ return
20
+
21
+ def forward(self, x):
22
+ return self.vision_encoder.forward_features(x)['x_norm_patchtokens']
23
+
24
+ def get_output_dim(self):
25
+ with torch.no_grad():
26
+ dummpy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(
27
+ next(self.parameters()).device).half()
28
+ encoder_output_size = self.vision_encoder(dummpy_input_image.unsqueeze(0)).shape[-1]
29
+ return encoder_output_size
Model/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
README.md CHANGED
@@ -1,14 +1,25 @@
1
  ---
2
- title: TRCaptionNetpp
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.47.2
8
  app_file: app.py
9
- pinned: false
10
- license: cc-by-4.0
11
- short_description: TRCaptionNet++
12
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: TRCaptionNet++
3
+ emoji: 🖼
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
 
7
  app_file: app.py
8
+ pinned: true
 
 
9
  ---
10
+ # Configuration
11
+ `title`: _string_
12
+ TRCaptionNet
13
+ `emoji`: _string_
14
+ 🖼
15
+ `colorFrom`: _string_
16
+ red
17
+ `colorTo`: _string_
18
+ indigo
19
+ `sdk`: _string_
20
+ gradio
21
+ `app_file`: _string_
22
+ app.py
23
 
24
+ `pinned`: _boolean_
25
+ true
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+ import torch
4
+ from torchvision import transforms
5
+
6
+ from Model import TRCaptionNetpp
7
+
8
+ model_ckpt = "./checkpoints/TRCaptionNetpp_Large.pth"
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ preprocess = transforms.Compose([transforms.Resize((224, 224)),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
15
+ std=[0.229, 0.224, 0.225])])
16
+
17
+ model = TRCaptionNetpp({
18
+ "max_length": 35,
19
+ "dino2": "dinov2_vitl14",
20
+ "bert": "dbmdz/electra-base-turkish-mc4-cased-discriminator",
21
+ "proj": True,
22
+ "proj_num_head": 16
23
+ })
24
+ model.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+
29
+ def inference(raw_image, min_length, repetition_penalty):
30
+ batch = preprocess(raw_image).unsqueeze(0).to(device)
31
+ caption = model.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
32
+ return caption
33
+
34
+
35
+ inputs = [gr.Image(type='pil', interactive=True,),
36
+ gr.Slider(minimum=6, maximum=22, value=11, label="MINIMUM CAPTION LENGTH", step=1),
37
+ gr.Slider(minimum=1, maximum=2, value=2.5, label="REPETITION PENALTY")]
38
+ outputs = gr.components.Textbox(label="Caption")
39
+ title = "TRCaptionNet"
40
+ paper_link = ""
41
+ github_link = "https://github.com/serdaryildiz/TRCaptionNetpp"
42
+ description = f"<p style='text-align: center'><a href='{github_link}' target='_blank'>TRCaptionNet++: A high-performance encoder-decoder based deep Turkish image captioning model fine-tuned with a large-scale set of pretrain data"
43
+ examples = [[p] for p in glob.glob("images/*")]
44
+
45
+ article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
46
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
47
+
48
+ iface = gr.Interface(fn=inference,
49
+ inputs=inputs,
50
+ outputs=outputs,
51
+ title=title,
52
+ description=description,
53
+ examples=examples,
54
+ article=article,
55
+ css=css)
56
+ iface.launch()
images/test1.png ADDED

Git LFS Details

  • SHA256: dd096643426a9d750d3932ecb4a0128540f762cc96c7b096859f1f50cab068d1
  • Pointer size: 131 Bytes
  • Size of remote file: 302 kB
images/test2.JPG ADDED

Git LFS Details

  • SHA256: 23d93d36918d63cbd799d58f6f7dc5da65b88e8ee71a8300fb06b4ac3b297045
  • Pointer size: 131 Bytes
  • Size of remote file: 680 kB
images/test3.png ADDED

Git LFS Details

  • SHA256: 1a812a5d92d2d82f0c6507b7190d639e7ce2dc1f9ce3d0d2ff5d1f142af3e53f
  • Pointer size: 131 Bytes
  • Size of remote file: 880 kB
images/test4.png ADDED

Git LFS Details

  • SHA256: c15315704afc9f1bc2c7a752460acef2872d5b1f06e443e9c5166e903fe554ed
  • Pointer size: 131 Bytes
  • Size of remote file: 299 kB
images/test5.png ADDED

Git LFS Details

  • SHA256: 20f6b117bf2f4400aa37cc601a13e62b01955167b36c8834e419a4b1a1c1c1d8
  • Pointer size: 131 Bytes
  • Size of remote file: 865 kB
images/test6.png ADDED

Git LFS Details

  • SHA256: 5c5d7c9b27c46b38ce8ad505efe4309eb41e67bb5271171667ea212f6e65b6fc
  • Pointer size: 131 Bytes
  • Size of remote file: 783 kB
images/test7.png ADDED

Git LFS Details

  • SHA256: 8a8c4b312494d9d70d8801818d8b10331bc0a669b0b67abaee0c724ee298adb2
  • Pointer size: 131 Bytes
  • Size of remote file: 540 kB
images/test8.png ADDED

Git LFS Details

  • SHA256: 417418d79ee195f835c765ec7900f455ce96958246b86ae09567ac9a1520e0a5
  • Pointer size: 131 Bytes
  • Size of remote file: 412 kB
images/test9.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.0
2
+ torchvision==0.15.1
3
+ opencv-python==4.6.0.66
4
+ transformers==4.27.3
5
+ ftfy==6.1.1
6
+ gradio==3.48.0
7
+ gdown==4.6.0