File size: 6,922 Bytes
0d0d69a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math

# ---------------- Positional Encoding ----------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (B, T, D)
        return x + self.pe[:, :x.size(1)].to(x.device)

# ---------------- Transformer (sửa để match training) ----------------
class TransformerSeq2Seq(nn.Module):
    """

    Thiết kế sao cho forward(src_embedded, tgt_input_ids, src_attn_mask=None, tgt_attn_mask=None)

    - src_embedded: (B, S, E) — bạn có thể pass embedding matrix bên ngoài (embedding_src[src_ids])

    - tgt_input_ids: (B, T)  — token ids cho decoder input (BOS.. token_{n-1})

    - src_attn_mask / tgt_attn_mask: (B, S) / (B, T) with 1 for real tokens, 0 for pad

    """
    def __init__(self,

                 embed_dim,

                 vocab_size,                   # target vocab size (output dim)

                 embedding_decoder=None,       # pretrained weights (np array or torch.Tensor) or None

                 num_heads=2,

                 num_layers=2,

                 dim_feedforward=256,

                 dropout=0.1,

                 freeze_decoder_emb=True,

                 max_len=512):
        super().__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size

        # positional encoding
        self.pos_encoder = PositionalEncoding(embed_dim, max_len=max_len)

        # decoder embedding (pretrained optional)
        if embedding_decoder is None:
            self.embedding_decoder = nn.Embedding(vocab_size, embed_dim)
        else:
            if not isinstance(embedding_decoder, torch.Tensor):
                embedding_decoder = torch.tensor(embedding_decoder, dtype=torch.float)
            self.embedding_decoder = nn.Embedding.from_pretrained(embedding_decoder, freeze=freeze_decoder_emb)

        # encoder/decoder (batch_first True -> inputs shape (B, T, E))
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                                        dim_feedforward=dim_feedforward, dropout=dropout,
                                                        batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads,
                                                        dim_feedforward=dim_feedforward, dropout=dropout,
                                                        batch_first=True)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)

        self.output_proj = nn.Linear(embed_dim, vocab_size)

    def forward(self, src_embedded, tgt_input_ids, src_attn_mask=None, tgt_attn_mask=None):
        """

        src_embedded : (B, S, E)

        tgt_input_ids: (B, T)

        src_attn_mask : (B, S) mask: 1 real token, 0 pad  (optional)

        tgt_attn_mask : (B, T) same

        """
        device = src_embedded.device
        # tgt embedding
        tgt_embedded = self.embedding_decoder(tgt_input_ids)  # (B, T, E)

        # add positional encoding
        src = self.pos_encoder(src_embedded)  # (B, S, E)
        tgt = self.pos_encoder(tgt_embedded)  # (B, T, E)

        # prepare key_padding_mask: True at positions that should be masked (pad positions)
        src_key_padding_mask = None
        tgt_key_padding_mask = None
        if src_attn_mask is not None:
            src_key_padding_mask = (src_attn_mask == 0).to(device)  # (B, S), bool
        if tgt_attn_mask is not None:
            tgt_key_padding_mask = (tgt_attn_mask == 0).to(device)  # (B, T)

        # encode
        memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask)  # (B, S, E)

        # causal mask for decoder (T x T)
        T = tgt.size(1)
        if T > 0:
            tgt_mask = torch.triu(torch.full((T, T), float('-inf'), device=device), diagonal=1)
        else:
            tgt_mask = None

        # decode
        output = self.decoder(tgt, memory,
                              tgt_mask=tgt_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=src_key_padding_mask)  # (B, T, E)

        logits = self.output_proj(output)  # (B, T, vocab)
        return logits

# ---------------- Helpers to apply embedding_src (tensor or nn.Embedding) ----------------
def apply_src_embedding(embedding_src, src_ids):
    """

    embedding_src can be:

      - torch.Tensor of shape (vocab_src, embed_dim)  -> indexing

      - nn.Embedding instance -> call( ids )

    src_ids: LongTensor (B, S)

    return: (B, S, E) float tensor on same device as src_ids

    """
    if isinstance(embedding_src, nn.Embedding):
        return embedding_src(src_ids)
    else:
        # assume it's a tensor/ndarray
        if not isinstance(embedding_src, torch.Tensor):
            embedding_src = torch.tensor(embedding_src, dtype=torch.float, device=src_ids.device)
        else:
            embedding_src = embedding_src.to(src_ids.device)
        return embedding_src[src_ids]
@torch.no_grad()
def translate(model, src_sentence, tokenizer_src, tokenizer_tgt, embedding_src, device, max_len=50):
    model.eval()
    inputs = tokenizer_src([src_sentence], return_tensors="pt", padding=True, truncation=True, max_length=128)
    src_ids = inputs["input_ids"].to(device)      # (1, S)
    src_attn = inputs.get("attention_mask", None)
    if src_attn is not None:
        src_attn = src_attn.to(device)

    src_embedded = apply_src_embedding(embedding_src, src_ids)  # (1, S, E)

    decoded_ids = [tokenizer_tgt.cls_token_id]
    for _ in range(max_len):
        decoder_input = torch.tensor([decoded_ids], device=device)
        # for decode we don't need tgt_attn_mask (we build causal mask inside model)
        logits = model(src_embedded, decoder_input, src_attn_mask=src_attn, tgt_attn_mask=None)
        next_token = logits[:, -1, :].argmax(dim=-1).item()
        decoded_ids.append(next_token)
        if next_token == tokenizer_tgt.sep_token_id:
            break

    return tokenizer_tgt.decode(decoded_ids, skip_special_tokens=True)