Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| def sinusoidal_1d_pe(S: int, D: int, device=None) -> torch.Tensor: | |
| assert D % 2 == 0, f"D % 2 == 0 must; D={D}" | |
| device = device or torch.device('cpu') | |
| d = D // 2 | |
| pos = torch.arange(S, device=device, dtype=torch.float32) # [S] | |
| k = torch.arange(d, device=device, dtype=torch.float32) # [d] | |
| omega = torch.exp(-math.log(10000.0) * k / d) # [d] | |
| # Broadcast: [S,1]*[d] -> [S,d] | |
| pe_sin = torch.sin(pos[..., None] * omega) # [S,d] | |
| pe_cos = torch.cos(pos[..., None] * omega) # [S,d] | |
| pe = torch.cat([pe_sin, pe_cos], dim=-1).unsqueeze(0).contiguous() # [1,S,D] | |
| return pe | |
| def causal_mask(T:int, device): | |
| """ | |
| ``` | |
| tensor([[False, True, True, True], | |
| [False, False, True, True], | |
| [False, False, False, True], | |
| [False, False, False, False]]) | |
| ``` | |
| """ | |
| m = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) | |
| return m # [T,T] | |
| class DecoderAttnBlock(nn.Module): | |
| def __init__(self, | |
| embed_dim:int, | |
| heads:int=8, | |
| dropout:float=0.1, | |
| mlp_ratio:float=4.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.norm3 = nn.LayerNorm(embed_dim) | |
| self.self_attn = nn.MultiheadAttention(embed_dim, heads, batch_first=True, dropout=dropout) | |
| self.cross_attn = nn.MultiheadAttention(embed_dim, heads, batch_first=True, dropout=dropout) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(embed_dim, int(embed_dim*mlp_ratio)), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(int(embed_dim*mlp_ratio), embed_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, | |
| x:torch.Tensor, | |
| enc_out, | |
| self_attn_mask=None, | |
| self_key_padding=None, | |
| enc_key_padding=None): | |
| # x: [B,T,D], enc_out: [B,S,D] | |
| h = self.norm1.forward(x) | |
| x = x + self.self_attn.forward(h, h, h, | |
| attn_mask=self_attn_mask, # [T,T] mask | |
| key_padding_mask=self_key_padding, # [B,T] | |
| need_weights=False)[0] | |
| h = self.norm2.forward(x) | |
| x = x + self.cross_attn.forward(h, enc_out, enc_out, | |
| key_padding_mask=enc_key_padding, # [B,S] | |
| need_weights=False)[0] | |
| h = self.norm3.forward(x) | |
| x = x + self.mlp.forward(h) | |
| return x # [B,T,D] | |
| class ViTDecoder(nn.Module): | |
| def __init__(self, | |
| vocab_size:int, | |
| pad_id:int, | |
| bos_id:int, | |
| eos_id:int, | |
| dim:int, | |
| depth:int=4, | |
| heads:int=8, | |
| dropout:float=0.1, | |
| device=torch.device('cpu')): | |
| super().__init__() | |
| self.pad_id = pad_id | |
| self.bos_id = bos_id | |
| self.eos_id = eos_id | |
| self.embed = nn.Embedding(vocab_size, dim, padding_idx=pad_id).to(device) | |
| self.dropout = nn.Dropout(dropout).to(device) | |
| self.blocks = nn.ModuleList([DecoderAttnBlock(dim, heads, dropout) for _ in range(depth)]).to(device) | |
| self.norm = nn.LayerNorm(dim).to(device) | |
| self.lm_head = nn.Linear(dim, vocab_size, bias=False).to(device) | |
| # self.lm_head.weight = self.embed.weight | |
| def forward(self, | |
| enc_out:torch.Tensor, # [B,S,D] | |
| tins:torch.Tensor): # [B,T] | |
| B, T = tins.shape | |
| x = self.embed.forward(tins) # [B,T,D] | |
| x = x + sinusoidal_1d_pe(T, x.size(-1), x.device) # 1D PE | |
| x = self.dropout.forward(x) | |
| attn_mask = causal_mask(T, x.device) # [T,T] | |
| key_pad = (tins == self.pad_id) # [B,T] | |
| enc_key_pad = None | |
| for blk in self.blocks: | |
| x = blk.forward(x, enc_out, self_attn_mask=attn_mask, self_key_padding=key_pad, enc_key_padding=enc_key_pad) | |
| x = self.norm(x) | |
| logits = self.lm_head(x) # [B,T,vocab] | |
| return logits | |
| def generate(self, | |
| enc_out, | |
| max_len:int=256 | |
| ): | |
| B, S, D = enc_out.shape | |
| x_ids = torch.full((B,1), self.bos_id, device=enc_out.device, dtype=torch.long) | |
| finished = torch.zeros(B, dtype=torch.bool, device=enc_out.device) | |
| for t in range(1, max_len+1): | |
| x = self.embed(x_ids) # [B,t,D] | |
| x = x + sinusoidal_1d_pe(x.size(1), x.size(2), x.device) | |
| for blk in self.blocks: | |
| attn_mask = causal_mask(x.size(1), x.device) | |
| key_pad = (x_ids == self.pad_id) | |
| x = blk(x, enc_out, self_attn_mask=attn_mask, self_key_padding=key_pad) | |
| x = self.norm(x) | |
| logits = self.lm_head(x[:, -1]) # [B,vocab] | |
| next_id = torch.argmax(logits, dim=-1) # greedy | |
| x_ids = torch.cat([x_ids, next_id[:,None]], dim=1) | |
| finished = finished | (next_id == self.eos_id) | |
| if bool(finished.all()): | |
| break | |
| return x_ids # [B, <=max_len+1] | |
| # ------ OLD VERSION ------ | |
| class AdditiveAttn(nn.Module): | |
| def __init__(self, fdim=512, hdim=512, dim=128): | |
| super().__init__() | |
| self.linear_image = nn.Linear(fdim, dim, bias=False) | |
| self.linear_hidden = nn.Linear(hdim, dim, bias=False) | |
| self.linear_score = nn.Linear(dim, 1, bias=False) | |
| def precompute_image(self, image_features: torch.Tensor) -> torch.Tensor: | |
| """ | |
| image_features: [B, S, D_img] | |
| return: proj_image [B, S, A] (A=dim) | |
| """ | |
| return self.linear_image(image_features) | |
| def forward_cached(self, | |
| proj_image: torch.Tensor, # [B, S, A] precomputed | |
| image_features: torch.Tensor, # [B, S, D_img] bmm | |
| hidden_state: torch.Tensor # [B, H] | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| proj_hidden = self.linear_hidden(hidden_state).unsqueeze(1) | |
| combined = torch.tanh(proj_image + proj_hidden) | |
| scores = self.linear_score(combined).squeeze(-1) | |
| weights = torch.softmax(scores, dim=-1) | |
| # [B,1,S] x [B,S,D] -> [B,1,D] -> [B,D] | |
| context = torch.bmm(weights.unsqueeze(1), image_features).squeeze(1) | |
| return context, weights | |
| class LSTMDecoder(nn.Module): | |
| def __init__(self, | |
| vocab_size, | |
| token_feats, | |
| hidden_size, | |
| D_img, | |
| pad_id, bos_id, eos_id, | |
| device=torch.device('cpu')): | |
| super().__init__() | |
| self.device = device | |
| self.embed = nn.Embedding(vocab_size, token_feats, padding_idx=pad_id) | |
| self.ctx_proj = nn.Linear(D_img, token_feats, bias=False) | |
| self.mix = nn.Sequential( | |
| nn.Linear(token_feats*2, token_feats, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.LayerNorm(token_feats) | |
| ) | |
| self.cell = nn.LSTMCell(token_feats, hidden_size) | |
| self.head = nn.Linear(hidden_size, vocab_size, bias=False) | |
| self.attn = AdditiveAttn(D_img, hidden_size, dim=128) | |
| self.h0_fc = nn.Linear(D_img, hidden_size) | |
| self.c0_fc = nn.Linear(D_img, hidden_size) | |
| self.bos_id, self.eos_id = bos_id, eos_id | |
| self.to(device) | |
| def _init_state(self, img_seq): # [B,S,D] | |
| g = img_seq.mean(1) # [B,D] | |
| h0 = torch.tanh(self.h0_fc(g)) * 0.5 | |
| c0 = torch.tanh(self.c0_fc(g)) * 0.5 | |
| return h0, c0 | |
| def forward(self, img_seq, tokens_in): # tokens_in:[B,T] | |
| B, T = tokens_in.shape | |
| img_seq = img_seq.contiguous() | |
| proj_image = self.attn.precompute_image(img_seq) # [B,S,A] | |
| h, c = self._init_state(img_seq) | |
| V = self.head.out_features | |
| outs = torch.empty(B, T, V, device=img_seq.device, dtype=torch.float32) | |
| for t in range(T): | |
| emb = self.embed(tokens_in[:, t]) # [B,E] | |
| ctx, _ = self.attn.forward_cached(proj_image, img_seq, h) # [B,D] | |
| ctxE = self.ctx_proj(ctx) # [B,E] | |
| x = torch.cat([emb, ctxE], dim=-1) # [B,2E] | |
| x = self.mix(x) + emb # residual ile emb korunur | |
| h, c = self.cell(x, (h, c)) # [B,H] | |
| outs[:, t, :] = self.head(h) # [B,V] | |
| return outs | |
| def generate(self, img_seq, max_len=50): | |
| B = img_seq.size(0) | |
| img_seq = img_seq.contiguous() | |
| proj_image = self.attn.precompute_image(img_seq) | |
| h, c = self._init_state(img_seq) | |
| tok = torch.full((B,), self.bos_id, device=img_seq.device, dtype=torch.long) | |
| V = self.head.out_features | |
| outs = torch.empty(B, max_len, dtype=torch.long, device=img_seq.device) | |
| for t in range(max_len): | |
| emb = self.embed(tok) # [B,E] | |
| ctx, _ = self.attn.forward_cached(proj_image, img_seq, h) | |
| ctxE = self.ctx_proj(ctx) # [B,E] | |
| x = torch.cat([emb, ctxE], dim=-1) # [B,2E] | |
| x = self.mix(x) + emb | |
| h, c = self.cell(x, (h, c)) | |
| nxt = self.head(h).argmax(-1) # [B] | |
| outs[:, t] = nxt | |
| tok = nxt | |
| if (nxt == self.eos_id).all(): | |
| return outs[:, :t+1] | |
| return outs | |