thai-nlp-toolkit / model /encoder.py
puttimej's picture
Upload model/encoder.py with huggingface_hub
a65f8c6 verified
Raw
History Blame Contribute Delete
3.66 kB
import torch
import torch.nn as nn
from torch import Tensor
from dataclasses import dataclass
from typing import Optional
import sys
import pathlib
# Add project root to sys.path
root = pathlib.Path(__file__).resolve().parent
while root.parent != root:
if (root / "requirements.txt").exists() or (root / "README.md").exists():
sys.path.append(str(root))
break
root = root.parent
from model.embedding import ThaiEmbedding
from model.transformer_block import TransformerBlock
@dataclass
class ModelConfig:
vocab_size: int = 32000
d_model: int = 256
num_heads: int = 8
num_layers: int = 6
d_ff: int = 1024
max_seq_len: int = 512
dropout: float = 0.1
pad_token_id: int = 0
class ThaiTransformerEncoder(nn.Module):
"""
Shared encoder backbone stacking multiple Transformer blocks.
input_ids → embedding → N x TransformerBlock → hidden states
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Embedding layer (token + positional + layer norm)
self.embedding = ThaiEmbedding(
vocab_size=config.vocab_size,
d_model=config.d_model,
max_seq_len=config.max_seq_len,
dropout=config.dropout
)
# Stack N transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
dropout=config.dropout
)
for _ in range(config.num_layers)
])
# 3. Final norm (optional but common for pre-norm architectures)
self.norm = nn.LayerNorm(config.d_model)
def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None):
"""
# สร้าง padding_mask จาก attention_mask
# TransformerBlock ใช้ True = "ให้ mask ออก" (ตรงข้ามกับ HuggingFace convention)
"""
if attention_mask is not None:
padding_mask = attention_mask == 0 # (B, T) bool
else:
padding_mask = input_ids == self.config.pad_token_id
# Embeddings
x = self.embedding(input_ids)
# Pass through transformer blocks
# เก็บ attn_weights ทุก layer ไว้สำหรับ visualization / debug
all_attn_weights = []
for block in self.blocks:
x, attn_w = block(x, padding_mask=padding_mask)
all_attn_weights.append(attn_w)
# final
x = self.norm(x)
return x, all_attn_weights
if __name__ == "__main__":
cfg = ModelConfig(vocab_size=32000, d_model=256,
num_heads=8, num_layers=6, d_ff=1024)
encoder = ThaiTransformerEncoder(cfg)
# dummy input พร้อม padding
B, T = 2, 32
input_ids = torch.randint(1, 32000, (B, T))
input_ids[1, 20:] = 0 # batch ที่ 2 มี padding หลัง position 20
attention_mask = (input_ids != 0).long()
hidden, attn_weights = encoder(input_ids, attention_mask)
assert hidden.shape == (B, T, 256), f"wrong shape: {hidden.shape}"
assert len(attn_weights) == 6, "ต้องได้ attn weights ครบ 6 layers"
assert not torch.isnan(hidden).any(), "NaN in output!"
# padding positions ต้องไม่ส่งผลต่อ real tokens (approximate check)
loss = hidden.sum()
loss.backward()
print(f"params: {sum(p.numel() for p in encoder.parameters()):,}")
print("encoder OK")