CoLMbo / mapper.py
massabaali's picture
Upload CoLMbo model weights and code
f55a095 verified
raw
history blame
10 kB
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from typing import Tuple, Optional
def get_sid_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
if map_type == 'mlp':
mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
elif map_type == 'transformer':
mapper = TransformerMapper(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
else:
raise ValueError(f"Unknown mapping type {map_type}")
for p in mapper.parameters():
p.requires_grad = True
return mapper
def get_text_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
if map_type == 'mlp':
mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
elif map_type == 'transformer':
mapper = TransformerMapperSeq(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
else:
raise ValueError(f"Unknown mapping type {map_type}")
for p in mapper.parameters():
p.requires_grad = True
return mapper
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
class Projection(nn.Module):
def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(d_in, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
self.init_weight()
def init_weight(self):
init_layer(self.linear1)
init_layer(self.linear2)
init_bn(self.layer_norm)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(nnf.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class MLP(nn.Module):
def __init__(self, emb_size, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
self.emb_size = emb_size
# if self.emb_size is not None:
# self.projector = Projection(emb_size, sizes[0])
layers = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# if self.emb_size is not None:
# x = self.projector(x)
return self.model(x)
class MlpTransformer(nn.Module):
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
super().__init__()
out_d = out_d if out_d is not None else in_dim
self.fc1 = nn.Linear(in_dim, h_dim)
self.act = act
self.fc2 = nn.Linear(h_dim, out_d)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim ** -0.5
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y=None, mask=None):
y = y if y is not None else x
b, n, c = x.shape
_, m, d = y.shape
# b n h dh
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
# b m 2 h dh
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
out = self.project(out)
return out, attention
class TransformerLayer(nn.Module):
def forward_with_attention(self, x, y=None, mask=None):
x_, attention = self.attn(self.norm1(x), y, mask)
x = x + x_
x = x + self.mlp(self.norm2(x))
return x, attention
def forward(self, x, y=None, mask=None):
x = x + self.attn(self.norm1(x), y, mask)[0]
x = x + self.mlp(self.norm2(x))
return x
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
norm_layer: nn.Module = nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim_self)
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
self.norm2 = norm_layer(dim_self)
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
class Transformer(nn.Module):
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
super(Transformer, self).__init__()
dim_ref = dim_ref if dim_ref is not None else dim_self
self.enc_dec = enc_dec
if enc_dec:
num_layers = num_layers * 2
layers = []
for i in range(num_layers):
if i % 2 == 0 and enc_dec: # cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
elif enc_dec: # self
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
else: # self or cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
self.layers = nn.ModuleList(layers)
def forward_with_attention(self, x, y=None, mask=None):
attentions = []
for layer in self.layers:
x, att = layer.forward_with_attention(x, y, mask)
attentions.append(att)
return x, attentions
def forward(self, x, y=None, mask=None):
for i, layer in enumerate(self.layers):
if i % 2 == 0 and self.enc_dec: # cross
x = layer(x, y)
elif self.enc_dec: # self
x = layer(x, x, mask)
else: # self or cross
x = layer(x, y, mask)
return x
class TransformerMapper(nn.Module):
def __init__(self, emb_size, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapper, self).__init__()
self.emb_size = emb_size
# if self.emb_size is not None:
# self.projector = Projection(emb_size, dim_clip)
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
if self.emb_size is not None:
x = self.projector(x)
# raise SystemError(x.shape) # torch.Size([100, 1024])
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
# raise SystemError(x.shape) # torch.Size([100, 40, 768])
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
prefix = torch.cat((x, prefix), dim=1) # shape is batch x seq x dim = b x 40+40 x 768 (clip length is 40)
out = self.transformer(prefix)[:, self.clip_length:]
# raise SystemError(out.shape) # torch.Size([100, 40, 768]) sid prefix
return out
class TransformerMapperSeq(nn.Module):
def __init__(self, emb_size ,dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapperSeq, self).__init__()
self.emb_size = emb_size
# if self.emb_size is not None:
# self.projector = Projection(emb_size, dim_clip)
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
# if self.emb_size is not None:
# x = self.projector(x)
# raise SystemError(x.shape) # torch.Size([32, 80, 768])
x = x.view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
# raise SystemError(prefix.shape, x.shape) # torch.Size([32, 40, 768]) torch.Size([32, 40, 1536])
prefix = torch.cat((x, prefix), dim=1)
out = self.transformer(prefix)[:, self.clip_length:]
# raise SystemError(out.shape) # torch.Size([100, 80, 768]) text prefix
return out