File size: 3,660 Bytes
a65f8c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
from torch import Tensor
from dataclasses import dataclass
from typing import Optional

import sys
import pathlib

# Add project root to sys.path
root = pathlib.Path(__file__).resolve().parent
while root.parent != root:
    if (root / "requirements.txt").exists() or (root / "README.md").exists():
        sys.path.append(str(root))
        break
    root = root.parent

from model.embedding import ThaiEmbedding
from model.transformer_block import TransformerBlock

@dataclass
class ModelConfig:
    vocab_size: int = 32000
    d_model: int = 256
    num_heads: int = 8
    num_layers: int = 6
    d_ff: int = 1024
    max_seq_len: int = 512
    dropout: float = 0.1
    pad_token_id: int = 0

class ThaiTransformerEncoder(nn.Module):
    """
    Shared encoder backbone stacking multiple Transformer blocks.
    input_ids → embedding → N x TransformerBlock → hidden states
    """
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Embedding layer (token + positional + layer norm)
        self.embedding = ThaiEmbedding(
            vocab_size=config.vocab_size,
            d_model=config.d_model,
            max_seq_len=config.max_seq_len,
            dropout=config.dropout
        )
        
        # Stack N transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=config.d_model,
                num_heads=config.num_heads,
                d_ff=config.d_ff,
                dropout=config.dropout
            )
            for _ in range(config.num_layers)
        ])
        
        # 3. Final norm (optional but common for pre-norm architectures)
        self.norm = nn.LayerNorm(config.d_model)

    def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None):
        """
        # สร้าง padding_mask จาก attention_mask
        # TransformerBlock ใช้ True = "ให้ mask ออก" (ตรงข้ามกับ HuggingFace convention)
        """
        if attention_mask is not None:
            padding_mask = attention_mask == 0 # (B, T) bool
        else:
            padding_mask = input_ids == self.config.pad_token_id 

        # Embeddings
        x = self.embedding(input_ids)
        
        # Pass through transformer blocks
        # เก็บ attn_weights ทุก layer ไว้สำหรับ visualization / debug
        all_attn_weights = []
        for block in self.blocks:
            x, attn_w = block(x, padding_mask=padding_mask)
            all_attn_weights.append(attn_w)

        # final
        x = self.norm(x)

        return x, all_attn_weights

if __name__ == "__main__":
    cfg = ModelConfig(vocab_size=32000, d_model=256,
                  num_heads=8, num_layers=6, d_ff=1024)
    encoder = ThaiTransformerEncoder(cfg)

    # dummy input พร้อม padding
    B, T = 2, 32
    input_ids = torch.randint(1, 32000, (B, T))
    input_ids[1, 20:] = 0   # batch ที่ 2 มี padding หลัง position 20

    attention_mask = (input_ids != 0).long()

    hidden, attn_weights = encoder(input_ids, attention_mask)

    assert hidden.shape == (B, T, 256), f"wrong shape: {hidden.shape}"
    assert len(attn_weights) == 6, "ต้องได้ attn weights ครบ 6 layers"
    assert not torch.isnan(hidden).any(), "NaN in output!"

    # padding positions ต้องไม่ส่งผลต่อ real tokens (approximate check)
    loss = hidden.sum()
    loss.backward()
    print(f"params: {sum(p.numel() for p in encoder.parameters()):,}")
    print("encoder OK")