ViTay-translation / model.py
mavietduc's picture
Update model.py
fc6dcba verified
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
PAD, BOS, EOS, UNK = 0, 1, 2, 3
LANG2ID = {"vi": 0, "ty": 1}
@dataclass
class ModelConfig:
vocab_size: int
d_model: int = 384
num_heads: int = 6
d_ff: int = 1536
num_encoder_layers: int = 6
num_decoder_layers: int = 6
max_pos: int = 1024
emb_dropout: float = 0.1
attn_pdrop: float = 0.1
resid_pdrop: float = 0.1
layerdrop: float = 0.1
pad_token_id: int = 0
tie_embeddings: bool = True
num_langs: int = 2 # 0: vi, 1: ty
class PositionalEmbedding(nn.Module):
def __init__(self, max_pos, d_model):
super().__init__()
self.weight = nn.Embedding(max_pos, d_model)
def forward(self, positions):
return self.weight(positions)
class Seq2SeqTransformer(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id)
self.lang_emb = nn.Embedding(cfg.num_langs, cfg.d_model)
self.pos_emb = PositionalEmbedding(cfg.max_pos, cfg.d_model)
self.emb_drop = nn.Dropout(cfg.emb_dropout)
self.enc_layer = nn.TransformerEncoderLayer(
d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff,
dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(self.enc_layer, num_layers=cfg.num_encoder_layers)
self.dec_layer = nn.TransformerDecoderLayer(
d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff,
dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True
)
self.decoder = nn.TransformerDecoder(self.dec_layer, num_layers=cfg.num_decoder_layers)
self.ln_enc = nn.RMSNorm(cfg.d_model)
self.ln_dec = nn.RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.token_emb.weight
def encode(self, src_ids, src_lang_id):
src_padding_mask = src_ids.eq(self.cfg.pad_token_id)
x = self._embed(src_ids, src_lang_id)
enc = self.encoder(x, src_key_padding_mask=src_padding_mask)
return self.ln_enc(enc), src_padding_mask
def decode(self, tgt_ids, enc_out, src_padding_mask, tgt_lang_id):
tgt_padding_mask = tgt_ids.eq(self.cfg.pad_token_id)
T = tgt_ids.size(1)
causal = torch.triu(torch.ones(T, T, device=tgt_ids.device, dtype=torch.bool), 1)
y = self._embed(tgt_ids, tgt_lang_id)
dec = self.decoder(
y, enc_out,
tgt_mask=causal,
tgt_key_padding_mask=tgt_padding_mask,
memory_key_padding_mask=src_padding_mask
)
return self.ln_dec(dec)
def _embed(self, input_ids, lang_id):
B, T = input_ids.size()
pos = torch.arange(T, device=input_ids.device)
if T > self.cfg.max_pos:
pos = pos.clamp_max(self.cfg.max_pos - 1)
pos = pos.unsqueeze(0).expand(B, T)
x = (self.token_emb(input_ids)
+ self.pos_emb(pos)
+ self.lang_emb(torch.full((B, T), lang_id, device=input_ids.device)))
return self.emb_drop(x)
def forward(self, src_ids, tgt_in_ids, src_lang_id, tgt_lang_id, labels=None):
enc_out, src_padding_mask = self.encode(src_ids, src_lang_id)
dec_out = self.decode(tgt_in_ids, enc_out, src_padding_mask, tgt_lang_id)
logits = self.lm_head(dec_out)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
labels.view(-1), ignore_index=self.cfg.pad_token_id)
return logits, loss
@torch.no_grad()
def generate(self, src_ids, src_lang_id, tgt_lang_id, max_len=128, bos_id=1, eos_id=2, beam_size=4,
length_penalty=0.8):
device = src_ids.device
enc_out, src_padding_mask = self.encode(src_ids, src_lang_id)
B = src_ids.size(0)
assert B == 1,
beams = [{"tokens": torch.tensor([bos_id], device=device), "logprob": 0.0, "finished": False} for _ in range(beam_size)]
for _ in range(max_len):
all_cand = []
for b in beams:
if b["finished"]:
all_cand.append(b);
continue
tgt = b["tokens"].unsqueeze(0)
dec_h = self.decode(tgt, enc_out, src_padding_mask, tgt_lang_id)
logit = self.lm_head(dec_h[:, -1, :])
logprobs = F.log_softmax(logit, dim=-1).squeeze(0)
topv, topi = torch.topk(logprobs, beam_size)
for score, tok in zip(topv.tolist(), topi.tolist()):
new_toks = torch.cat([b["tokens"], torch.tensor([tok], device=device)])
all_cand.append({"tokens": new_toks, "logprob": b["logprob"] + score, "finished": tok == eos_id})
def lp(alpha, L):
return ((5 + L) / 6) ** alpha
beams = sorted(all_cand, key=lambda x: x["logprob"] / lp(length_penalty, len(x["tokens"])), reverse=True)[:beam_size]
if all(b["finished"] for b in beams): break
best = max(beams, key=lambda x: x["logprob"] / (((5 + len(x["tokens"])) / 6) ** length_penalty))
return best["tokens"]