File size: 6,922 Bytes
0d0d69a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
# ---------------- Positional Encoding ----------------
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (B, T, D)
return x + self.pe[:, :x.size(1)].to(x.device)
# ---------------- Transformer (sửa để match training) ----------------
class TransformerSeq2Seq(nn.Module):
"""
Thiết kế sao cho forward(src_embedded, tgt_input_ids, src_attn_mask=None, tgt_attn_mask=None)
- src_embedded: (B, S, E) — bạn có thể pass embedding matrix bên ngoài (embedding_src[src_ids])
- tgt_input_ids: (B, T) — token ids cho decoder input (BOS.. token_{n-1})
- src_attn_mask / tgt_attn_mask: (B, S) / (B, T) with 1 for real tokens, 0 for pad
"""
def __init__(self,
embed_dim,
vocab_size, # target vocab size (output dim)
embedding_decoder=None, # pretrained weights (np array or torch.Tensor) or None
num_heads=2,
num_layers=2,
dim_feedforward=256,
dropout=0.1,
freeze_decoder_emb=True,
max_len=512):
super().__init__()
self.embed_dim = embed_dim
self.vocab_size = vocab_size
# positional encoding
self.pos_encoder = PositionalEncoding(embed_dim, max_len=max_len)
# decoder embedding (pretrained optional)
if embedding_decoder is None:
self.embedding_decoder = nn.Embedding(vocab_size, embed_dim)
else:
if not isinstance(embedding_decoder, torch.Tensor):
embedding_decoder = torch.tensor(embedding_decoder, dtype=torch.float)
self.embedding_decoder = nn.Embedding.from_pretrained(embedding_decoder, freeze=freeze_decoder_emb)
# encoder/decoder (batch_first True -> inputs shape (B, T, E))
self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout,
batch_first=True)
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout,
batch_first=True)
self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
self.output_proj = nn.Linear(embed_dim, vocab_size)
def forward(self, src_embedded, tgt_input_ids, src_attn_mask=None, tgt_attn_mask=None):
"""
src_embedded : (B, S, E)
tgt_input_ids: (B, T)
src_attn_mask : (B, S) mask: 1 real token, 0 pad (optional)
tgt_attn_mask : (B, T) same
"""
device = src_embedded.device
# tgt embedding
tgt_embedded = self.embedding_decoder(tgt_input_ids) # (B, T, E)
# add positional encoding
src = self.pos_encoder(src_embedded) # (B, S, E)
tgt = self.pos_encoder(tgt_embedded) # (B, T, E)
# prepare key_padding_mask: True at positions that should be masked (pad positions)
src_key_padding_mask = None
tgt_key_padding_mask = None
if src_attn_mask is not None:
src_key_padding_mask = (src_attn_mask == 0).to(device) # (B, S), bool
if tgt_attn_mask is not None:
tgt_key_padding_mask = (tgt_attn_mask == 0).to(device) # (B, T)
# encode
memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask) # (B, S, E)
# causal mask for decoder (T x T)
T = tgt.size(1)
if T > 0:
tgt_mask = torch.triu(torch.full((T, T), float('-inf'), device=device), diagonal=1)
else:
tgt_mask = None
# decode
output = self.decoder(tgt, memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask) # (B, T, E)
logits = self.output_proj(output) # (B, T, vocab)
return logits
# ---------------- Helpers to apply embedding_src (tensor or nn.Embedding) ----------------
def apply_src_embedding(embedding_src, src_ids):
"""
embedding_src can be:
- torch.Tensor of shape (vocab_src, embed_dim) -> indexing
- nn.Embedding instance -> call( ids )
src_ids: LongTensor (B, S)
return: (B, S, E) float tensor on same device as src_ids
"""
if isinstance(embedding_src, nn.Embedding):
return embedding_src(src_ids)
else:
# assume it's a tensor/ndarray
if not isinstance(embedding_src, torch.Tensor):
embedding_src = torch.tensor(embedding_src, dtype=torch.float, device=src_ids.device)
else:
embedding_src = embedding_src.to(src_ids.device)
return embedding_src[src_ids]
@torch.no_grad()
def translate(model, src_sentence, tokenizer_src, tokenizer_tgt, embedding_src, device, max_len=50):
model.eval()
inputs = tokenizer_src([src_sentence], return_tensors="pt", padding=True, truncation=True, max_length=128)
src_ids = inputs["input_ids"].to(device) # (1, S)
src_attn = inputs.get("attention_mask", None)
if src_attn is not None:
src_attn = src_attn.to(device)
src_embedded = apply_src_embedding(embedding_src, src_ids) # (1, S, E)
decoded_ids = [tokenizer_tgt.cls_token_id]
for _ in range(max_len):
decoder_input = torch.tensor([decoded_ids], device=device)
# for decode we don't need tgt_attn_mask (we build causal mask inside model)
logits = model(src_embedded, decoder_input, src_attn_mask=src_attn, tgt_attn_mask=None)
next_token = logits[:, -1, :].argmax(dim=-1).item()
decoded_ids.append(next_token)
if next_token == tokenizer_tgt.sep_token_id:
break
return tokenizer_tgt.decode(decoded_ids, skip_special_tokens=True) |