base_model / model.py
biancaganescu's picture
Upload model.py with huggingface_hub
1894069 verified
import torch
import torch.nn as nn
import math
class DualStreamTransformer(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 768,
n_head: int = 8,
d_hid: int = 768,
num_encoder_layers: int = 5,
num_decoder_layers: int = 8,
dino_dim: int = 768,
dropout: float = 0.1,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_head = n_head
self.d_hid = d_hid
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
self.dino_dim = dino_dim
self.dropout = dropout
self.text_embedding = self.SimpleTextEmbedding(vocab_size, d_model)
self.image_embedding = self.DinoImageEmbedding(dino_dim, d_model)
self.image_encoder = self.Encoder(
d_model, n_head, d_hid, num_encoder_layers, dropout
)
self.decoder = self.MultimodalDecoder(
d_model, n_head, d_hid, num_decoder_layers, dropout
)
self.output_layer = nn.Linear(d_model, vocab_size)
def forward(
self, input_ids, dino_embedding=None, padding_mask=None, use_image: bool = False
):
embedded = self.text_embedding(input_ids)
if (
use_image
and dino_embedding is not None
and not torch.all(dino_embedding == 0)
):
image_embedded = self.image_embedding(dino_embedding)
image_encoded = self.image_encoder(image_embedded)
else:
image_encoded = None
seq_len = embedded.size(1)
tgt_mask = self.decoder.generate_square_subsequent_mask(seq_len).to(
embedded.device
)
decoder_output = self.decoder(
tgt=embedded,
image_memory=image_encoded,
tgt_mask=tgt_mask,
tgt_key_padding_mask=padding_mask,
)
output = self.output_layer(decoder_output)
return output
class SimpleTextEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, max_len=128, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_len, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(p=dropout)
self.d_model = d_model
def forward(self, x):
batch_size, seq_len = x.size()
positions = (
torch.arange(seq_len, device=x.device)
.unsqueeze(0)
.expand(batch_size, seq_len)
)
scale = math.sqrt(self.d_model)
token_emb = self.token_embedding(x) * scale
pos_emb = self.position_embedding(positions)
embeddings = self.dropout(token_emb + pos_emb)
return self.layer_norm(embeddings)
class DinoImageEmbedding(nn.Module):
def __init__(self, dino_dim, d_model):
super().__init__()
self.projection_layer = nn.Linear(dino_dim, d_model)
def forward(self, x):
return self.projection_layer(x.unsqueeze(1))
class Encoder(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
d_hid: int,
n_layers: int,
dropout: float = 0.1,
):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model, n_head, d_hid, dropout, activation="gelu", batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, n_layers)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
return self.encoder(src, src_mask, src_key_padding_mask)
class DynamicGating(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1):
super().__init__()
self.gate_fc = nn.Linear(d_model * 2, d_model)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, text_features, image_features):
if image_features is None:
return text_features
combined = torch.cat([text_features, image_features], dim=-1)
gate = torch.sigmoid(self.gate_fc(combined))
fused = gate * text_features + (1 - gate) * image_features
fused = self.layer_norm(self.dropout(fused))
return fused
class MultimodalDecoderLayer(nn.Module):
def __init__(self, d_model: int, n_head: int, d_hid: int, dropout: float = 0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, n_head, dropout=dropout, batch_first=True
)
self.cross_attn_txt_image = nn.MultiheadAttention(
d_model, n_head, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.gate = DualStreamTransformer.DynamicGating(d_model, dropout)
self.ff = nn.Sequential(
nn.Linear(d_model, d_hid),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_hid, d_model),
nn.Dropout(dropout),
)
def forward(self, tgt, image_memory, tgt_mask=None, tgt_key_padding_mask=None):
tgt_norm = self.norm1(tgt)
self_attn_output, _ = self.self_attn(
tgt_norm,
tgt_norm,
tgt_norm,
key_padding_mask=tgt_key_padding_mask,
attn_mask=tgt_mask,
is_causal=True,
)
tgt = tgt + self.dropout(self_attn_output)
if image_memory is not None:
tgt_norm = self.norm2(tgt)
cross_attn_output, _ = self.cross_attn_txt_image(
tgt_norm, image_memory, image_memory
)
cross_attn_output = self.dropout(cross_attn_output)
fused = self.gate(tgt_norm, cross_attn_output)
tgt = tgt + fused
tgt_norm = self.norm3(tgt)
ff_output = self.ff(tgt_norm)
tgt = tgt + self.dropout(ff_output)
return tgt
class MultimodalDecoder(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
d_hid: int,
n_layers: int,
dropout: float = 0.1,
):
super().__init__()
self.layers = nn.ModuleList(
[
DualStreamTransformer.MultimodalDecoderLayer(
d_model, n_head, d_hid, dropout
)
for _ in range(n_layers)
]
)
def generate_square_subsequent_mask(self, size):
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return mask
def forward(self, tgt, image_memory, tgt_mask, tgt_key_padding_mask=None):
output = tgt
for layer in self.layers:
output = layer(
output,
image_memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
)
return output