RSICRC / src /model.py
RogerFerrod's picture
upload code
f6ffda2 verified
import torch
from torch import nn
from einops import rearrange
import math
from torch import Tensor
import torchvision.models as models
from torch.nn import functional as F
class ICCModel(nn.Module):
def __init__(self, device, pretrained, backbone, d_model, vocab_size, max_len,
num_heads, h_dim, a_dim, encoder_layers, decoder_layers, dropout,
learnable=False, fine_tune=True, tie_embeddings=True, prenorm=False):
super(ICCModel, self).__init__()
self.feature_dim = d_model
visual = pretrained.visual if pretrained else None
self.encoder = ImagesEncoder(device, visual, backbone, d_model, num_heads, h_dim, a_dim, dropout,
encoder_layers, fine_tune)
self.decoder = Decoder(device, d_model, vocab_size, max_len, num_heads,
decoder_layers, dropout,
learnable=learnable, tie_embeddings=tie_embeddings, prenorm=prenorm)
def forward(self, img1, img2, input_ids, labels, attention_mask):
vis_emb, vis_toks = self.encoder(img1, img2)
cap_loss, text_emb, lm_logits, weights = self.decoder(input_ids, labels, attention_mask, vis_toks)
return cap_loss, vis_emb, text_emb, vis_toks, lm_logits, weights
class ImagesEncoder(nn.Module):
def __init__(self, device, pretrained, backbone, d_model, num_heads, h_dim, a_dim, dropout,
encoder_layers, fine_tune):
super(ImagesEncoder, self).__init__()
self.encoder = Encoder(pretrained, backbone, d_model, fine_tune)
self.encoder_trans = AttentiveEncoder(device, encoder_layers,
[self.encoder.feat_size, self.encoder.feat_size, d_model], num_heads,
hidden_dim=h_dim, attention_dim=a_dim, dropout=dropout)
self.cos = torch.nn.CosineSimilarity(dim=1)
self.Conv1 = nn.Conv2d(d_model * 2, d_model, kernel_size=1)
self.LN = resblock(d_model, d_model)
self.att_pool = nn.MultiheadAttention(d_model, num_heads)
self.att_pool_norm = nn.LayerNorm(d_model)
self.img_queries = nn.Parameter(torch.randn(1, d_model))
def forward(self, img1, img2):
feat1 = self.encoder(img1)
feat2 = self.encoder(img2)
x1, x2 = self.encoder_trans(feat1, feat2) # batch_size, channel, enc_image_size, enc_image_size
x_sam = self.cos(x1, x2)
x = torch.cat([x1, x2], dim=1) + x_sam.unsqueeze(1) # batch_size, 2channel, enc_image_size, enc_image_size
x = self.LN(self.Conv1(x))
batch, channel = x.size(0), x.size(1)
x = x.view(batch, channel, -1).permute(2, 0, 1) # h*w, batch, dim
img_queries = self.img_queries.unsqueeze(1).repeat(1, x.shape[1], 1) # L,N,E
img_emb = self.att_pool(img_queries, x, x, need_weights=False)[0]
img_emb = self.att_pool_norm(img_emb) # 1, batch, d_model
cls = img_emb[0]
return cls, x
class Encoder(nn.Module):
def __init__(self, pretrained, backbone, d_model, fine_tune):
super(Encoder, self).__init__()
self.backbone = backbone
if 'rn' in backbone.lower():
modules = list(pretrained.children())[:-1]
self.net = nn.Sequential(*modules)
self.feat_dim = 2048
self.feat_size = 7
elif 'b-32' in backbone.lower():
self.net = pretrained
self.net.output_tokens = True
self.feat_dim = 768
self.feat_size = 7
elif 'l-14' in backbone.lower():
self.net = pretrained
self.net.output_tokens = True
self.feat_dim = 1024
self.feat_size = 16
elif backbone == 'resnet50':
net = models.resnet50(pretrained=True)
modules = list(net.children())[:-2]
self.net = nn.Sequential(*modules)
self.feat_dim = 2048
self.feat_size = 8
elif backbone == 'resnet101':
net = models.resnet101(pretrained=True)
modules = list(net.children())[:-2]
self.net = nn.Sequential(*modules)
self.feat_dim = 2048
self.feat_size = 8
self.proj = None
if self.feat_dim != d_model:
self.proj = nn.Conv2d(self.feat_dim, d_model, kernel_size=1)
self.fine_tune(fine_tune)
def forward(self, image):
feat = self.net(image) # batch, feat_dim, feat_size, feat_size
if 'vit' in self.backbone.lower():
feat = feat[1].reshape(-1, self.feat_size, self.feat_size, self.feat_dim).permute(0, 3, 1, 2)
if self.proj:
feat = self.proj(feat)
return feat
def fine_tune(self, fine_tune=True):
for p in self.net.parameters():
p.requires_grad = False
if 'resnet' in self.backbone:
to_finetune = list(self.net.children())[-5:]
elif 'vit' in self.backbone.lower():
to_finetune = list(self.net.children())[-2:] # only transformer layers
else:
to_finetune = list(self.net.children())[-3:] # only fine-tune convolutional blocks 2 through 4
for c in to_finetune:
for p in c.parameters():
p.requires_grad = fine_tune
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super(FeedForward, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class MultiHeadAtt(nn.Module):
def __init__(self, dim_q, dim_kv, attention_dim, heads=8, dropout=0.):
super(MultiHeadAtt, self).__init__()
project_out = not (heads == 1 and attention_dim == dim_kv)
self.heads = heads
self.scale = (attention_dim // self.heads) ** -0.5
self.to_q = nn.Linear(dim_q, attention_dim, bias=False)
self.to_k = nn.Linear(dim_kv, attention_dim, bias=False)
self.to_v = nn.Linear(dim_kv, attention_dim, bias=False)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(attention_dim, dim_q),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x1, x2, x3):
q = self.to_q(x1)
k = self.to_k(x2)
v = self.to_k(x3)
q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.dropout(self.attend(dots))
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out) # (b,n,dim)
class Transformer(nn.Module):
def __init__(self, dim_q, dim_kv, heads, attention_dim, hidden_dim, dropout=0., norm_first=False):
super(Transformer, self).__init__()
self.norm_first = norm_first
self.att = MultiHeadAtt(dim_q, dim_kv, attention_dim, heads=heads, dropout=dropout)
self.feedforward = FeedForward(dim_q, hidden_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim_q)
self.norm2 = nn.LayerNorm(dim_q)
def forward(self, x1, x2, x3):
if self.norm_first:
x = self.att(self.norm1(x1), self.norm1(x2), self.norm1(x3)) + x1
x = self.feedforward(self.norm2(x)) + x
else:
x = self.norm1(self.att(x1, x2, x3) + x1)
x = self.norm2(self.feedforward(x) + x)
return x
class AttentiveEncoder(nn.Module):
def __init__(self, device, n_layers, feature_size, heads, hidden_dim=512, attention_dim=512, dropout=0.):
super(AttentiveEncoder, self).__init__()
h_feat, w_feat, channels = feature_size
self.device = device
self.h_embedding = nn.Embedding(h_feat, int(channels / 2))
self.w_embedding = nn.Embedding(w_feat, int(channels / 2))
self.selftrans = nn.ModuleList([])
for i in range(n_layers):
self.selftrans.append(nn.ModuleList([
Transformer(channels, channels, heads, attention_dim, hidden_dim, dropout, norm_first=False),
Transformer(channels * 2, channels * 2, heads, attention_dim, hidden_dim, dropout, norm_first=False),
]))
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, img1, img2):
batch, c, h, w = img1.shape
pos_h = torch.arange(h).to(self.device)
pos_w = torch.arange(w).to(self.device)
embed_h = self.w_embedding(pos_h)
embed_w = self.h_embedding(pos_w)
pos_embedding = torch.cat([embed_w.unsqueeze(0).repeat(h, 1, 1),
embed_h.unsqueeze(1).repeat(1, w, 1)],
dim=-1)
pos_embedding = pos_embedding.permute(2, 0, 1).unsqueeze(0).repeat(batch, 1, 1, 1)
img1 = img1 + pos_embedding
img2 = img2 + pos_embedding
img1 = img1.view(batch, c, -1).transpose(-1, 1) # batch, hw, c
img2 = img2.view(batch, c, -1).transpose(-1, 1)
img_sa1, img_sa2 = img1, img2
for (l, m) in self.selftrans:
img_sa1 = l(img_sa1, img_sa1, img_sa1) + img_sa1
img_sa2 = l(img_sa2, img_sa2, img_sa2) + img_sa2
img = torch.cat([img_sa1, img_sa2], dim=-1)
img = m(img, img, img)
img_sa1 = img[:, :, :c] + img1
img_sa2 = img[:, :, c:] + img2
img1 = img_sa1.reshape(batch, h, w, c).transpose(-1, 1)
img2 = img_sa2.reshape(batch, h, w, c).transpose(-1, 1)
return img1, img2
class resblock(nn.Module):
def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
super(resblock, self).__init__()
self.left = nn.Sequential(
nn.Conv2d(inchannel, int(outchannel / 2), kernel_size=1),
# nn.LayerNorm(int(outchannel/2),dim=1),
nn.BatchNorm2d(int(outchannel / 2)),
nn.ReLU(),
nn.Conv2d(int(outchannel / 2), int(outchannel / 2), kernel_size=3, stride=1, padding=1),
# nn.LayerNorm(int(outchannel/2),dim=1),
nn.BatchNorm2d(int(outchannel / 2)),
nn.ReLU(),
nn.Conv2d(int(outchannel / 2), outchannel, kernel_size=1),
# nn.LayerNorm(int(outchannel / 1),dim=1)
nn.BatchNorm2d(outchannel)
)
self.right = shortcut
def forward(self, x):
out = self.left(x)
residual = x
out = out + residual
return F.relu(out)
class Decoder(nn.Module):
def __init__(self, device, h_dim, vocab_size, max_len, n_head, n_layers, dropout=0.10,
learnable=False, tie_embeddings=True, prenorm=False):
super(Decoder, self).__init__()
self.embed_dim = h_dim
self.vocab_size = vocab_size
self.dropout = dropout
self.device = device
self.tokens_embed = nn.Embedding(vocab_size, self.embed_dim)
self.position_encoding = PositionalEncoding(self.embed_dim, dropout=dropout, max_len=max_len,
device=device, learnable=learnable)
self.uni_decoder = nn.ModuleList(
[DecoderLayer(h_dim, h_dim, n_head, dim_feedforward=h_dim * 4, dropout=self.dropout, prenorm=prenorm,
crossattention=False) for _ in range(n_layers)])
self.cross_decoder = nn.ModuleList(
[DecoderLayer(h_dim, h_dim, n_head, dim_feedforward=h_dim * 4, dropout=self.dropout, prenorm=prenorm,
crossattention=True) for _ in range(n_layers)])
self.lm_head = nn.Linear(h_dim, vocab_size, bias=False)
if tie_embeddings:
self.tokens_embed.weight = self.lm_head.weight
self.dropout = nn.Dropout(p=self.dropout)
self.init_weights()
self.loss_fn = nn.CrossEntropyLoss()
def init_weights(self):
self.tokens_embed.weight.data.uniform_(-0.1, 0.1)
self.lm_head.weight.data.uniform_(-0.1, 0.1)
def forward(self, input_ids=None, labels=None, pad_mask=None, img_emb=None):
att_weights = None
mask = torch.tril(torch.ones(input_ids.shape[1], input_ids.shape[1]))
mask = ~mask.bool()
mask = mask.to(self.device)
inputs_embeds = self.tokens_embed(input_ids)
inputs_embeds = self.position_encoding(inputs_embeds) # batch, seq, e_dim
inputs_embeds = inputs_embeds.permute(1, 0, 2) # seq, batch, e_dim
# seq, batch, emb_dim
out = inputs_embeds
for block in self.uni_decoder:
out, _ = block(out, None, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
if pad_mask is not None: # not inference
cls = []
for i in range(pad_mask.shape[0]):
end = pad_mask[i].shape[0] - pad_mask[i].count_nonzero()
cls.append(out[end - 1, i, :])
cls = torch.stack(cls) # batch, emb_dim
else:
cls = None
if img_emb is None:
return None, cls, None, None
for block in self.cross_decoder:
out, att_weights = block(out, img_emb, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
lm_logits = self.lm_head(self.dropout(out)) # seq, batch, voc_dim
lm_logits = lm_logits.permute(1, 0, 2) # batch, seq, voc_dim
if labels is not None: # not inference
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
loss = self.loss_fn(shift_logits, shift_labels)
else:
loss = None
return loss, cls, lm_logits, att_weights
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len, device, learnable=False):
super(PositionalEncoding, self).__init__()
self.learnable = learnable
self.max_len = max_len
self.device = device
self.dropout = nn.Dropout(p=dropout)
if not learnable:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
else:
self.pos_emb = nn.Embedding(max_len, int(d_model))
def forward(self, x):
if self.learnable:
position_ids = torch.arange(x.size(1), dtype=torch.long).to(self.device)
position_ids = position_ids.unsqueeze(0).view(-1, x.size(1)) # batch, seq
x = x + self.pos_emb(position_ids)
else:
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class DecoderLayer(nn.Module):
def __init__(self, d_model, img_dim, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5,
prenorm=False, crossattention=False):
super(DecoderLayer, self).__init__()
self.prenorm = prenorm
self.crossattention = crossattention
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
if crossattention:
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=img_dim, vdim=img_dim)
self.mha_dropout = nn.Dropout(dropout)
self.mha_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.ff_linear1 = nn.Linear(d_model, dim_feedforward)
self.ff_dropout = nn.Dropout(dropout)
self.ff_linear2 = nn.Linear(dim_feedforward, d_model)
self.sa_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.ff_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.sa_dropout = nn.Dropout(dropout)
self.ff_dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask=None,
memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
att_weight = None
x = tgt
if self.prenorm:
x = x + self._sa_block(self.sa_norm(x), tgt_mask, tgt_key_padding_mask)
if self.crossattention:
enc_att, att_weight = self._mha_block(self.mha_norm(x), memory, memory_mask, memory_key_padding_mask)
x = x + enc_att
x = x + self._ff_block(self.ff_norm(x))
else:
x = self.sa_norm(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
if self.crossattention:
enc_att, att_weight = self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
x = self.mha_norm(x + enc_att)
x = self.ff_norm(x + self._ff_block(x))
return x, att_weight
def _sa_block(self, x, attn_mask, key_padding_mask):
x = self.self_attn(x, x, x, # L,N,E
attn_mask=attn_mask, # L, S
key_padding_mask=key_padding_mask, # N, S
is_causal=True,
need_weights=False)[0]
return self.sa_dropout(x)
def _mha_block(self, x, mem, attn_mask, key_padding_mask):
x, att_weight = self.cross_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=False,
need_weights=True)
return self.mha_dropout(x), att_weight
def _ff_block(self, x):
x = self.ff_linear2(self.ff_dropout(self.activation(self.ff_linear1(x))))
return self.ff_dropout(x)