File size: 8,840 Bytes
e27ab6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as ArrowDataset
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
import config
from src import utils
class TranslationDataset(Dataset):
"""
A "lazy" Dataset.
Uses the high-level PreTrainedTokenizerFast wrapper.
"""
def __init__(
self,
dataset: ArrowDataset,
tokenizer: PreTrainedTokenizerFast,
max_len_src: int,
max_len_tgt: int,
src_lang: str = "en",
tgt_lang: str = "vi",
):
super().__init__()
self.dataset = dataset
self.tokenizer = tokenizer
self.max_len_src = max_len_src
self.max_len_tgt = max_len_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, index: int) -> dict[str, list[int]]:
item = self.dataset[index]["translation"]
src_text = item[self.src_lang]
tgt_text = item[self.tgt_lang]
# We set add_special_tokens=False for manual control.
src_encoding = self.tokenizer(
src_text,
truncation=True,
max_length=self.max_len_src,
add_special_tokens=False, # (Source has no SOS/EOS)
)
tgt_encoding = self.tokenizer(
tgt_text,
truncation=True,
max_length=self.max_len_tgt - 2, # (Reserve 2 spots for SOS/EOS)
add_special_tokens=False,
)
# Manually add SOS/EOS to target
src_ids = src_encoding["input_ids"]
tgt_ids = (
[config.SOS_TOKEN_ID] + tgt_encoding["input_ids"] + [config.EOS_TOKEN_ID]
)
return {"src_ids": src_ids, "tgt_ids": tgt_ids}
class DataCollator:
"""
Implements a custom collate_fn.
1. Takes a list of dicts (from __getitem__)
2. Adds SOS/EOS (Wait, we did this in Dataset)
3. Creates decoder inputs and labels (shifted)
4. Dynamically pads all sequences *in the batch*
5. Creates all 3 required masks
6. Returns a single dict of tensors
"""
def __init__(self, pad_token_id: int):
self.pad_token_id = pad_token_id
def __call__(self, batch: list[dict[str, list[int]]]) -> dict[str, Tensor]:
# 1. Get raw ID lists from the batch
src_ids_list = [item["src_ids"] for item in batch]
tgt_ids_list = [item["tgt_ids"] for item in batch] # (Already has SOS/EOS)
# 2. Create shifted inputs/labels
# Decoder input (T_tgt): [SOS, w1, w2, w3]
dec_input_ids_list = [ids[:-1] for ids in tgt_ids_list]
# Label (T_tgt): [w1, w2, w3, EOS]
labels_list = [ids[1:] for ids in tgt_ids_list]
# 3. Dynamic Padding
# We use torch.nn.utils.rnn.pad_sequence
# (Note: batch_first=True means (B, T))
src_ids_padded = nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in src_ids_list],
batch_first=True,
padding_value=self.pad_token_id,
)
dec_input_ids_padded = nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in dec_input_ids_list],
batch_first=True,
padding_value=self.pad_token_id,
)
labels_padded = nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in labels_list],
batch_first=True,
padding_value=self.pad_token_id, # (Loss will ignore this ID)
)
# 4. Get the sequence length
_, T_tgt = dec_input_ids_padded.shape
# 5. Create Masks (on CPU)
# (Mask 1) Source padding mask (for Encoder MHA & Cross-Attn)
# Shape: (B, 1, 1, T_src)
src_mask = utils.create_padding_mask(src_ids_padded, self.pad_token_id)
# (Mask 2) Target padding mask (for Decoder MHA)
# Shape: (B, 1, 1, T_tgt)
tgt_padding_mask = utils.create_padding_mask(
dec_input_ids_padded, self.pad_token_id
)
# (Mask 3) Target look-ahead mask (for Decoder MHA)
# Shape: (1, 1, T_tgt, T_tgt)
look_ahead_mask = utils.create_look_ahead_mask(T_tgt)
# (Mask 4) Combined target mask
# Shape: (B, 1, T_tgt, T_tgt)
tgt_mask = tgt_padding_mask & look_ahead_mask
return {
"src_ids": src_ids_padded, # (B, T_src)
"tgt_input_ids": dec_input_ids_padded, # (B, T_tgt)
"labels": labels_padded, # (B, T_tgt)
"src_mask": src_mask, # (B, 1, 1, T_src)
"tgt_mask": tgt_mask, # (B, 1, T_tgt, T_tgt)
}
def get_translation_datasets(
tokenizer: PreTrainedTokenizerFast,
) -> tuple[TranslationDataset, TranslationDataset, TranslationDataset]:
"""
A Factory function to automate the data pipeline setup.
It performs 3 steps:
1. Loads and cleans raw data (using src.utils).
2. Instantiates the TranslationDataset for Train, Val, and Test splits.
3. Returns the 3 PyTorch datasets ready for the DataLoader.
Args:
tokenizer: The trained tokenizer.
Returns:
Tuple containing (train_ds, val_ds, test_ds)
"""
# 1. Load raw cleaned data (returns Dict[str, Dataset])
# This keeps train.py clean from raw data handling logic.
train_data, val_data, test_data = utils.get_raw_data(
config.DATA_PATH, num_workers=config.NUM_WORKERS
)
train_data = train_data.select(range(config.NUM_SAMPLES_TO_USE))
print(f"Building PyTorch Datasets...")
# 2. Instantiate the Train Dataset
# (Uses global config for max_length)
train_ds = TranslationDataset(
dataset=train_data,
tokenizer=tokenizer,
max_len_src=config.MAX_SEQ_LEN,
max_len_tgt=config.MAX_SEQ_LEN,
)
# 3. Instantiate the Validation Dataset
val_ds = TranslationDataset(
dataset=val_data,
tokenizer=tokenizer,
max_len_src=config.MAX_SEQ_LEN,
max_len_tgt=config.MAX_SEQ_LEN,
)
# 4. Instantiate the Test Dataset
test_ds = TranslationDataset(
dataset=test_data,
tokenizer=tokenizer,
max_len_src=config.MAX_SEQ_LEN,
max_len_tgt=config.MAX_SEQ_LEN,
)
print(
f"Datasets created: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}"
)
return train_ds, val_ds, test_ds
def get_dataloaders(
tokenizer: PreTrainedTokenizerFast,
) -> tuple[DataLoader, DataLoader, DataLoader]:
"""
A high-level Factory function to create DataLoaders.
This function abstracts away all the data pipeline complexity:
- Loading/Cleaning raw data
- Creating PyTorch Datasets
- Instantiating the DataCollator (dynamic padding)
- Creating DataLoaders with the correct batch size and workers
Args:
tokenizer: The trained tokenizer.
Returns:
Tuple containing (train_loader, val_loader, test_loader)
"""
# 1. Create the Datasets (using the factory function we made earlier)
train_ds, val_ds, test_ds = get_translation_datasets(tokenizer)
# 2. Instantiate the Collator
# (We need config to get PAD_TOKEN_ID)
collator = DataCollator(pad_token_id=config.PAD_TOKEN_ID)
print(
f"Building DataLoaders (Batch Size: {config.BATCH_SIZE}, Workers: {config.NUM_WORKERS})..."
)
# 3. Create Train DataLoader
# (Shuffle = True is CRITICAL for training)
train_loader = DataLoader(
train_ds,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
collate_fn=collator,
pin_memory=True if config.DEVICE == "cuda" else False, # (Optimization)
prefetch_factor=2,
persistent_workers=True,
)
# 4. Create Validation DataLoader
# (Shuffle = False for reproducible validation)
val_loader = DataLoader(
val_ds,
batch_size=2 * config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
collate_fn=collator,
pin_memory=True if config.DEVICE == "cuda" else False,
prefetch_factor=2,
persistent_workers=True,
)
# 5. Create Test DataLoader
test_loader = DataLoader(
test_ds,
batch_size=2 * config.BATCH_SIZE,
shuffle=False,
num_workers=2,
# num_workers=config.NUM_WORKERS,
collate_fn=collator,
pin_memory=True if config.DEVICE == "cuda" else False,
prefetch_factor=2,
)
print(f"DataLoader (train) created with {len(train_loader)} batches.")
print(f"DataLoader (val) created with {len(val_loader)} batches.")
print(f"DataLoader (test) created with {len(test_loader)} batches.")
return train_loader, val_loader, test_loader
|