|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
PAD, BOS, EOS, UNK = 0, 1, 2, 3 |
|
|
LANG2ID = {"vi": 0, "ty": 1} |
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
vocab_size: int |
|
|
d_model: int = 384 |
|
|
num_heads: int = 6 |
|
|
d_ff: int = 1536 |
|
|
num_encoder_layers: int = 6 |
|
|
num_decoder_layers: int = 6 |
|
|
max_pos: int = 1024 |
|
|
emb_dropout: float = 0.1 |
|
|
attn_pdrop: float = 0.1 |
|
|
resid_pdrop: float = 0.1 |
|
|
layerdrop: float = 0.1 |
|
|
pad_token_id: int = 0 |
|
|
tie_embeddings: bool = True |
|
|
num_langs: int = 2 |
|
|
|
|
|
|
|
|
class PositionalEmbedding(nn.Module): |
|
|
def __init__(self, max_pos, d_model): |
|
|
super().__init__() |
|
|
self.weight = nn.Embedding(max_pos, d_model) |
|
|
|
|
|
def forward(self, positions): |
|
|
return self.weight(positions) |
|
|
|
|
|
|
|
|
class Seq2SeqTransformer(nn.Module): |
|
|
def __init__(self, cfg: ModelConfig): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id) |
|
|
self.lang_emb = nn.Embedding(cfg.num_langs, cfg.d_model) |
|
|
self.pos_emb = PositionalEmbedding(cfg.max_pos, cfg.d_model) |
|
|
self.emb_drop = nn.Dropout(cfg.emb_dropout) |
|
|
|
|
|
self.enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff, |
|
|
dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder(self.enc_layer, num_layers=cfg.num_encoder_layers) |
|
|
|
|
|
self.dec_layer = nn.TransformerDecoderLayer( |
|
|
d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff, |
|
|
dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True |
|
|
) |
|
|
self.decoder = nn.TransformerDecoder(self.dec_layer, num_layers=cfg.num_decoder_layers) |
|
|
|
|
|
self.ln_enc = nn.RMSNorm(cfg.d_model) |
|
|
self.ln_dec = nn.RMSNorm(cfg.d_model) |
|
|
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
|
|
if cfg.tie_embeddings: |
|
|
self.lm_head.weight = self.token_emb.weight |
|
|
|
|
|
def encode(self, src_ids, src_lang_id): |
|
|
src_padding_mask = src_ids.eq(self.cfg.pad_token_id) |
|
|
x = self._embed(src_ids, src_lang_id) |
|
|
enc = self.encoder(x, src_key_padding_mask=src_padding_mask) |
|
|
return self.ln_enc(enc), src_padding_mask |
|
|
|
|
|
def decode(self, tgt_ids, enc_out, src_padding_mask, tgt_lang_id): |
|
|
tgt_padding_mask = tgt_ids.eq(self.cfg.pad_token_id) |
|
|
T = tgt_ids.size(1) |
|
|
causal = torch.triu(torch.ones(T, T, device=tgt_ids.device, dtype=torch.bool), 1) |
|
|
|
|
|
y = self._embed(tgt_ids, tgt_lang_id) |
|
|
dec = self.decoder( |
|
|
y, enc_out, |
|
|
tgt_mask=causal, |
|
|
tgt_key_padding_mask=tgt_padding_mask, |
|
|
memory_key_padding_mask=src_padding_mask |
|
|
) |
|
|
return self.ln_dec(dec) |
|
|
|
|
|
def _embed(self, input_ids, lang_id): |
|
|
B, T = input_ids.size() |
|
|
pos = torch.arange(T, device=input_ids.device) |
|
|
if T > self.cfg.max_pos: |
|
|
pos = pos.clamp_max(self.cfg.max_pos - 1) |
|
|
pos = pos.unsqueeze(0).expand(B, T) |
|
|
x = (self.token_emb(input_ids) |
|
|
+ self.pos_emb(pos) |
|
|
+ self.lang_emb(torch.full((B, T), lang_id, device=input_ids.device))) |
|
|
return self.emb_drop(x) |
|
|
|
|
|
def forward(self, src_ids, tgt_in_ids, src_lang_id, tgt_lang_id, labels=None): |
|
|
enc_out, src_padding_mask = self.encode(src_ids, src_lang_id) |
|
|
dec_out = self.decode(tgt_in_ids, enc_out, src_padding_mask, tgt_lang_id) |
|
|
logits = self.lm_head(dec_out) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), |
|
|
labels.view(-1), ignore_index=self.cfg.pad_token_id) |
|
|
return logits, loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, src_ids, src_lang_id, tgt_lang_id, max_len=128, bos_id=1, eos_id=2, beam_size=4, |
|
|
length_penalty=0.8): |
|
|
device = src_ids.device |
|
|
enc_out, src_padding_mask = self.encode(src_ids, src_lang_id) |
|
|
B = src_ids.size(0) |
|
|
assert B == 1, |
|
|
beams = [{"tokens": torch.tensor([bos_id], device=device), "logprob": 0.0, "finished": False} for _ in range(beam_size)] |
|
|
for _ in range(max_len): |
|
|
all_cand = [] |
|
|
for b in beams: |
|
|
if b["finished"]: |
|
|
all_cand.append(b); |
|
|
continue |
|
|
tgt = b["tokens"].unsqueeze(0) |
|
|
dec_h = self.decode(tgt, enc_out, src_padding_mask, tgt_lang_id) |
|
|
logit = self.lm_head(dec_h[:, -1, :]) |
|
|
logprobs = F.log_softmax(logit, dim=-1).squeeze(0) |
|
|
topv, topi = torch.topk(logprobs, beam_size) |
|
|
for score, tok in zip(topv.tolist(), topi.tolist()): |
|
|
new_toks = torch.cat([b["tokens"], torch.tensor([tok], device=device)]) |
|
|
all_cand.append({"tokens": new_toks, "logprob": b["logprob"] + score, "finished": tok == eos_id}) |
|
|
|
|
|
def lp(alpha, L): |
|
|
return ((5 + L) / 6) ** alpha |
|
|
|
|
|
beams = sorted(all_cand, key=lambda x: x["logprob"] / lp(length_penalty, len(x["tokens"])), reverse=True)[:beam_size] |
|
|
if all(b["finished"] for b in beams): break |
|
|
best = max(beams, key=lambda x: x["logprob"] / (((5 + len(x["tokens"])) / 6) ** length_penalty)) |
|
|
return best["tokens"] |
|
|
|