""" COMPLETE TRANSFORMER MODEL Mô hình Transformer hoàn chỉnh cho dịch máy Seq2Seq """ import torch import torch.nn as nn from .transformer_encoder_decoder import ( Encoder, Decoder, create_padding_mask, create_target_mask ) # ============================================================================ # TRANSFORMER MODEL # ============================================================================ class Transformer(nn.Module): """ Mô hình Transformer hoàn chỉnh cho Neural Machine Translation Args: src_vocab_size: Kích thước vocabulary source language tgt_vocab_size: Kích thước vocabulary target language d_model: Dimension của model (mặc định 512) n_layers: Số lượng encoder/decoder layers (mặc định 6) n_heads: Số lượng attention heads (mặc định 8) d_ff: Dimension của feed-forward network (mặc định 2048) dropout: Dropout rate (mặc định 0.1) max_len: Maximum sequence length (mặc định 5000) pad_idx: Index của padding token (mặc định 0) """ def __init__( self, src_vocab_size, tgt_vocab_size, d_model=512, n_layers=6, n_heads=8, d_ff=2048, dropout=0.1, max_len=5000, pad_idx=0 ): super().__init__() self.pad_idx = pad_idx # Encoder self.encoder = Encoder( vocab_size=src_vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads, d_ff=d_ff, dropout=dropout, max_len=max_len ) # Decoder self.decoder = Decoder( vocab_size=tgt_vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads, d_ff=d_ff, dropout=dropout, max_len=max_len ) # Khởi tạo weights self._init_weights() def _init_weights(self): """ Khởi tạo weights theo Xavier Uniform """ for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, tgt): """ Forward pass Args: src: Source sequence [batch_size, src_len] tgt: Target sequence [batch_size, tgt_len] Returns: output: Logits [batch_size, tgt_len, tgt_vocab_size] """ # Tạo masks src_mask = create_padding_mask(src, self.pad_idx) tgt_mask = create_target_mask(tgt, self.pad_idx) # Encoder encoder_output = self.encoder(src, src_mask) # Decoder output = self.decoder(tgt, encoder_output, src_mask, tgt_mask) return output def encode(self, src): """ Chỉ chạy encoder (dùng khi inference) Args: src: Source sequence [batch_size, src_len] Returns: encoder_output: [batch_size, src_len, d_model] src_mask: [batch_size, 1, 1, src_len] """ src_mask = create_padding_mask(src, self.pad_idx) encoder_output = self.encoder(src, src_mask) return encoder_output, src_mask def decode(self, tgt, encoder_output, src_mask): """ Chỉ chạy decoder (dùng khi inference) Args: tgt: Target sequence [batch_size, tgt_len] encoder_output: Encoder output [batch_size, src_len, d_model] src_mask: Source mask [batch_size, 1, 1, src_len] Returns: output: Logits [batch_size, tgt_len, tgt_vocab_size] """ tgt_mask = create_target_mask(tgt, self.pad_idx) output = self.decoder(tgt, encoder_output, src_mask, tgt_mask) return output # ============================================================================ # TRANSFORMER WITH SHARED VOCABULARY & WEIGHT TYING # ============================================================================ class TransformerShared(nn.Module): """ Transformer với Shared Vocabulary và Weight Tying Đặc điểm: - Dùng chung 1 vocabulary cho cả source và target - Embedding input và output layer chia sẻ weights (Weight Tying) - Tiết kiệm ~50% parameters so với model riêng biệt - Học được mối liên hệ trực tiếp giữa 2 ngôn ngữ tốt hơn Args: vocab_size: Kích thước shared vocabulary d_model: Dimension của model (mặc định 512) n_layers: Số lượng encoder/decoder layers (mặc định 6) n_heads: Số lượng attention heads (mặc định 8) d_ff: Dimension của feed-forward network (mặc định 2048) dropout: Dropout rate (mặc định 0.1) max_len: Maximum sequence length (mặc định 5000) pad_idx: Index của padding token (mặc định 0) use_weight_tying: Có dùng weight tying không (mặc định True) """ def __init__( self, vocab_size, d_model=512, n_layers=6, n_heads=8, d_ff=2048, dropout=0.1, max_len=5000, pad_idx=0, use_weight_tying=True ): super().__init__() self.pad_idx = pad_idx self.d_model = d_model self.use_weight_tying = use_weight_tying # --- KHÁC BIỆT LỚN NHẤT --- # Chỉ tạo 1 Embedding matrix dùng cho cả 2 ngôn ngữ from .transformer_components import Embedding, PositionalEncoding self.shared_embedding = Embedding(vocab_size, d_model) # Positional Encoding self.pos_encoding = PositionalEncoding(d_model, max_len, dropout) # Encoder (dùng shared embedding) # Tạo Encoder nhưng sẽ thay embedding sau self.encoder = Encoder( vocab_size=vocab_size, # Dùng chung vocab_size d_model=d_model, n_layers=n_layers, n_heads=n_heads, d_ff=d_ff, dropout=dropout, max_len=max_len ) # Thay thế embedding của encoder bằng shared embedding # Quan trọng: Phải thay thế sau khi tạo Encoder self.encoder.embedding = self.shared_embedding # Decoder (dùng shared embedding) # Tạo Decoder nhưng sẽ thay embedding sau self.decoder = Decoder( vocab_size=vocab_size, # Dùng chung vocab_size d_model=d_model, n_layers=n_layers, n_heads=n_heads, d_ff=d_ff, dropout=dropout, max_len=max_len ) # Thay thế embedding của decoder bằng shared embedding # Quan trọng: Phải thay thế sau khi tạo Decoder self.decoder.embedding = self.shared_embedding # QUAN TRỌNG: Bỏ qua fc_out của Decoder vì chúng ta dùng output_layer riêng # Decoder.fc_out sẽ được thay thế bằng identity function self.decoder.fc_out = nn.Identity() # Output layer self.output_layer = nn.Linear(d_model, vocab_size, bias=False) # Khởi tạo weights TRƯỚC khi weight tying self._init_weights() # --- KÍCH HOẠT WEIGHT TYING --- if use_weight_tying: # Dòng này giúp tiết kiệm 50% tham số vocab # Embedding weight và output layer weight chia sẻ nhau # QUAN TRỌNG: Phải gán SAU _init_weights() để không bị reset self.output_layer.weight = self.shared_embedding.embedding.weight print("✓ Weight Tying enabled: Embedding và Output layer chia sẻ weights") def _init_weights(self): """ Khởi tạo weights theo Xavier Uniform """ for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, tgt): """ Forward pass Args: src: Source sequence [batch_size, src_len] tgt: Target sequence [batch_size, tgt_len] Returns: output: Logits [batch_size, tgt_len, vocab_size] """ # Tạo masks src_mask = create_padding_mask(src, self.pad_idx) tgt_mask = create_target_mask(tgt, self.pad_idx) # Encoder (dùng shared embedding) encoder_output = self.encoder(src, src_mask) # Decoder (dùng shared embedding) decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask) # Debug: Kiểm tra shape # print(f"DEBUG: decoder_output shape: {decoder_output.shape}") # print(f"DEBUG: output_layer weight shape: {self.output_layer.weight.shape}") # Output layer (có thể share weight với embedding nếu use_weight_tying=True) output = self.output_layer(decoder_output) return output def encode(self, src): """ Chỉ chạy encoder (dùng khi inference) Args: src: Source sequence [batch_size, src_len] Returns: encoder_output: [batch_size, src_len, d_model] src_mask: [batch_size, 1, 1, src_len] """ src_mask = create_padding_mask(src, self.pad_idx) encoder_output = self.encoder(src, src_mask) return encoder_output, src_mask def decode(self, tgt, encoder_output, src_mask): """ Chỉ chạy decoder (dùng khi inference) Args: tgt: Target sequence [batch_size, tgt_len] encoder_output: Encoder output [batch_size, src_len, d_model] src_mask: Source mask [batch_size, 1, 1, src_len] Returns: output: Logits [batch_size, tgt_len, vocab_size] """ tgt_mask = create_target_mask(tgt, self.pad_idx) decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask) output = self.output_layer(decoder_output) return output # ============================================================================ # MODEL CONFIGURATION # ============================================================================ def get_model_config(model_size='base'): """ Trả về config cho các kích thước model khác nhau Args: model_size: 'tiny', 'small', 'base', 'large' Returns: config: Dictionary chứa hyperparameters """ configs = { 'tiny': { 'd_model': 256, 'n_layers': 2, 'n_heads': 4, 'd_ff': 1024, 'dropout': 0.1 }, 'small': { 'd_model': 256, 'n_layers': 4, 'n_heads': 8, 'd_ff': 1024, 'dropout': 0.1 }, 'medium': { # ~25M parameters với 32k vocab + weight tying 'd_model': 384, 'n_layers': 5, 'n_heads': 8, 'd_ff': 1536, 'dropout': 0.1 }, 'custom_25m': { # ~25M parameters với 32k vocab + weight tying 'd_model': 384, 'n_layers': 6, 'n_heads': 8, 'd_ff': 1536, # 4 * d_model 'dropout': 0.1 }, 'base': { 'd_model': 512, 'n_layers': 6, 'n_heads': 8, 'd_ff': 2048, 'dropout': 0.1 }, 'large': { 'd_model': 1024, 'n_layers': 6, 'n_heads': 16, 'd_ff': 4096, 'dropout': 0.1 } } return configs.get(model_size, configs['base']) def create_model(src_vocab_size, tgt_vocab_size, model_size='base', pad_idx=0, use_shared_vocab=True, use_weight_tying=True): """ Tạo Transformer model với Shared Vocabulary Args: src_vocab_size: Kích thước shared vocabulary tgt_vocab_size: Bỏ qua (giữ để tương thích, phải = src_vocab_size) model_size: Kích thước model ('tiny', 'small', 'base', 'large') pad_idx: Padding index use_shared_vocab: Luôn True (giữ để tương thích) use_weight_tying: Có dùng weight tying không (mặc định True) Returns: model: TransformerShared model config: Model configuration """ config = get_model_config(model_size) # Luôn dùng shared vocabulary vocab_size = src_vocab_size model = TransformerShared( vocab_size=vocab_size, d_model=config['d_model'], n_layers=config['n_layers'], n_heads=config['n_heads'], d_ff=config['d_ff'], dropout=config['dropout'], pad_idx=pad_idx, use_weight_tying=use_weight_tying ) return model, config # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def count_parameters(model): """ Đếm số lượng parameters của model Args: model: PyTorch model Returns: total: Tổng số parameters trainable: Số parameters có thể train """ total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return total, trainable def print_model_info(model, model_size='base', use_shared_vocab=False): """ In thông tin về model Args: model: Transformer model model_size: Kích thước model use_shared_vocab: Có dùng shared vocabulary không """ total_params, trainable_params = count_parameters(model) print("="*70) print("THÔNG TIN MÔ HÌNH TRANSFORMER") print("="*70) print(f"\nKích thước model: {model_size.upper()}") if use_shared_vocab: print(f" Mode: SHARED VOCABULARY + WEIGHT TYING") if isinstance(model, TransformerShared) and model.use_weight_tying: print(f" ✓ Weight Tying: Enabled (tiết kiệm ~50% vocab params)") else: print(f" Mode: SEPARATE VOCABULARIES") print(f"\nSố lượng parameters:") print(f" - Total: {total_params:,}") print(f" - Trainable: {trainable_params:,}") print(f" - Model size: ~{total_params * 4 / (1024**2):.2f} MB (float32)") config = get_model_config(model_size) print(f"\nCấu hình:") print(f" - d_model: {config['d_model']}") print(f" - n_layers: {config['n_layers']}") print(f" - n_heads: {config['n_heads']}") print(f" - d_ff: {config['d_ff']}") print(f" - dropout: {config['dropout']}") print("="*70) # ============================================================================ # TEST COMPLETE MODEL # ============================================================================ if __name__ == "__main__": print("="*70) print("KIỂM TRA TRANSFORMER MODEL HOÀN CHỈNH") print("="*70) # Hyperparameters src_vocab_size = 10000 tgt_vocab_size = 8000 batch_size = 4 src_len = 15 tgt_len = 20 pad_idx = 0 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"\nDevice: {device}\n") # Test với các kích thước model khác nhau for model_size in ['tiny', 'small', 'base']: print(f"\n{'='*70}") print(f"TEST MODEL SIZE: {model_size.upper()}") print(f"{'='*70}\n") # Tạo model model, config = create_model( src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, model_size=model_size, pad_idx=pad_idx ) model = model.to(device) # In thông tin model print_model_info(model, model_size) # Tạo dummy data src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device) tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len)).to(device) # Forward pass print(f"\nForward pass:") print(f" Source shape: {src.shape}") print(f" Target shape: {tgt.shape}") with torch.no_grad(): output = model(src, tgt) print(f" Output shape: {output.shape}") print(f" Expected: [{batch_size}, {tgt_len}, {tgt_vocab_size}]") print(f" ✓ Shape correct!") # Test encode và decode riêng print(f"\nTest encode & decode separately:") with torch.no_grad(): encoder_output, src_mask = model.encode(src) decoder_output = model.decode(tgt, encoder_output, src_mask) print(f" Encoder output shape: {encoder_output.shape}") print(f" Decoder output shape: {decoder_output.shape}") print(f" ✓ Encode/Decode work correctly!") # Kiểm tra output giống nhau print(f"\nVerify output consistency:") with torch.no_grad(): output_combined = model(src, tgt) is_same = torch.allclose(output_combined, decoder_output, atol=1e-6) print(f" Forward == Encode+Decode: {is_same}") print(f" ✓ Model is consistent!") print("\n" + "="*70) print("✓ TẤT CẢ TESTS PASSED!") print("="*70) print("\n📝 GỢI Ý SỬ DỤNG:") print(" - Dùng 'tiny' để debug và test nhanh") print(" - Dùng 'small' để train trên CPU hoặc GPU nhỏ") print(" - Dùng 'base' để có kết quả tốt (cần GPU)") print(" - Dùng 'large' chỉ khi có GPU mạnh")