mavietduc commited on
Commit
fc6dcba
·
verified ·
1 Parent(s): 015742b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +135 -139
model.py CHANGED
@@ -1,139 +1,135 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
-
8
- PAD, BOS, EOS, UNK = 0, 1, 2, 3
9
- LANG2ID = {"vi": 0, "ty": 1}
10
-
11
- @dataclass
12
- class ModelConfig:
13
- vocab_size: int
14
- d_model: int = 384
15
- num_heads: int = 6
16
- d_ff: int = 1536
17
- num_encoder_layers: int = 6
18
- num_decoder_layers: int = 6
19
- max_pos: int = 1024
20
- emb_dropout: float = 0.1
21
- attn_pdrop: float = 0.1
22
- resid_pdrop: float = 0.1
23
- layerdrop: float = 0.1
24
- pad_token_id: int = 0
25
- tie_embeddings: bool = True
26
- num_langs: int = 2 # 0: vi, 1: ty
27
-
28
-
29
- class PositionalEmbedding(nn.Module):
30
- def __init__(self, max_pos, d_model):
31
- super().__init__()
32
- self.weight = nn.Embedding(max_pos, d_model)
33
-
34
- def forward(self, positions):
35
- return self.weight(positions)
36
-
37
-
38
- class Seq2SeqTransformer(nn.Module):
39
- def __init__(self, cfg: ModelConfig):
40
- super().__init__()
41
- self.cfg = cfg
42
- self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id)
43
- self.lang_emb = nn.Embedding(cfg.num_langs, cfg.d_model)
44
- self.pos_emb = PositionalEmbedding(cfg.max_pos, cfg.d_model)
45
- self.emb_drop = nn.Dropout(cfg.emb_dropout)
46
-
47
- self.enc_layer = nn.TransformerEncoderLayer(
48
- d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff,
49
- dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True
50
- )
51
- self.encoder = nn.TransformerEncoder(self.enc_layer, num_layers=cfg.num_encoder_layers)
52
-
53
- self.dec_layer = nn.TransformerDecoderLayer(
54
- d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff,
55
- dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True
56
- )
57
- self.decoder = nn.TransformerDecoder(self.dec_layer, num_layers=cfg.num_decoder_layers)
58
-
59
- self.ln_enc = nn.RMSNorm(cfg.d_model)
60
- self.ln_dec = nn.RMSNorm(cfg.d_model)
61
- self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
62
- if cfg.tie_embeddings:
63
- self.lm_head.weight = self.token_emb.weight
64
-
65
- def encode(self, src_ids, src_lang_id):
66
- # padding mask: True = vị trí bị chặn
67
- src_padding_mask = src_ids.eq(self.cfg.pad_token_id) # (B, T_src)
68
- x = self._embed(src_ids, src_lang_id) # (B, T_src, C)
69
- enc = self.encoder(x, src_key_padding_mask=src_padding_mask)
70
- return self.ln_enc(enc), src_padding_mask # giữ RMSNorm cuối stack
71
-
72
- def decode(self, tgt_ids, enc_out, src_padding_mask, tgt_lang_id):
73
- tgt_padding_mask = tgt_ids.eq(self.cfg.pad_token_id) # (B, T_tgt)
74
- T = tgt_ids.size(1)
75
- # causal mask: True = CHẶN (tam giác trên)
76
- causal = torch.triu(torch.ones(T, T, device=tgt_ids.device, dtype=torch.bool), 1)
77
-
78
- y = self._embed(tgt_ids, tgt_lang_id) # (B, T_tgt, C)
79
- dec = self.decoder(
80
- y, enc_out,
81
- tgt_mask=causal, # (T, T)
82
- tgt_key_padding_mask=tgt_padding_mask, # (B, T_tgt)
83
- memory_key_padding_mask=src_padding_mask # (B, T_src)
84
- )
85
- return self.ln_dec(dec)
86
-
87
- def _embed(self, input_ids, lang_id):
88
- B, T = input_ids.size()
89
- pos = torch.arange(T, device=input_ids.device)
90
- if T > self.cfg.max_pos:
91
- pos = pos.clamp_max(self.cfg.max_pos - 1)
92
- pos = pos.unsqueeze(0).expand(B, T)
93
- x = (self.token_emb(input_ids)
94
- + self.pos_emb(pos)
95
- + self.lang_emb(torch.full((B, T), lang_id, device=input_ids.device)))
96
- return self.emb_drop(x)
97
-
98
- def forward(self, src_ids, tgt_in_ids, src_lang_id, tgt_lang_id, labels=None):
99
- enc_out, src_padding_mask = self.encode(src_ids, src_lang_id)
100
- dec_out = self.decode(tgt_in_ids, enc_out, src_padding_mask, tgt_lang_id)
101
- logits = self.lm_head(dec_out)
102
- loss = None
103
- if labels is not None:
104
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
105
- labels.view(-1), ignore_index=self.cfg.pad_token_id)
106
- return logits, loss
107
-
108
- @torch.no_grad()
109
- def generate(self, src_ids, src_lang_id, tgt_lang_id, max_len=128, bos_id=1, eos_id=2, beam_size=4,
110
- length_penalty=0.8):
111
- device = src_ids.device
112
- enc_out, src_padding_mask = self.encode(src_ids, src_lang_id)
113
- B = src_ids.size(0)
114
- assert B == 1, "Beam search demo cho batch=1"
115
- beams = [{"tokens": torch.tensor([bos_id], device=device), "logprob": 0.0, "finished": False} for _ in
116
- range(beam_size)]
117
- for _ in range(max_len):
118
- all_cand = []
119
- for b in beams:
120
- if b["finished"]:
121
- all_cand.append(b);
122
- continue
123
- tgt = b["tokens"].unsqueeze(0)
124
- dec_h = self.decode(tgt, enc_out, src_padding_mask, tgt_lang_id)
125
- logit = self.lm_head(dec_h[:, -1, :])
126
- logprobs = F.log_softmax(logit, dim=-1).squeeze(0)
127
- topv, topi = torch.topk(logprobs, beam_size)
128
- for score, tok in zip(topv.tolist(), topi.tolist()):
129
- new_toks = torch.cat([b["tokens"], torch.tensor([tok], device=device)])
130
- all_cand.append({"tokens": new_toks, "logprob": b["logprob"] + score, "finished": tok == eos_id})
131
-
132
- def lp(alpha, L):
133
- return ((5 + L) / 6) ** alpha
134
-
135
- beams = sorted(all_cand, key=lambda x: x["logprob"] / lp(length_penalty, len(x["tokens"])), reverse=True)[
136
- :beam_size]
137
- if all(b["finished"] for b in beams): break
138
- best = max(beams, key=lambda x: x["logprob"] / (((5 + len(x["tokens"])) / 6) ** length_penalty))
139
- return best["tokens"]
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ PAD, BOS, EOS, UNK = 0, 1, 2, 3
9
+ LANG2ID = {"vi": 0, "ty": 1}
10
+
11
+ @dataclass
12
+ class ModelConfig:
13
+ vocab_size: int
14
+ d_model: int = 384
15
+ num_heads: int = 6
16
+ d_ff: int = 1536
17
+ num_encoder_layers: int = 6
18
+ num_decoder_layers: int = 6
19
+ max_pos: int = 1024
20
+ emb_dropout: float = 0.1
21
+ attn_pdrop: float = 0.1
22
+ resid_pdrop: float = 0.1
23
+ layerdrop: float = 0.1
24
+ pad_token_id: int = 0
25
+ tie_embeddings: bool = True
26
+ num_langs: int = 2 # 0: vi, 1: ty
27
+
28
+
29
+ class PositionalEmbedding(nn.Module):
30
+ def __init__(self, max_pos, d_model):
31
+ super().__init__()
32
+ self.weight = nn.Embedding(max_pos, d_model)
33
+
34
+ def forward(self, positions):
35
+ return self.weight(positions)
36
+
37
+
38
+ class Seq2SeqTransformer(nn.Module):
39
+ def __init__(self, cfg: ModelConfig):
40
+ super().__init__()
41
+ self.cfg = cfg
42
+ self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_token_id)
43
+ self.lang_emb = nn.Embedding(cfg.num_langs, cfg.d_model)
44
+ self.pos_emb = PositionalEmbedding(cfg.max_pos, cfg.d_model)
45
+ self.emb_drop = nn.Dropout(cfg.emb_dropout)
46
+
47
+ self.enc_layer = nn.TransformerEncoderLayer(
48
+ d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff,
49
+ dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True
50
+ )
51
+ self.encoder = nn.TransformerEncoder(self.enc_layer, num_layers=cfg.num_encoder_layers)
52
+
53
+ self.dec_layer = nn.TransformerDecoderLayer(
54
+ d_model=cfg.d_model, nhead=cfg.num_heads, dim_feedforward=cfg.d_ff,
55
+ dropout=cfg.resid_pdrop, activation="gelu", batch_first=True, norm_first=True
56
+ )
57
+ self.decoder = nn.TransformerDecoder(self.dec_layer, num_layers=cfg.num_decoder_layers)
58
+
59
+ self.ln_enc = nn.RMSNorm(cfg.d_model)
60
+ self.ln_dec = nn.RMSNorm(cfg.d_model)
61
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
62
+ if cfg.tie_embeddings:
63
+ self.lm_head.weight = self.token_emb.weight
64
+
65
+ def encode(self, src_ids, src_lang_id):
66
+ src_padding_mask = src_ids.eq(self.cfg.pad_token_id)
67
+ x = self._embed(src_ids, src_lang_id)
68
+ enc = self.encoder(x, src_key_padding_mask=src_padding_mask)
69
+ return self.ln_enc(enc), src_padding_mask
70
+
71
+ def decode(self, tgt_ids, enc_out, src_padding_mask, tgt_lang_id):
72
+ tgt_padding_mask = tgt_ids.eq(self.cfg.pad_token_id)
73
+ T = tgt_ids.size(1)
74
+ causal = torch.triu(torch.ones(T, T, device=tgt_ids.device, dtype=torch.bool), 1)
75
+
76
+ y = self._embed(tgt_ids, tgt_lang_id)
77
+ dec = self.decoder(
78
+ y, enc_out,
79
+ tgt_mask=causal,
80
+ tgt_key_padding_mask=tgt_padding_mask,
81
+ memory_key_padding_mask=src_padding_mask
82
+ )
83
+ return self.ln_dec(dec)
84
+
85
+ def _embed(self, input_ids, lang_id):
86
+ B, T = input_ids.size()
87
+ pos = torch.arange(T, device=input_ids.device)
88
+ if T > self.cfg.max_pos:
89
+ pos = pos.clamp_max(self.cfg.max_pos - 1)
90
+ pos = pos.unsqueeze(0).expand(B, T)
91
+ x = (self.token_emb(input_ids)
92
+ + self.pos_emb(pos)
93
+ + self.lang_emb(torch.full((B, T), lang_id, device=input_ids.device)))
94
+ return self.emb_drop(x)
95
+
96
+ def forward(self, src_ids, tgt_in_ids, src_lang_id, tgt_lang_id, labels=None):
97
+ enc_out, src_padding_mask = self.encode(src_ids, src_lang_id)
98
+ dec_out = self.decode(tgt_in_ids, enc_out, src_padding_mask, tgt_lang_id)
99
+ logits = self.lm_head(dec_out)
100
+ loss = None
101
+ if labels is not None:
102
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
103
+ labels.view(-1), ignore_index=self.cfg.pad_token_id)
104
+ return logits, loss
105
+
106
+ @torch.no_grad()
107
+ def generate(self, src_ids, src_lang_id, tgt_lang_id, max_len=128, bos_id=1, eos_id=2, beam_size=4,
108
+ length_penalty=0.8):
109
+ device = src_ids.device
110
+ enc_out, src_padding_mask = self.encode(src_ids, src_lang_id)
111
+ B = src_ids.size(0)
112
+ assert B == 1,
113
+ beams = [{"tokens": torch.tensor([bos_id], device=device), "logprob": 0.0, "finished": False} for _ in range(beam_size)]
114
+ for _ in range(max_len):
115
+ all_cand = []
116
+ for b in beams:
117
+ if b["finished"]:
118
+ all_cand.append(b);
119
+ continue
120
+ tgt = b["tokens"].unsqueeze(0)
121
+ dec_h = self.decode(tgt, enc_out, src_padding_mask, tgt_lang_id)
122
+ logit = self.lm_head(dec_h[:, -1, :])
123
+ logprobs = F.log_softmax(logit, dim=-1).squeeze(0)
124
+ topv, topi = torch.topk(logprobs, beam_size)
125
+ for score, tok in zip(topv.tolist(), topi.tolist()):
126
+ new_toks = torch.cat([b["tokens"], torch.tensor([tok], device=device)])
127
+ all_cand.append({"tokens": new_toks, "logprob": b["logprob"] + score, "finished": tok == eos_id})
128
+
129
+ def lp(alpha, L):
130
+ return ((5 + L) / 6) ** alpha
131
+
132
+ beams = sorted(all_cand, key=lambda x: x["logprob"] / lp(length_penalty, len(x["tokens"])), reverse=True)[:beam_size]
133
+ if all(b["finished"] for b in beams): break
134
+ best = max(beams, key=lambda x: x["logprob"] / (((5 + len(x["tokens"])) / 6) ** length_penalty))
135
+ return best["tokens"]