AlainDeLong's picture
Create translate app
e27ab6a
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as ArrowDataset
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
import config
from src import utils
class TranslationDataset(Dataset):
"""
A "lazy" Dataset.
Uses the high-level PreTrainedTokenizerFast wrapper.
"""
def __init__(
self,
dataset: ArrowDataset,
tokenizer: PreTrainedTokenizerFast,
max_len_src: int,
max_len_tgt: int,
src_lang: str = "en",
tgt_lang: str = "vi",
):
super().__init__()
self.dataset = dataset
self.tokenizer = tokenizer
self.max_len_src = max_len_src
self.max_len_tgt = max_len_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, index: int) -> dict[str, list[int]]:
item = self.dataset[index]["translation"]
src_text = item[self.src_lang]
tgt_text = item[self.tgt_lang]
# We set add_special_tokens=False for manual control.
src_encoding = self.tokenizer(
src_text,
truncation=True,
max_length=self.max_len_src,
add_special_tokens=False, # (Source has no SOS/EOS)
)
tgt_encoding = self.tokenizer(
tgt_text,
truncation=True,
max_length=self.max_len_tgt - 2, # (Reserve 2 spots for SOS/EOS)
add_special_tokens=False,
)
# Manually add SOS/EOS to target
src_ids = src_encoding["input_ids"]
tgt_ids = (
[config.SOS_TOKEN_ID] + tgt_encoding["input_ids"] + [config.EOS_TOKEN_ID]
)
return {"src_ids": src_ids, "tgt_ids": tgt_ids}
class DataCollator:
"""
Implements a custom collate_fn.
1. Takes a list of dicts (from __getitem__)
2. Adds SOS/EOS (Wait, we did this in Dataset)
3. Creates decoder inputs and labels (shifted)
4. Dynamically pads all sequences *in the batch*
5. Creates all 3 required masks
6. Returns a single dict of tensors
"""
def __init__(self, pad_token_id: int):
self.pad_token_id = pad_token_id
def __call__(self, batch: list[dict[str, list[int]]]) -> dict[str, Tensor]:
# 1. Get raw ID lists from the batch
src_ids_list = [item["src_ids"] for item in batch]
tgt_ids_list = [item["tgt_ids"] for item in batch] # (Already has SOS/EOS)
# 2. Create shifted inputs/labels
# Decoder input (T_tgt): [SOS, w1, w2, w3]
dec_input_ids_list = [ids[:-1] for ids in tgt_ids_list]
# Label (T_tgt): [w1, w2, w3, EOS]
labels_list = [ids[1:] for ids in tgt_ids_list]
# 3. Dynamic Padding
# We use torch.nn.utils.rnn.pad_sequence
# (Note: batch_first=True means (B, T))
src_ids_padded = nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in src_ids_list],
batch_first=True,
padding_value=self.pad_token_id,
)
dec_input_ids_padded = nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in dec_input_ids_list],
batch_first=True,
padding_value=self.pad_token_id,
)
labels_padded = nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in labels_list],
batch_first=True,
padding_value=self.pad_token_id, # (Loss will ignore this ID)
)
# 4. Get the sequence length
_, T_tgt = dec_input_ids_padded.shape
# 5. Create Masks (on CPU)
# (Mask 1) Source padding mask (for Encoder MHA & Cross-Attn)
# Shape: (B, 1, 1, T_src)
src_mask = utils.create_padding_mask(src_ids_padded, self.pad_token_id)
# (Mask 2) Target padding mask (for Decoder MHA)
# Shape: (B, 1, 1, T_tgt)
tgt_padding_mask = utils.create_padding_mask(
dec_input_ids_padded, self.pad_token_id
)
# (Mask 3) Target look-ahead mask (for Decoder MHA)
# Shape: (1, 1, T_tgt, T_tgt)
look_ahead_mask = utils.create_look_ahead_mask(T_tgt)
# (Mask 4) Combined target mask
# Shape: (B, 1, T_tgt, T_tgt)
tgt_mask = tgt_padding_mask & look_ahead_mask
return {
"src_ids": src_ids_padded, # (B, T_src)
"tgt_input_ids": dec_input_ids_padded, # (B, T_tgt)
"labels": labels_padded, # (B, T_tgt)
"src_mask": src_mask, # (B, 1, 1, T_src)
"tgt_mask": tgt_mask, # (B, 1, T_tgt, T_tgt)
}
def get_translation_datasets(
tokenizer: PreTrainedTokenizerFast,
) -> tuple[TranslationDataset, TranslationDataset, TranslationDataset]:
"""
A Factory function to automate the data pipeline setup.
It performs 3 steps:
1. Loads and cleans raw data (using src.utils).
2. Instantiates the TranslationDataset for Train, Val, and Test splits.
3. Returns the 3 PyTorch datasets ready for the DataLoader.
Args:
tokenizer: The trained tokenizer.
Returns:
Tuple containing (train_ds, val_ds, test_ds)
"""
# 1. Load raw cleaned data (returns Dict[str, Dataset])
# This keeps train.py clean from raw data handling logic.
train_data, val_data, test_data = utils.get_raw_data(
config.DATA_PATH, num_workers=config.NUM_WORKERS
)
train_data = train_data.select(range(config.NUM_SAMPLES_TO_USE))
print(f"Building PyTorch Datasets...")
# 2. Instantiate the Train Dataset
# (Uses global config for max_length)
train_ds = TranslationDataset(
dataset=train_data,
tokenizer=tokenizer,
max_len_src=config.MAX_SEQ_LEN,
max_len_tgt=config.MAX_SEQ_LEN,
)
# 3. Instantiate the Validation Dataset
val_ds = TranslationDataset(
dataset=val_data,
tokenizer=tokenizer,
max_len_src=config.MAX_SEQ_LEN,
max_len_tgt=config.MAX_SEQ_LEN,
)
# 4. Instantiate the Test Dataset
test_ds = TranslationDataset(
dataset=test_data,
tokenizer=tokenizer,
max_len_src=config.MAX_SEQ_LEN,
max_len_tgt=config.MAX_SEQ_LEN,
)
print(
f"Datasets created: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}"
)
return train_ds, val_ds, test_ds
def get_dataloaders(
tokenizer: PreTrainedTokenizerFast,
) -> tuple[DataLoader, DataLoader, DataLoader]:
"""
A high-level Factory function to create DataLoaders.
This function abstracts away all the data pipeline complexity:
- Loading/Cleaning raw data
- Creating PyTorch Datasets
- Instantiating the DataCollator (dynamic padding)
- Creating DataLoaders with the correct batch size and workers
Args:
tokenizer: The trained tokenizer.
Returns:
Tuple containing (train_loader, val_loader, test_loader)
"""
# 1. Create the Datasets (using the factory function we made earlier)
train_ds, val_ds, test_ds = get_translation_datasets(tokenizer)
# 2. Instantiate the Collator
# (We need config to get PAD_TOKEN_ID)
collator = DataCollator(pad_token_id=config.PAD_TOKEN_ID)
print(
f"Building DataLoaders (Batch Size: {config.BATCH_SIZE}, Workers: {config.NUM_WORKERS})..."
)
# 3. Create Train DataLoader
# (Shuffle = True is CRITICAL for training)
train_loader = DataLoader(
train_ds,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
collate_fn=collator,
pin_memory=True if config.DEVICE == "cuda" else False, # (Optimization)
prefetch_factor=2,
persistent_workers=True,
)
# 4. Create Validation DataLoader
# (Shuffle = False for reproducible validation)
val_loader = DataLoader(
val_ds,
batch_size=2 * config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
collate_fn=collator,
pin_memory=True if config.DEVICE == "cuda" else False,
prefetch_factor=2,
persistent_workers=True,
)
# 5. Create Test DataLoader
test_loader = DataLoader(
test_ds,
batch_size=2 * config.BATCH_SIZE,
shuffle=False,
num_workers=2,
# num_workers=config.NUM_WORKERS,
collate_fn=collator,
pin_memory=True if config.DEVICE == "cuda" else False,
prefetch_factor=2,
)
print(f"DataLoader (train) created with {len(train_loader)} batches.")
print(f"DataLoader (val) created with {len(val_loader)} batches.")
print(f"DataLoader (test) created with {len(test_loader)} batches.")
return train_loader, val_loader, test_loader