from pathlib import Path import random import re from datetime import datetime import numpy as np import torch from torch import Tensor from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from jaxtyping import Bool, Int import model # Utility function to set random seed for reproducibility def seed_everything(seed: int = 42) -> None: """ Set random seed for Python, NumPy, and PyTorch to ensure reproducibility. Args: seed (int): The seed value to use. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def make_run_name(model_name: str, d_model: int) -> str: time_tag: str = datetime.now().strftime("%Y%m%d_%H%M%S") return f"{model_name}-{d_model}d-{time_tag}" # Utility function to set random seed for reproducibility def load_tokenizer(tokenizer_path: str | Path) -> PreTrainedTokenizerFast: """ Load a trained tokenizer from file and return tokenizer object and special token ids. Args: tokenizer_path (str | Path): Path to the tokenizer JSON file. special_tokens (list[str], optional): List of special tokens to get ids for (e.g. ["[PAD]", "[SOS]", "[EOS]", "[UNK]"]). Returns: tokenizer (Tokenizer): Loaded tokenizer object. token_ids (dict): Dictionary of special token ids. """ print(f"Loading tokenizer from {tokenizer_path}...") # tokenizer = Tokenizer.from_file(str(tokenizer_path)) tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) tokenizer.pad_token = "[PAD]" tokenizer.unk_token = "[UNK]" tokenizer.bos_token = "[SOS]" # bos = Beginning Of Sentence tokenizer.eos_token = "[EOS]" # eos = End Of Sentence return tokenizer def create_padding_mask( input_ids: Int[Tensor, "B T_k"], pad_token_id: int ) -> Bool[Tensor, "B 1 1 T_k"]: """ Creates a padding mask for the attention mechanism. This mask identifies positions holding the token and prepares a mask tensor that, when broadcasted, will mask these positions in the attention scores matrix (B, H, T_q, T_k). Args: input_ids (Tensor): The input token IDs. Shape (B, T_k). pad_token_id (int): The ID of the padding token. Returns: Tensor: A boolean mask of shape (B, 1, 1, T_k). 'True' means "keep" (not a pad token). 'False' means "mask out" (is a pad token). """ # 1. Create the base mask # (input_ids != pad_token_id) will be True for real tokens, False for PAD # Shape: (B, T_k) mask: Tensor = input_ids != pad_token_id # 2. Add dimensions for broadcasting # We add a dimension for T_q (dim 1) and H (dim 2) # Shape: (B, T_k) -> (B, 1, T_k) -> (B, 1, 1, T_k) return mask.unsqueeze(1).unsqueeze(2) def create_look_ahead_mask(seq_len: int) -> Bool[Tensor, "1 1 T_q T_q"]: """ Creates a causal (look-ahead) mask for the Decoder's self-attention. This mask prevents positions from attending to subsequent positions. It's a square matrix where the upper triangle (future) is False and the lower triangle (past/present) is True. Args: seq_len (int): The sequence length (T_q). device (torch.device): The device to create the tensor on (e.g., 'cuda'). Returns: Tensor: A boolean mask of shape (1, 1, T_q, T_q). 'True' means "keep" (allowed to see). 'False' means "mask out" (future token). """ # 1. Create a square matrix of ones. # Shape: (T_q, T_q) ones = torch.ones(seq_len, seq_len) # 2. Get the lower triangular part (bao gồm đường chéo) # This sets the upper triangle (future) to 0 and keeps the rest 1. # Shape: (T_q, T_q) # Example (T_q=3): # [[1., 0., 0.], # [1., 1., 0.], # [1., 1., 1.]] lower_triangular: Tensor = torch.tril(ones) # 3. Convert to boolean and add broadcasting dimensions # Shape: (T_q, T_q) -> (1, 1, T_q, T_q) # (mask == 1) converts 1. to True, 0. to False return (lower_triangular == 1).unsqueeze(0).unsqueeze(0) def greedy_decode_sentence( model: model.Transformer, src: Int[Tensor, "1 T_src"], # Input: one sentence src_mask: Bool[Tensor, "1 1 1 T_src"], max_len: int, sos_token_id: int, eos_token_id: int, device: torch.device, ) -> Int[Tensor, "1 T_out"]: """ Performs greedy decoding for a single sentence. This is an autoregressive process (token by token). Args: model: The trained Transformer model (already on device). src: The source token IDs (e.g., English). src_mask: The padding mask for the source. max_len: The maximum length to generate. sos_token_id: The ID for [SOS] token. eos_token_id: The ID for [EOS] token. device: The device to run on. Returns: Tensor: The generated target token IDs (e.g., Vietnamese). """ # Set model to eval mode (disables dropout) model.eval() # No gradients needed with torch.no_grad(): # --- 1. Encode the source *once* --- # (B, T_src) -> (B, T_src, D) src_embedded = model.src_embed(src) src_with_pos = model.pos_enc(src_embedded) enc_output: Tensor = model.encoder(src_with_pos, src_mask) # --- 2. Initialize the Decoder input --- # Start with the [SOS] token. Shape: (1, 1) decoder_input: Tensor = torch.tensor( [[sos_token_id]], dtype=torch.long, device=device ) # Shape: (B=1, T_tgt=1) # --- 3. Autoregressive Loop --- for _ in range(max_len - 1): # (Max length - 1, since we have [SOS]) # --- a. Get Target Embedding + Position --- # (B, T_tgt) -> (B, T_tgt, D) tgt_embedded = model.tgt_embed(decoder_input) tgt_with_pos = model.pos_enc(tgt_embedded) # --- b. Create Target Mask (Causal) --- # We must re-create the mask every loop, # as T_tgt (decoder_input.size(1)) is growing. # Shape: (1, 1, T_tgt, T_tgt) T_tgt = decoder_input.size(1) tgt_mask = create_look_ahead_mask(T_tgt).to(device) # --- c. Run Decoder and Generator --- # (B, T_tgt, D) dec_output: Tensor = model.decoder( tgt_with_pos, enc_output, src_mask, tgt_mask ) # (B, T_tgt, vocab_size) logits: Tensor = model.generator(dec_output) # --- d. Get the *last* token's logits --- # (B, T_tgt, vocab_size) -> (B, vocab_size) last_token_logits = logits[:, -1, :] # --- e. Greedy Search (get highest prob. token) --- # (B, vocab_size) -> (B, 1) next_token: Tensor = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1) # --- f. Append the new token --- # (B, T_tgt) + (B, 1) -> (B, T_tgt + 1) decoder_input = torch.cat([decoder_input, next_token], dim=1) # --- g. Check for [EOS] --- # If the *last* token we added is [EOS], stop generating. if next_token.item() == eos_token_id: break return decoder_input.squeeze(0) # Return shape (T_out) def filter_and_detokenize(token_list: list[str], skip_special: bool = True) -> str: """ Manually joins tokens with a space and cleans up common punctuation issues caused by whitespace tokenization. """ if skip_special: # 1. Filter out special tokens special_tokens = {"[PAD]", "[UNK]", "[SOS]", "[EOS]"} token_list = [tok for tok in token_list if tok not in special_tokens] # 2. Join with spaces detokenized_string = " ".join(token_list) # 3. Clean up punctuation # (This is a simple heuristic-based detokenizer) # Remove space before punctuation: "project ." -> "project." detokenized_string = re.sub(r'\s([.,!?\'":;])', r"\1", detokenized_string) # Handle contractions: "don 't" -> "don't" detokenized_string = re.sub(r"(\w)\s(\'\w)", r"\1\2", detokenized_string) return detokenized_string # Define a high-level, production-ready # inference function that handles all steps. def translate( model: model.Transformer, tokenizer: PreTrainedTokenizerFast, sentence_en: str, device: torch.device, max_len: int, sos_token_id: int, eos_token_id: int, pad_token_id: int, ) -> str: """ Translates a single English sentence to Vietnamese. Args: model: The trained Transformer model. tokenizer: The (PreTrainedTokenizerFast) tokenizer. sentence_en: The raw English input string. device: The device to run on. max_len: The max sequence length (from config). sos_token_id: The ID for [SOS]. eos_token_id: The ID for [EOS]. pad_token_id: The ID for [PAD]. Returns: str: The translated Vietnamese string. """ # Set model to evaluation mode model.eval() # Run inference in a no-gradient context with torch.no_grad(): # 1. Tokenize the source (English) sentence src_encoding = tokenizer( sentence_en, truncation=True, max_length=max_len, add_special_tokens=False, # (Encoder does not need SOS/EOS) ) # 2. Convert to Tensor, add Batch dimension (B=1), and move to device # Shape: (1, T_src) src_ids: Tensor = torch.tensor( [src_encoding["input_ids"]], dtype=torch.long ).to(device) # 3. Create the source padding mask # Shape: (1, 1, 1, T_src) src_mask: Tensor = create_padding_mask(src_ids, pad_token_id).to(device) # 4. Generate the target (Vietnamese) token IDs # (This calls the autoregressive function from Cell 16A) # Shape: (T_out) predicted_ids: Tensor = greedy_decode_sentence( model, src_ids, src_mask, max_len=max_len, sos_token_id=sos_token_id, eos_token_id=eos_token_id, device=device, ) # 5. Detokenize (Fixing "sticky" words) # Convert 1D GPU Tensor -> 1D CPU List predicted_id_list = predicted_ids.cpu().tolist() # This call is safe (1D List -> List[str]) predicted_token_list = tokenizer.convert_ids_to_tokens(predicted_id_list) # Use our helper (from Cell 16B) to # join with spaces, remove special tokens, and fix punctuation. result_string = filter_and_detokenize(predicted_token_list, skip_special=True) return result_string print("Inference function `translate()` defined.")