mevlt01001's picture
Upload 7 files
9ec3d0b verified
import math
import torch
import torch.nn as nn
@torch.no_grad()
def sinusoidal_1d_pe(S: int, D: int, device=None) -> torch.Tensor:
assert D % 2 == 0, f"D % 2 == 0 must; D={D}"
device = device or torch.device('cpu')
d = D // 2
pos = torch.arange(S, device=device, dtype=torch.float32) # [S]
k = torch.arange(d, device=device, dtype=torch.float32) # [d]
omega = torch.exp(-math.log(10000.0) * k / d) # [d]
# Broadcast: [S,1]*[d] -> [S,d]
pe_sin = torch.sin(pos[..., None] * omega) # [S,d]
pe_cos = torch.cos(pos[..., None] * omega) # [S,d]
pe = torch.cat([pe_sin, pe_cos], dim=-1).unsqueeze(0).contiguous() # [1,S,D]
return pe
def causal_mask(T:int, device):
"""
```
tensor([[False, True, True, True],
[False, False, True, True],
[False, False, False, True],
[False, False, False, False]])
```
"""
m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
return m # [T,T]
class DecoderAttnBlock(nn.Module):
def __init__(self,
embed_dim:int,
heads:int=8,
dropout:float=0.1,
mlp_ratio:float=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.norm3 = nn.LayerNorm(embed_dim)
self.self_attn = nn.MultiheadAttention(embed_dim, heads, batch_first=True, dropout=dropout)
self.cross_attn = nn.MultiheadAttention(embed_dim, heads, batch_first=True, dropout=dropout)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim*mlp_ratio)),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(int(embed_dim*mlp_ratio), embed_dim),
nn.Dropout(dropout)
)
def forward(self,
x:torch.Tensor,
enc_out,
self_attn_mask=None,
self_key_padding=None,
enc_key_padding=None):
# x: [B,T,D], enc_out: [B,S,D]
h = self.norm1.forward(x)
x = x + self.self_attn.forward(h, h, h,
attn_mask=self_attn_mask, # [T,T] mask
key_padding_mask=self_key_padding, # [B,T]
need_weights=False)[0]
h = self.norm2.forward(x)
x = x + self.cross_attn.forward(h, enc_out, enc_out,
key_padding_mask=enc_key_padding, # [B,S]
need_weights=False)[0]
h = self.norm3.forward(x)
x = x + self.mlp.forward(h)
return x # [B,T,D]
class ViTDecoder(nn.Module):
def __init__(self,
vocab_size:int,
pad_id:int,
bos_id:int,
eos_id:int,
dim:int,
depth:int=4,
heads:int=8,
dropout:float=0.1,
device=torch.device('cpu')):
super().__init__()
self.pad_id = pad_id
self.bos_id = bos_id
self.eos_id = eos_id
self.embed = nn.Embedding(vocab_size, dim, padding_idx=pad_id).to(device)
self.dropout = nn.Dropout(dropout).to(device)
self.blocks = nn.ModuleList([DecoderAttnBlock(dim, heads, dropout) for _ in range(depth)]).to(device)
self.norm = nn.LayerNorm(dim).to(device)
self.lm_head = nn.Linear(dim, vocab_size, bias=False).to(device)
# self.lm_head.weight = self.embed.weight
def forward(self,
enc_out:torch.Tensor, # [B,S,D]
tins:torch.Tensor): # [B,T]
B, T = tins.shape
x = self.embed.forward(tins) # [B,T,D]
x = x + sinusoidal_1d_pe(T, x.size(-1), x.device) # 1D PE
x = self.dropout.forward(x)
attn_mask = causal_mask(T, x.device) # [T,T]
key_pad = (tins == self.pad_id) # [B,T]
enc_key_pad = None
for blk in self.blocks:
x = blk.forward(x, enc_out, self_attn_mask=attn_mask, self_key_padding=key_pad, enc_key_padding=enc_key_pad)
x = self.norm(x)
logits = self.lm_head(x) # [B,T,vocab]
return logits
@torch.no_grad()
def generate(self,
enc_out,
max_len:int=256
):
B, S, D = enc_out.shape
x_ids = torch.full((B,1), self.bos_id, device=enc_out.device, dtype=torch.long)
finished = torch.zeros(B, dtype=torch.bool, device=enc_out.device)
for t in range(1, max_len+1):
x = self.embed(x_ids) # [B,t,D]
x = x + sinusoidal_1d_pe(x.size(1), x.size(2), x.device)
for blk in self.blocks:
attn_mask = causal_mask(x.size(1), x.device)
key_pad = (x_ids == self.pad_id)
x = blk(x, enc_out, self_attn_mask=attn_mask, self_key_padding=key_pad)
x = self.norm(x)
logits = self.lm_head(x[:, -1]) # [B,vocab]
next_id = torch.argmax(logits, dim=-1) # greedy
x_ids = torch.cat([x_ids, next_id[:,None]], dim=1)
finished = finished | (next_id == self.eos_id)
if bool(finished.all()):
break
return x_ids # [B, <=max_len+1]
# ------ OLD VERSION ------
class AdditiveAttn(nn.Module):
def __init__(self, fdim=512, hdim=512, dim=128):
super().__init__()
self.linear_image = nn.Linear(fdim, dim, bias=False)
self.linear_hidden = nn.Linear(hdim, dim, bias=False)
self.linear_score = nn.Linear(dim, 1, bias=False)
@torch.no_grad()
def precompute_image(self, image_features: torch.Tensor) -> torch.Tensor:
"""
image_features: [B, S, D_img]
return: proj_image [B, S, A] (A=dim)
"""
return self.linear_image(image_features)
def forward_cached(self,
proj_image: torch.Tensor, # [B, S, A] precomputed
image_features: torch.Tensor, # [B, S, D_img] bmm
hidden_state: torch.Tensor # [B, H]
) -> tuple[torch.Tensor, torch.Tensor]:
proj_hidden = self.linear_hidden(hidden_state).unsqueeze(1)
combined = torch.tanh(proj_image + proj_hidden)
scores = self.linear_score(combined).squeeze(-1)
weights = torch.softmax(scores, dim=-1)
# [B,1,S] x [B,S,D] -> [B,1,D] -> [B,D]
context = torch.bmm(weights.unsqueeze(1), image_features).squeeze(1)
return context, weights
class LSTMDecoder(nn.Module):
def __init__(self,
vocab_size,
token_feats,
hidden_size,
D_img,
pad_id, bos_id, eos_id,
device=torch.device('cpu')):
super().__init__()
self.device = device
self.embed = nn.Embedding(vocab_size, token_feats, padding_idx=pad_id)
self.ctx_proj = nn.Linear(D_img, token_feats, bias=False)
self.mix = nn.Sequential(
nn.Linear(token_feats*2, token_feats, bias=False),
nn.ReLU(inplace=True),
nn.LayerNorm(token_feats)
)
self.cell = nn.LSTMCell(token_feats, hidden_size)
self.head = nn.Linear(hidden_size, vocab_size, bias=False)
self.attn = AdditiveAttn(D_img, hidden_size, dim=128)
self.h0_fc = nn.Linear(D_img, hidden_size)
self.c0_fc = nn.Linear(D_img, hidden_size)
self.bos_id, self.eos_id = bos_id, eos_id
self.to(device)
def _init_state(self, img_seq): # [B,S,D]
g = img_seq.mean(1) # [B,D]
h0 = torch.tanh(self.h0_fc(g)) * 0.5
c0 = torch.tanh(self.c0_fc(g)) * 0.5
return h0, c0
def forward(self, img_seq, tokens_in): # tokens_in:[B,T]
B, T = tokens_in.shape
img_seq = img_seq.contiguous()
proj_image = self.attn.precompute_image(img_seq) # [B,S,A]
h, c = self._init_state(img_seq)
V = self.head.out_features
outs = torch.empty(B, T, V, device=img_seq.device, dtype=torch.float32)
for t in range(T):
emb = self.embed(tokens_in[:, t]) # [B,E]
ctx, _ = self.attn.forward_cached(proj_image, img_seq, h) # [B,D]
ctxE = self.ctx_proj(ctx) # [B,E]
x = torch.cat([emb, ctxE], dim=-1) # [B,2E]
x = self.mix(x) + emb # residual ile emb korunur
h, c = self.cell(x, (h, c)) # [B,H]
outs[:, t, :] = self.head(h) # [B,V]
return outs
@torch.no_grad()
def generate(self, img_seq, max_len=50):
B = img_seq.size(0)
img_seq = img_seq.contiguous()
proj_image = self.attn.precompute_image(img_seq)
h, c = self._init_state(img_seq)
tok = torch.full((B,), self.bos_id, device=img_seq.device, dtype=torch.long)
V = self.head.out_features
outs = torch.empty(B, max_len, dtype=torch.long, device=img_seq.device)
for t in range(max_len):
emb = self.embed(tok) # [B,E]
ctx, _ = self.attn.forward_cached(proj_image, img_seq, h)
ctxE = self.ctx_proj(ctx) # [B,E]
x = torch.cat([emb, ctxE], dim=-1) # [B,2E]
x = self.mix(x) + emb
h, c = self.cell(x, (h, c))
nxt = self.head(h).argmax(-1) # [B]
outs[:, t] = nxt
tok = nxt
if (nxt == self.eos_id).all():
return outs[:, :t+1]
return outs