Forge-EMB-mmclip / utils /model_utils.py
AL-GR's picture
Upload model
2666e68 verified
from torch import nn
import torch
from collections import OrderedDict
from transformers.models.bert.tokenization_bert import BertTokenizer
class text_process(object):
def __init__(self, context_length=80, mlm_probability=0.15):
self.context_length = context_length
self.mlm_probability = mlm_probability
bert_path = './bert'
self.tokenizer = BertTokenizer.from_pretrained(bert_path, model_max_length=context_length) # chinese
def __call__(self, text):
text = self.tokenizer(_preprocess_text(text), return_tensors="pt", truncation=True, padding='max_length')
text_ids = text['input_ids']
attention_mask = text['attention_mask']
return text_ids[0]
def __repr__(self):
repr = "(DataAugmentationForBERT,\n"
repr += f" content_length = {self.context_length},\n"
repr += f" mlm_probability = {self.mlm_probability},\n"
repr += ")"
return repr
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
# 使用检查点来计算每个残差块的前向传播
for resblock in self.resblocks:
# x = resblock(x)
x = torch.utils.checkpoint.checkpoint(resblock, x, use_reentrant=False)
return x
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x) # 模长根号d
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
# x = self.ln_post(x[:, 0, :]) # 模长根号d
x = self.ln_post(x)
if self.proj is not None:
x = x @ self.proj
return x[:, 0, :], x
def _preprocess_text(text):
# adapt the text to Chinese BERT vocab
text = text.lower().replace("“", "\"").replace("”", "\"")
return text