|
|
from pathlib import Path |
|
|
import random |
|
|
import re |
|
|
from datetime import datetime |
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import Tensor |
|
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
from jaxtyping import Bool, Int |
|
|
import model |
|
|
|
|
|
|
|
|
|
|
|
def seed_everything(seed: int = 42) -> None: |
|
|
""" |
|
|
Set random seed for Python, NumPy, and PyTorch to ensure reproducibility. |
|
|
Args: |
|
|
seed (int): The seed value to use. |
|
|
""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
def make_run_name(model_name: str, d_model: int) -> str: |
|
|
time_tag: str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
return f"{model_name}-{d_model}d-{time_tag}" |
|
|
|
|
|
|
|
|
|
|
|
def load_tokenizer(tokenizer_path: str | Path) -> PreTrainedTokenizerFast: |
|
|
""" |
|
|
Load a trained tokenizer from file and return tokenizer object and special token ids. |
|
|
Args: |
|
|
tokenizer_path (str | Path): Path to the tokenizer JSON file. |
|
|
special_tokens (list[str], optional): List of special tokens to get ids for (e.g. ["[PAD]", "[SOS]", "[EOS]", "[UNK]"]). |
|
|
Returns: |
|
|
tokenizer (Tokenizer): Loaded tokenizer object. |
|
|
token_ids (dict): Dictionary of special token ids. |
|
|
""" |
|
|
print(f"Loading tokenizer from {tokenizer_path}...") |
|
|
|
|
|
tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) |
|
|
tokenizer.pad_token = "[PAD]" |
|
|
tokenizer.unk_token = "[UNK]" |
|
|
tokenizer.bos_token = "[SOS]" |
|
|
tokenizer.eos_token = "[EOS]" |
|
|
return tokenizer |
|
|
|
|
|
|
|
|
def create_padding_mask( |
|
|
input_ids: Int[Tensor, "B T_k"], pad_token_id: int |
|
|
) -> Bool[Tensor, "B 1 1 T_k"]: |
|
|
""" |
|
|
Creates a padding mask for the attention mechanism. |
|
|
|
|
|
This mask identifies positions holding the <PAD> token |
|
|
and prepares a mask tensor that, when broadcasted, will mask |
|
|
these positions in the attention scores matrix (B, H, T_q, T_k). |
|
|
|
|
|
Args: |
|
|
input_ids (Tensor): The input token IDs. Shape (B, T_k). |
|
|
pad_token_id (int): The ID of the padding token. |
|
|
|
|
|
Returns: |
|
|
Tensor: A boolean mask of shape (B, 1, 1, T_k). |
|
|
'True' means "keep" (not a pad token). |
|
|
'False' means "mask out" (is a pad token). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask: Tensor = input_ids != pad_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
|
|
|
def create_look_ahead_mask(seq_len: int) -> Bool[Tensor, "1 1 T_q T_q"]: |
|
|
""" |
|
|
Creates a causal (look-ahead) mask for the Decoder's self-attention. |
|
|
|
|
|
This mask prevents positions from attending to subsequent positions. |
|
|
It's a square matrix where the upper triangle (future) is False |
|
|
and the lower triangle (past/present) is True. |
|
|
|
|
|
Args: |
|
|
seq_len (int): The sequence length (T_q). |
|
|
device (torch.device): The device to create the tensor on (e.g., 'cuda'). |
|
|
|
|
|
Returns: |
|
|
Tensor: A boolean mask of shape (1, 1, T_q, T_q). |
|
|
'True' means "keep" (allowed to see). |
|
|
'False' means "mask out" (future token). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
ones = torch.ones(seq_len, seq_len) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lower_triangular: Tensor = torch.tril(ones) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (lower_triangular == 1).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
def greedy_decode_sentence( |
|
|
model: model.Transformer, |
|
|
src: Int[Tensor, "1 T_src"], |
|
|
src_mask: Bool[Tensor, "1 1 1 T_src"], |
|
|
max_len: int, |
|
|
sos_token_id: int, |
|
|
eos_token_id: int, |
|
|
device: torch.device, |
|
|
) -> Int[Tensor, "1 T_out"]: |
|
|
""" |
|
|
Performs greedy decoding for a single sentence. |
|
|
This is an autoregressive process (token by token). |
|
|
|
|
|
Args: |
|
|
model: The trained Transformer model (already on device). |
|
|
src: The source token IDs (e.g., English). |
|
|
src_mask: The padding mask for the source. |
|
|
max_len: The maximum length to generate. |
|
|
sos_token_id: The ID for [SOS] token. |
|
|
eos_token_id: The ID for [EOS] token. |
|
|
device: The device to run on. |
|
|
|
|
|
Returns: |
|
|
Tensor: The generated target token IDs (e.g., Vietnamese). |
|
|
""" |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
src_embedded = model.src_embed(src) |
|
|
src_with_pos = model.pos_enc(src_embedded) |
|
|
enc_output: Tensor = model.encoder(src_with_pos, src_mask) |
|
|
|
|
|
|
|
|
|
|
|
decoder_input: Tensor = torch.tensor( |
|
|
[[sos_token_id]], dtype=torch.long, device=device |
|
|
) |
|
|
|
|
|
|
|
|
for _ in range(max_len - 1): |
|
|
|
|
|
|
|
|
|
|
|
tgt_embedded = model.tgt_embed(decoder_input) |
|
|
tgt_with_pos = model.pos_enc(tgt_embedded) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T_tgt = decoder_input.size(1) |
|
|
tgt_mask = create_look_ahead_mask(T_tgt).to(device) |
|
|
|
|
|
|
|
|
|
|
|
dec_output: Tensor = model.decoder( |
|
|
tgt_with_pos, enc_output, src_mask, tgt_mask |
|
|
) |
|
|
|
|
|
logits: Tensor = model.generator(dec_output) |
|
|
|
|
|
|
|
|
|
|
|
last_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
|
|
|
next_token: Tensor = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
decoder_input = torch.cat([decoder_input, next_token], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
if next_token.item() == eos_token_id: |
|
|
break |
|
|
|
|
|
return decoder_input.squeeze(0) |
|
|
|
|
|
|
|
|
def filter_and_detokenize(token_list: list[str], skip_special: bool = True) -> str: |
|
|
""" |
|
|
Manually joins tokens with a space and cleans up common |
|
|
punctuation issues caused by whitespace tokenization. |
|
|
""" |
|
|
if skip_special: |
|
|
|
|
|
special_tokens = {"[PAD]", "[UNK]", "[SOS]", "[EOS]"} |
|
|
token_list = [tok for tok in token_list if tok not in special_tokens] |
|
|
|
|
|
|
|
|
detokenized_string = " ".join(token_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
detokenized_string = re.sub(r'\s([.,!?\'":;])', r"\1", detokenized_string) |
|
|
|
|
|
detokenized_string = re.sub(r"(\w)\s(\'\w)", r"\1\2", detokenized_string) |
|
|
|
|
|
return detokenized_string |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate( |
|
|
model: model.Transformer, |
|
|
tokenizer: PreTrainedTokenizerFast, |
|
|
sentence_en: str, |
|
|
device: torch.device, |
|
|
max_len: int, |
|
|
sos_token_id: int, |
|
|
eos_token_id: int, |
|
|
pad_token_id: int, |
|
|
) -> str: |
|
|
""" |
|
|
Translates a single English sentence to Vietnamese. |
|
|
|
|
|
Args: |
|
|
model: The trained Transformer model. |
|
|
tokenizer: The (PreTrainedTokenizerFast) tokenizer. |
|
|
sentence_en: The raw English input string. |
|
|
device: The device to run on. |
|
|
max_len: The max sequence length (from config). |
|
|
sos_token_id: The ID for [SOS]. |
|
|
eos_token_id: The ID for [EOS]. |
|
|
pad_token_id: The ID for [PAD]. |
|
|
|
|
|
Returns: |
|
|
str: The translated Vietnamese string. |
|
|
""" |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
src_encoding = tokenizer( |
|
|
sentence_en, |
|
|
truncation=True, |
|
|
max_length=max_len, |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
src_ids: Tensor = torch.tensor( |
|
|
[src_encoding["input_ids"]], dtype=torch.long |
|
|
).to(device) |
|
|
|
|
|
|
|
|
|
|
|
src_mask: Tensor = create_padding_mask(src_ids, pad_token_id).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predicted_ids: Tensor = greedy_decode_sentence( |
|
|
model, |
|
|
src_ids, |
|
|
src_mask, |
|
|
max_len=max_len, |
|
|
sos_token_id=sos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predicted_id_list = predicted_ids.cpu().tolist() |
|
|
|
|
|
|
|
|
predicted_token_list = tokenizer.convert_ids_to_tokens(predicted_id_list) |
|
|
|
|
|
|
|
|
|
|
|
result_string = filter_and_detokenize(predicted_token_list, skip_special=True) |
|
|
|
|
|
return result_string |
|
|
|
|
|
print("Inference function `translate()` defined.") |
|
|
|