vi-en-transformer-25m / src /transformer_encoder_decoder.py
Cong123779's picture
Upload model source code
51b3b77 verified
"""
TRANSFORMER ENCODER & DECODER (FIXED)
Xây dựng hoàn chỉnh Encoder và Decoder layers
"""
import torch
import torch.nn as nn
from .transformer_components import (
MultiHeadAttention,
PositionwiseFeedForward,
ResidualConnection,
LayerNorm
)
# ============================================================================
# 1. ENCODER LAYER
# ============================================================================
class EncoderLayer(nn.Module):
"""
Một layer của Transformer Encoder
Gồm:
1. Multi-Head Self-Attention
2. Add & Norm
3. Feed-Forward Network
4. Add & Norm
Args:
d_model: Dimension của model
n_heads: Số lượng attention heads
d_ff: Dimension của feed-forward network
dropout: Dropout rate
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# Multi-Head Self-Attention
self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
# Feed-Forward Network
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
# Residual Connections
self.residual1 = ResidualConnection(d_model, dropout)
self.residual2 = ResidualConnection(d_model, dropout)
def forward(self, x, mask=None):
"""
Args:
x: Input [batch_size, seq_len, d_model]
mask: Mask tensor [batch_size, 1, 1, seq_len] (để mask padding)
Returns:
output: [batch_size, seq_len, d_model]
"""
# 1. Self-Attention với Residual Connection
x = self.residual1(x, lambda x: self.self_attention(x, x, x, mask)[0])
# 2. Feed-Forward với Residual Connection
x = self.residual2(x, self.feed_forward)
return x
# ============================================================================
# 2. ENCODER
# ============================================================================
class Encoder(nn.Module):
"""
Transformer Encoder - Stack của N encoder layers
Args:
vocab_size: Kích thước vocabulary
d_model: Dimension của model
n_layers: Số lượng encoder layers
n_heads: Số lượng attention heads
d_ff: Dimension của feed-forward network
dropout: Dropout rate
max_len: Maximum sequence length
"""
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout=0.1, max_len=5000):
super().__init__()
from .transformer_components import Embedding, PositionalEncoding
# Embedding layer
self.embedding = Embedding(vocab_size, d_model)
# Positional Encoding
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
# Stack of Encoder Layers
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# Final Layer Normalization
self.norm = LayerNorm(d_model)
def forward(self, src, src_mask=None):
"""
Args:
src: Source sequence [batch_size, src_len]
src_mask: Source mask [batch_size, 1, 1, src_len]
Returns:
output: [batch_size, src_len, d_model]
"""
# 1. Embedding + Positional Encoding
x = self.embedding(src)
x = self.pos_encoding(x)
# 2. Pass through encoder layers
for layer in self.layers:
x = layer(x, src_mask)
# 3. Final normalization
x = self.norm(x)
return x
# ============================================================================
# 3. DECODER LAYER
# ============================================================================
class DecoderLayer(nn.Module):
"""
Một layer của Transformer Decoder
Gồm:
1. Masked Multi-Head Self-Attention
2. Add & Norm
3. Multi-Head Cross-Attention (với Encoder output)
4. Add & Norm
5. Feed-Forward Network
6. Add & Norm
Args:
d_model: Dimension của model
n_heads: Số lượng attention heads
d_ff: Dimension của feed-forward network
dropout: Dropout rate
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# Masked Multi-Head Self-Attention
self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
# Multi-Head Cross-Attention (Encoder-Decoder Attention)
self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
# Feed-Forward Network
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
# Residual Connections
self.residual1 = ResidualConnection(d_model, dropout)
self.residual2 = ResidualConnection(d_model, dropout)
self.residual3 = ResidualConnection(d_model, dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: Target input [batch_size, tgt_len, d_model]
encoder_output: Encoder output [batch_size, src_len, d_model]
src_mask: Source mask [batch_size, 1, 1, src_len]
tgt_mask: Target mask [batch_size, 1, tgt_len, tgt_len] (causal mask)
Returns:
output: [batch_size, tgt_len, d_model]
"""
# 1. Masked Self-Attention với Residual Connection
x = self.residual1(x, lambda x: self.self_attention(x, x, x, tgt_mask)[0])
# 2. Cross-Attention với Encoder output
# Q từ decoder, K, V từ encoder
x = self.residual2(x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask)[0])
# 3. Feed-Forward với Residual Connection
x = self.residual3(x, self.feed_forward)
return x
# ============================================================================
# 4. DECODER
# ============================================================================
class Decoder(nn.Module):
"""
Transformer Decoder - Stack của N decoder layers
Args:
vocab_size: Kích thước vocabulary
d_model: Dimension của model
n_layers: Số lượng decoder layers
n_heads: Số lượng attention heads
d_ff: Dimension của feed-forward network
dropout: Dropout rate
max_len: Maximum sequence length
"""
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout=0.1, max_len=5000):
super().__init__()
from .transformer_components import Embedding, PositionalEncoding
# Embedding layer
self.embedding = Embedding(vocab_size, d_model)
# Positional Encoding
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
# Stack of Decoder Layers
self.layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# Final Layer Normalization
self.norm = LayerNorm(d_model)
# Output projection to vocabulary
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
tgt: Target sequence [batch_size, tgt_len]
encoder_output: Encoder output [batch_size, src_len, d_model]
src_mask: Source mask [batch_size, 1, 1, src_len]
tgt_mask: Target mask [batch_size, 1, tgt_len, tgt_len]
Returns:
output: [batch_size, tgt_len, vocab_size]
"""
# 1. Embedding + Positional Encoding
x = self.embedding(tgt)
x = self.pos_encoding(x)
# 2. Pass through decoder layers
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
# 3. Final normalization
x = self.norm(x)
# 4. Project to vocabulary
output = self.fc_out(x)
return output
# ============================================================================
# 5. MASK FUNCTIONS (FIXED)
# ============================================================================
def create_padding_mask(seq, pad_idx=0):
"""
Tạo mask cho padding tokens
Args:
seq: Sequence [batch_size, seq_len]
pad_idx: Index của padding token
Returns:
mask: [batch_size, 1, 1, seq_len] (bool type)
"""
# Tạo mask: True cho non-padding, False cho padding
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask # Returns bool tensor
def create_causal_mask(seq_len, device):
"""
Tạo causal mask (look-ahead mask) cho decoder
Ngăn decoder nhìn thấy future tokens
Args:
seq_len: Length của sequence
device: Device (cuda hoặc cpu)
Returns:
mask: [1, 1, seq_len, seq_len] (bool type)
"""
# Tạo lower triangular matrix - FIXED: convert to bool
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
mask = mask.bool() # Convert to bool
mask = mask.unsqueeze(0).unsqueeze(1)
return mask
def create_target_mask(tgt, pad_idx=0):
"""
Tạo mask kết hợp cho target sequence (padding + causal)
Args:
tgt: Target sequence [batch_size, tgt_len]
pad_idx: Index của padding token
Returns:
mask: [batch_size, 1, tgt_len, tgt_len] (bool type)
"""
batch_size, tgt_len = tgt.size()
device = tgt.device
# Padding mask - returns bool
padding_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2) # [batch, 1, 1, tgt_len]
# Causal mask - returns bool
causal_mask = create_causal_mask(tgt_len, device) # [1, 1, tgt_len, tgt_len]
# Kết hợp cả 2 masks - both are bool now
mask = padding_mask & causal_mask
return mask
# ============================================================================
# 6. TEST ENCODER & DECODER
# ============================================================================
if __name__ == "__main__":
print("="*70)
print("KIỂM TRA ENCODER & DECODER")
print("="*70)
# Hyperparameters
batch_size = 2
src_len = 10
tgt_len = 12
src_vocab_size = 10000
tgt_vocab_size = 8000
d_model = 512
n_layers = 6
n_heads = 8
d_ff = 2048
dropout = 0.1
pad_idx = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")
# Tạo dummy data
src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device)
tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len)).to(device)
# Tạo masks
src_mask = create_padding_mask(src, pad_idx).to(device)
tgt_mask = create_target_mask(tgt, pad_idx).to(device)
print(f"\nInput shapes:")
print(f" Source: {src.shape}")
print(f" Target: {tgt.shape}")
print(f" Source mask: {src_mask.shape}, dtype: {src_mask.dtype}")
print(f" Target mask: {tgt_mask.shape}, dtype: {tgt_mask.dtype}")
# Test Encoder
print("\n" + "="*70)
print("Test Encoder")
print("="*70)
encoder = Encoder(
vocab_size=src_vocab_size,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads,
d_ff=d_ff,
dropout=dropout
).to(device)
encoder_output = encoder(src, src_mask)
print(f"Encoder output shape: {encoder_output.shape}")
print(f"Expected: [{batch_size}, {src_len}, {d_model}]")
# Test Decoder
print("\n" + "="*70)
print("Test Decoder")
print("="*70)
decoder = Decoder(
vocab_size=tgt_vocab_size,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads,
d_ff=d_ff,
dropout=dropout
).to(device)
decoder_output = decoder(tgt, encoder_output, src_mask, tgt_mask)
print(f"Decoder output shape: {decoder_output.shape}")
print(f"Expected: [{batch_size}, {tgt_len}, {tgt_vocab_size}]")
# Số lượng parameters
encoder_params = sum(p.numel() for p in encoder.parameters())
decoder_params = sum(p.numel() for p in decoder.parameters())
print("\n" + "="*70)
print("THỐNG KÊ MÔ HÌNH")
print("="*70)
print(f"Encoder parameters: {encoder_params:,}")
print(f"Decoder parameters: {decoder_params:,}")
print(f"Total parameters: {encoder_params + decoder_params:,}")
print("\n" + "="*70)
print("✓ ENCODER & DECODER HOẠT ĐỘNG ĐÚNG!")
print("="*70)