| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import sys |
| import pathlib |
|
|
| |
| root = pathlib.Path(__file__).resolve().parent |
| while root.parent != root: |
| if (root / "requirements.txt").exists() or (root / "README.md").exists(): |
| sys.path.append(str(root)) |
| break |
| root = root.parent |
|
|
| from model.embedding import ThaiEmbedding |
| from model.transformer_block import TransformerBlock |
|
|
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 32000 |
| d_model: int = 256 |
| num_heads: int = 8 |
| num_layers: int = 6 |
| d_ff: int = 1024 |
| max_seq_len: int = 512 |
| dropout: float = 0.1 |
| pad_token_id: int = 0 |
|
|
| class ThaiTransformerEncoder(nn.Module): |
| """ |
| Shared encoder backbone stacking multiple Transformer blocks. |
| input_ids → embedding → N x TransformerBlock → hidden states |
| """ |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| self.embedding = ThaiEmbedding( |
| vocab_size=config.vocab_size, |
| d_model=config.d_model, |
| max_seq_len=config.max_seq_len, |
| dropout=config.dropout |
| ) |
| |
| |
| self.blocks = nn.ModuleList([ |
| TransformerBlock( |
| d_model=config.d_model, |
| num_heads=config.num_heads, |
| d_ff=config.d_ff, |
| dropout=config.dropout |
| ) |
| for _ in range(config.num_layers) |
| ]) |
| |
| |
| self.norm = nn.LayerNorm(config.d_model) |
|
|
| def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): |
| """ |
| # สร้าง padding_mask จาก attention_mask |
| # TransformerBlock ใช้ True = "ให้ mask ออก" (ตรงข้ามกับ HuggingFace convention) |
| """ |
| if attention_mask is not None: |
| padding_mask = attention_mask == 0 |
| else: |
| padding_mask = input_ids == self.config.pad_token_id |
|
|
| |
| x = self.embedding(input_ids) |
| |
| |
| |
| all_attn_weights = [] |
| for block in self.blocks: |
| x, attn_w = block(x, padding_mask=padding_mask) |
| all_attn_weights.append(attn_w) |
|
|
| |
| x = self.norm(x) |
|
|
| return x, all_attn_weights |
|
|
| if __name__ == "__main__": |
| cfg = ModelConfig(vocab_size=32000, d_model=256, |
| num_heads=8, num_layers=6, d_ff=1024) |
| encoder = ThaiTransformerEncoder(cfg) |
|
|
| |
| B, T = 2, 32 |
| input_ids = torch.randint(1, 32000, (B, T)) |
| input_ids[1, 20:] = 0 |
|
|
| attention_mask = (input_ids != 0).long() |
|
|
| hidden, attn_weights = encoder(input_ids, attention_mask) |
|
|
| assert hidden.shape == (B, T, 256), f"wrong shape: {hidden.shape}" |
| assert len(attn_weights) == 6, "ต้องได้ attn weights ครบ 6 layers" |
| assert not torch.isnan(hidden).any(), "NaN in output!" |
|
|
| |
| loss = hidden.sum() |
| loss.backward() |
| print(f"params: {sum(p.numel() for p in encoder.parameters()):,}") |
| print("encoder OK") |
|
|