CoLMbo / mapper.py
massabaali's picture
Upload CoLMbo model weights and code
f55a095 verified
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from typing import Tuple, Optional
def get_sid_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
if map_type == 'mlp':
mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
elif map_type == 'transformer':
mapper = TransformerMapper(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
else:
raise ValueError(f"Unknown mapping type {map_type}")
for p in mapper.parameters():
p.requires_grad = True
return mapper
def get_text_mapper(map_type: str, emb_size, prefix_size: int, gpt_embedding_size: int, prefix_length: int, clip_length: int, num_layers: int):
if map_type == 'mlp':
mapper = MLP(emb_size, (prefix_size, (gpt_embedding_size * prefix_length) // 2, gpt_embedding_size * prefix_length))
elif map_type == 'transformer':
mapper = TransformerMapperSeq(emb_size, prefix_size, gpt_embedding_size, prefix_length, clip_length, int(num_layers/2))
else:
raise ValueError(f"Unknown mapping type {map_type}")
for p in mapper.parameters():
p.requires_grad = True
return mapper
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
class Projection(nn.Module):
def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(d_in, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
self.init_weight()
def init_weight(self):
init_layer(self.linear1)
init_layer(self.linear2)
init_bn(self.layer_norm)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(nnf.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class MLP(nn.Module):
def __init__(self, emb_size, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
self.emb_size = emb_size
# if self.emb_size is not None:
# self.projector = Projection(emb_size, sizes[0])
layers = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# if self.emb_size is not None:
# x = self.projector(x)
return self.model(x)
class MlpTransformer(nn.Module):
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
super().__init__()
out_d = out_d if out_d is not None else in_dim
self.fc1 = nn.Linear(in_dim, h_dim)
self.act = act
self.fc2 = nn.Linear(h_dim, out_d)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim ** -0.5
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y=None, mask=None):
y = y if y is not None else x
b, n, c = x.shape
_, m, d = y.shape
# b n h dh
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
# b m 2 h dh
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
out = self.project(out)
return out, attention
class TransformerLayer(nn.Module):
def forward_with_attention(self, x, y=None, mask=None):
x_, attention = self.attn(self.norm1(x), y, mask)
x = x + x_
x = x + self.mlp(self.norm2(x))
return x, attention
def forward(self, x, y=None, mask=None):
x = x + self.attn(self.norm1(x), y, mask)[0]
x = x + self.mlp(self.norm2(x))
return x
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
norm_layer: nn.Module = nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim_self)
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
self.norm2 = norm_layer(dim_self)
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
class Transformer(nn.Module):
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
super(Transformer, self).__init__()
dim_ref = dim_ref if dim_ref is not None else dim_self
self.enc_dec = enc_dec
if enc_dec:
num_layers = num_layers * 2
layers = []
for i in range(num_layers):
if i % 2 == 0 and enc_dec: # cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
elif enc_dec: # self
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
else: # self or cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
self.layers = nn.ModuleList(layers)
def forward_with_attention(self, x, y=None, mask=None):
attentions = []
for layer in self.layers:
x, att = layer.forward_with_attention(x, y, mask)
attentions.append(att)
return x, attentions
def forward(self, x, y=None, mask=None):
for i, layer in enumerate(self.layers):
if i % 2 == 0 and self.enc_dec: # cross
x = layer(x, y)
elif self.enc_dec: # self
x = layer(x, x, mask)
else: # self or cross
x = layer(x, y, mask)
return x
class TransformerMapper(nn.Module):
def __init__(self, emb_size, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapper, self).__init__()
self.emb_size = emb_size
# if self.emb_size is not None:
# self.projector = Projection(emb_size, dim_clip)
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
if self.emb_size is not None:
x = self.projector(x)
# raise SystemError(x.shape) # torch.Size([100, 1024])
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
# raise SystemError(x.shape) # torch.Size([100, 40, 768])
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
prefix = torch.cat((x, prefix), dim=1) # shape is batch x seq x dim = b x 40+40 x 768 (clip length is 40)
out = self.transformer(prefix)[:, self.clip_length:]
# raise SystemError(out.shape) # torch.Size([100, 40, 768]) sid prefix
return out
class TransformerMapperSeq(nn.Module):
def __init__(self, emb_size ,dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapperSeq, self).__init__()
self.emb_size = emb_size
# if self.emb_size is not None:
# self.projector = Projection(emb_size, dim_clip)
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
# if self.emb_size is not None:
# x = self.projector(x)
# raise SystemError(x.shape) # torch.Size([32, 80, 768])
x = x.view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
# raise SystemError(prefix.shape, x.shape) # torch.Size([32, 40, 768]) torch.Size([32, 40, 1536])
prefix = torch.cat((x, prefix), dim=1)
out = self.transformer(prefix)[:, self.clip_length:]
# raise SystemError(out.shape) # torch.Size([100, 80, 768]) text prefix
return out