|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from datasets import Dataset as ArrowDataset |
|
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
|
|
|
import config |
|
|
from src import utils |
|
|
|
|
|
|
|
|
class TranslationDataset(Dataset): |
|
|
""" |
|
|
A "lazy" Dataset. |
|
|
Uses the high-level PreTrainedTokenizerFast wrapper. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset: ArrowDataset, |
|
|
tokenizer: PreTrainedTokenizerFast, |
|
|
max_len_src: int, |
|
|
max_len_tgt: int, |
|
|
src_lang: str = "en", |
|
|
tgt_lang: str = "vi", |
|
|
): |
|
|
super().__init__() |
|
|
self.dataset = dataset |
|
|
self.tokenizer = tokenizer |
|
|
self.max_len_src = max_len_src |
|
|
self.max_len_tgt = max_len_tgt |
|
|
self.src_lang = src_lang |
|
|
self.tgt_lang = tgt_lang |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, index: int) -> dict[str, list[int]]: |
|
|
|
|
|
item = self.dataset[index]["translation"] |
|
|
src_text = item[self.src_lang] |
|
|
tgt_text = item[self.tgt_lang] |
|
|
|
|
|
|
|
|
src_encoding = self.tokenizer( |
|
|
src_text, |
|
|
truncation=True, |
|
|
max_length=self.max_len_src, |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
tgt_encoding = self.tokenizer( |
|
|
tgt_text, |
|
|
truncation=True, |
|
|
max_length=self.max_len_tgt - 2, |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
|
|
|
src_ids = src_encoding["input_ids"] |
|
|
|
|
|
tgt_ids = ( |
|
|
[config.SOS_TOKEN_ID] + tgt_encoding["input_ids"] + [config.EOS_TOKEN_ID] |
|
|
) |
|
|
|
|
|
return {"src_ids": src_ids, "tgt_ids": tgt_ids} |
|
|
|
|
|
|
|
|
class DataCollator: |
|
|
""" |
|
|
Implements a custom collate_fn. |
|
|
|
|
|
1. Takes a list of dicts (from __getitem__) |
|
|
2. Adds SOS/EOS (Wait, we did this in Dataset) |
|
|
3. Creates decoder inputs and labels (shifted) |
|
|
4. Dynamically pads all sequences *in the batch* |
|
|
5. Creates all 3 required masks |
|
|
6. Returns a single dict of tensors |
|
|
""" |
|
|
|
|
|
def __init__(self, pad_token_id: int): |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
def __call__(self, batch: list[dict[str, list[int]]]) -> dict[str, Tensor]: |
|
|
|
|
|
|
|
|
src_ids_list = [item["src_ids"] for item in batch] |
|
|
tgt_ids_list = [item["tgt_ids"] for item in batch] |
|
|
|
|
|
|
|
|
|
|
|
dec_input_ids_list = [ids[:-1] for ids in tgt_ids_list] |
|
|
|
|
|
labels_list = [ids[1:] for ids in tgt_ids_list] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_ids_padded = nn.utils.rnn.pad_sequence( |
|
|
[torch.tensor(ids) for ids in src_ids_list], |
|
|
batch_first=True, |
|
|
padding_value=self.pad_token_id, |
|
|
) |
|
|
|
|
|
dec_input_ids_padded = nn.utils.rnn.pad_sequence( |
|
|
[torch.tensor(ids) for ids in dec_input_ids_list], |
|
|
batch_first=True, |
|
|
padding_value=self.pad_token_id, |
|
|
) |
|
|
|
|
|
labels_padded = nn.utils.rnn.pad_sequence( |
|
|
[torch.tensor(ids) for ids in labels_list], |
|
|
batch_first=True, |
|
|
padding_value=self.pad_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
_, T_tgt = dec_input_ids_padded.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_mask = utils.create_padding_mask(src_ids_padded, self.pad_token_id) |
|
|
|
|
|
|
|
|
|
|
|
tgt_padding_mask = utils.create_padding_mask( |
|
|
dec_input_ids_padded, self.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
look_ahead_mask = utils.create_look_ahead_mask(T_tgt) |
|
|
|
|
|
|
|
|
|
|
|
tgt_mask = tgt_padding_mask & look_ahead_mask |
|
|
|
|
|
return { |
|
|
"src_ids": src_ids_padded, |
|
|
"tgt_input_ids": dec_input_ids_padded, |
|
|
"labels": labels_padded, |
|
|
"src_mask": src_mask, |
|
|
"tgt_mask": tgt_mask, |
|
|
} |
|
|
|
|
|
|
|
|
def get_translation_datasets( |
|
|
tokenizer: PreTrainedTokenizerFast, |
|
|
) -> tuple[TranslationDataset, TranslationDataset, TranslationDataset]: |
|
|
""" |
|
|
A Factory function to automate the data pipeline setup. |
|
|
|
|
|
It performs 3 steps: |
|
|
1. Loads and cleans raw data (using src.utils). |
|
|
2. Instantiates the TranslationDataset for Train, Val, and Test splits. |
|
|
3. Returns the 3 PyTorch datasets ready for the DataLoader. |
|
|
|
|
|
Args: |
|
|
tokenizer: The trained tokenizer. |
|
|
|
|
|
Returns: |
|
|
Tuple containing (train_ds, val_ds, test_ds) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
train_data, val_data, test_data = utils.get_raw_data( |
|
|
config.DATA_PATH, num_workers=config.NUM_WORKERS |
|
|
) |
|
|
train_data = train_data.select(range(config.NUM_SAMPLES_TO_USE)) |
|
|
|
|
|
print(f"Building PyTorch Datasets...") |
|
|
|
|
|
|
|
|
|
|
|
train_ds = TranslationDataset( |
|
|
dataset=train_data, |
|
|
tokenizer=tokenizer, |
|
|
max_len_src=config.MAX_SEQ_LEN, |
|
|
max_len_tgt=config.MAX_SEQ_LEN, |
|
|
) |
|
|
|
|
|
|
|
|
val_ds = TranslationDataset( |
|
|
dataset=val_data, |
|
|
tokenizer=tokenizer, |
|
|
max_len_src=config.MAX_SEQ_LEN, |
|
|
max_len_tgt=config.MAX_SEQ_LEN, |
|
|
) |
|
|
|
|
|
|
|
|
test_ds = TranslationDataset( |
|
|
dataset=test_data, |
|
|
tokenizer=tokenizer, |
|
|
max_len_src=config.MAX_SEQ_LEN, |
|
|
max_len_tgt=config.MAX_SEQ_LEN, |
|
|
) |
|
|
|
|
|
print( |
|
|
f"Datasets created: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}" |
|
|
) |
|
|
|
|
|
return train_ds, val_ds, test_ds |
|
|
|
|
|
|
|
|
def get_dataloaders( |
|
|
tokenizer: PreTrainedTokenizerFast, |
|
|
) -> tuple[DataLoader, DataLoader, DataLoader]: |
|
|
""" |
|
|
A high-level Factory function to create DataLoaders. |
|
|
|
|
|
This function abstracts away all the data pipeline complexity: |
|
|
- Loading/Cleaning raw data |
|
|
- Creating PyTorch Datasets |
|
|
- Instantiating the DataCollator (dynamic padding) |
|
|
- Creating DataLoaders with the correct batch size and workers |
|
|
|
|
|
Args: |
|
|
tokenizer: The trained tokenizer. |
|
|
|
|
|
Returns: |
|
|
Tuple containing (train_loader, val_loader, test_loader) |
|
|
""" |
|
|
|
|
|
|
|
|
train_ds, val_ds, test_ds = get_translation_datasets(tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
collator = DataCollator(pad_token_id=config.PAD_TOKEN_ID) |
|
|
|
|
|
print( |
|
|
f"Building DataLoaders (Batch Size: {config.BATCH_SIZE}, Workers: {config.NUM_WORKERS})..." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=config.BATCH_SIZE, |
|
|
shuffle=True, |
|
|
num_workers=config.NUM_WORKERS, |
|
|
collate_fn=collator, |
|
|
pin_memory=True if config.DEVICE == "cuda" else False, |
|
|
prefetch_factor=2, |
|
|
persistent_workers=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_ds, |
|
|
batch_size=2 * config.BATCH_SIZE, |
|
|
shuffle=False, |
|
|
num_workers=config.NUM_WORKERS, |
|
|
collate_fn=collator, |
|
|
pin_memory=True if config.DEVICE == "cuda" else False, |
|
|
prefetch_factor=2, |
|
|
persistent_workers=True, |
|
|
) |
|
|
|
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_ds, |
|
|
batch_size=2 * config.BATCH_SIZE, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
|
|
|
collate_fn=collator, |
|
|
pin_memory=True if config.DEVICE == "cuda" else False, |
|
|
prefetch_factor=2, |
|
|
) |
|
|
|
|
|
print(f"DataLoader (train) created with {len(train_loader)} batches.") |
|
|
print(f"DataLoader (val) created with {len(val_loader)} batches.") |
|
|
print(f"DataLoader (test) created with {len(test_loader)} batches.") |
|
|
|
|
|
return train_loader, val_loader, test_loader |
|
|
|