from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F PAD, BOS, EOS, UNK = 0, 1, 2, 3 LANG2ID = {"vi": 0, "ty": 1} @dataclass class ModelConfig: vocab_size: int d_model: int = 384 num_heads: int = 6 d_ff: int = 1536 num_encoder_layers: int = 6 num_decoder_layers: int = 6 max_pos: int = 1024 emb_dropout: float = 0.1 attn_pdrop: float = 0.1 resid_pdrop: float = 0.1 layerdrop: float = 0.1 pad_token_id: int = 0 tie_embeddings: bool = True num_langs: int = 2 # 0: vi, 1: ty class PositionalEmbedding(nn.Module): def __init__(self, max_pos, d_model): super().__init__() self.weight = nn.Embedding(max_pos, d_model) def forward(self, positions): return self.weight(positions) class Seq2SeqTransformer(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id) self.lang_emb = nn.Embedding(cfg.num_langs, cfg.d_model) self.pos_emb = PositionalEmbedding(cfg.max_pos, cfg.d_model) self.emb_drop = nn.Dropout(cfg.emb_dropout) self.enc_layer = nn.TransformerEncoderLayer( d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff, dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True ) self.encoder = nn.TransformerEncoder(self.enc_layer, num_layers=cfg.num_encoder_layers) self.dec_layer = nn.TransformerDecoderLayer( d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff, dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True ) self.decoder = nn.TransformerDecoder(self.dec_layer, num_layers=cfg.num_decoder_layers) self.ln_enc = nn.RMSNorm(cfg.d_model) self.ln_dec = nn.RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) if cfg.tie_embeddings: self.lm_head.weight = self.token_emb.weight def encode(self, src_ids, src_lang_id): src_padding_mask = src_ids.eq(self.cfg.pad_token_id) x = self._embed(src_ids, src_lang_id) enc = self.encoder(x, src_key_padding_mask=src_padding_mask) return self.ln_enc(enc), src_padding_mask def decode(self, tgt_ids, enc_out, src_padding_mask, tgt_lang_id): tgt_padding_mask = tgt_ids.eq(self.cfg.pad_token_id) T = tgt_ids.size(1) causal = torch.triu(torch.ones(T, T, device=tgt_ids.device, dtype=torch.bool), 1) y = self._embed(tgt_ids, tgt_lang_id) dec = self.decoder( y, enc_out, tgt_mask=causal, tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask ) return self.ln_dec(dec) def _embed(self, input_ids, lang_id): B, T = input_ids.size() pos = torch.arange(T, device=input_ids.device) if T > self.cfg.max_pos: pos = pos.clamp_max(self.cfg.max_pos - 1) pos = pos.unsqueeze(0).expand(B, T) x = (self.token_emb(input_ids) + self.pos_emb(pos) + self.lang_emb(torch.full((B, T), lang_id, device=input_ids.device))) return self.emb_drop(x) def forward(self, src_ids, tgt_in_ids, src_lang_id, tgt_lang_id, labels=None): enc_out, src_padding_mask = self.encode(src_ids, src_lang_id) dec_out = self.decode(tgt_in_ids, enc_out, src_padding_mask, tgt_lang_id) logits = self.lm_head(dec_out) loss = None if labels is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=self.cfg.pad_token_id) return logits, loss @torch.no_grad() def generate(self, src_ids, src_lang_id, tgt_lang_id, max_len=128, bos_id=1, eos_id=2, beam_size=4, length_penalty=0.8): device = src_ids.device enc_out, src_padding_mask = self.encode(src_ids, src_lang_id) B = src_ids.size(0) assert B == 1, beams = [{"tokens": torch.tensor([bos_id], device=device), "logprob": 0.0, "finished": False} for _ in range(beam_size)] for _ in range(max_len): all_cand = [] for b in beams: if b["finished"]: all_cand.append(b); continue tgt = b["tokens"].unsqueeze(0) dec_h = self.decode(tgt, enc_out, src_padding_mask, tgt_lang_id) logit = self.lm_head(dec_h[:, -1, :]) logprobs = F.log_softmax(logit, dim=-1).squeeze(0) topv, topi = torch.topk(logprobs, beam_size) for score, tok in zip(topv.tolist(), topi.tolist()): new_toks = torch.cat([b["tokens"], torch.tensor([tok], device=device)]) all_cand.append({"tokens": new_toks, "logprob": b["logprob"] + score, "finished": tok == eos_id}) def lp(alpha, L): return ((5 + L) / 6) ** alpha beams = sorted(all_cand, key=lambda x: x["logprob"] / lp(length_penalty, len(x["tokens"])), reverse=True)[:beam_size] if all(b["finished"] for b in beams): break best = max(beams, key=lambda x: x["logprob"] / (((5 + len(x["tokens"])) / 6) ** length_penalty)) return best["tokens"]