| | from torch import nn |
| | import torch |
| | from collections import OrderedDict |
| | from transformers.models.bert.tokenization_bert import BertTokenizer |
| |
|
| |
|
| | class text_process(object): |
| | def __init__(self, context_length=80, mlm_probability=0.15): |
| | self.context_length = context_length |
| | self.mlm_probability = mlm_probability |
| |
|
| | bert_path = './bert' |
| | self.tokenizer = BertTokenizer.from_pretrained(bert_path, model_max_length=context_length) |
| |
|
| | def __call__(self, text): |
| | text = self.tokenizer(_preprocess_text(text), return_tensors="pt", truncation=True, padding='max_length') |
| | text_ids = text['input_ids'] |
| | attention_mask = text['attention_mask'] |
| |
|
| | return text_ids[0] |
| |
|
| | def __repr__(self): |
| | repr = "(DataAugmentationForBERT,\n" |
| | repr += f" content_length = {self.context_length},\n" |
| | repr += f" mlm_probability = {self.mlm_probability},\n" |
| | repr += ")" |
| | return repr |
| |
|
| |
|
| | class LayerNorm(nn.LayerNorm): |
| | """Subclass torch's LayerNorm to handle fp16.""" |
| |
|
| | def forward(self, x: torch.Tensor): |
| | orig_type = x.dtype |
| | ret = super().forward(x.type(torch.float32)) |
| | return ret.type(orig_type) |
| |
|
| | class QuickGELU(nn.Module): |
| | def forward(self, x: torch.Tensor): |
| | return x * torch.sigmoid(1.702 * x) |
| |
|
| | class ResidualAttentionBlock(nn.Module): |
| | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
| | super().__init__() |
| |
|
| | self.attn = nn.MultiheadAttention(d_model, n_head) |
| | self.ln_1 = LayerNorm(d_model) |
| | self.mlp = nn.Sequential(OrderedDict([ |
| | ("c_fc", nn.Linear(d_model, d_model * 4)), |
| | ("gelu", QuickGELU()), |
| | ("c_proj", nn.Linear(d_model * 4, d_model)) |
| | ])) |
| | self.ln_2 = LayerNorm(d_model) |
| | self.attn_mask = attn_mask |
| |
|
| | def attention(self, x: torch.Tensor): |
| | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = x + self.attention(self.ln_1(x)) |
| | x = x + self.mlp(self.ln_2(x)) |
| | return x |
| | |
| |
|
| | class Transformer(nn.Module): |
| | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
| | super().__init__() |
| | self.width = width |
| | self.layers = layers |
| | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | for resblock in self.resblocks: |
| | |
| | x = torch.utils.checkpoint.checkpoint(resblock, x, use_reentrant=False) |
| | return x |
| |
|
| |
|
| | class VisualTransformer(nn.Module): |
| | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
| | super().__init__() |
| | self.input_resolution = input_resolution |
| | self.output_dim = output_dim |
| | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
| |
|
| | scale = width ** -0.5 |
| | self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
| | self.ln_pre = LayerNorm(width) |
| |
|
| | self.transformer = Transformer(width, layers, heads) |
| |
|
| | self.ln_post = LayerNorm(width) |
| | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self.conv1(x) |
| | x = x.reshape(x.shape[0], x.shape[1], -1) |
| | x = x.permute(0, 2, 1) |
| | x = torch.cat( |
| | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
| | x], dim=1) |
| | x = x + self.positional_embedding.to(x.dtype) |
| | x = self.ln_pre(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| | x = self.transformer(x) |
| | x = x.permute(1, 0, 2) |
| |
|
| | |
| | x = self.ln_post(x) |
| |
|
| | if self.proj is not None: |
| | x = x @ self.proj |
| |
|
| | return x[:, 0, :], x |
| | |
| |
|
| | def _preprocess_text(text): |
| | |
| | text = text.lower().replace("“", "\"").replace("”", "\"") |
| | return text |