|
|
""" |
|
|
Data utilities for tokenization and dataset loading. |
|
|
""" |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from datasets import DatasetDict |
|
|
from config import MAX_INPUT, MAX_TARGET, TOKENIZER_NAME, SPECIAL_TOKENS, CACHE_DIR |
|
|
|
|
|
|
|
|
def load_tokenizer(): |
|
|
"""Load tokenizer from Hugging Face Hub.""" |
|
|
print(f"Loading tokenizer from {TOKENIZER_NAME}...") |
|
|
try: |
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_NAME) |
|
|
except Exception as e: |
|
|
print(f"Error loading from Hub, trying local fallback: {e}") |
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained("./tokenizer.json") |
|
|
|
|
|
|
|
|
tokenizer.add_special_tokens(SPECIAL_TOKENS) |
|
|
return tokenizer |
|
|
|
|
|
|
|
|
tokenizer = load_tokenizer() |
|
|
|
|
|
|
|
|
def tokenize_batch(batch): |
|
|
"""Tokenize a batch of examples.""" |
|
|
model_inputs = tokenizer( |
|
|
batch["source"], |
|
|
truncation=True, |
|
|
max_length=MAX_INPUT, |
|
|
padding="max_length", |
|
|
) |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
|
labels = tokenizer( |
|
|
batch["target"], |
|
|
truncation=True, |
|
|
max_length=MAX_TARGET, |
|
|
padding="max_length", |
|
|
) |
|
|
|
|
|
model_inputs["labels"] = labels["input_ids"] |
|
|
return model_inputs |
|
|
|
|
|
|
|
|
def load_tokenized(name: str): |
|
|
"""Load and tokenize dataset.""" |
|
|
print(f"Loading {name} dataset from cache...") |
|
|
raw = DatasetDict.load_from_disk(str(CACHE_DIR / name)) |
|
|
|
|
|
print(f"Tokenizing {name} dataset...") |
|
|
tokenized = raw.map( |
|
|
tokenize_batch, |
|
|
batched=True, |
|
|
remove_columns=["source", "target"], |
|
|
desc=f"Tokenizing {name}", |
|
|
) |
|
|
|
|
|
print(f"{name} dataset tokenization complete!") |
|
|
return tokenized |
|
|
|
|
|
|
|
|
def get_tokenizer(): |
|
|
"""Get the loaded tokenizer.""" |
|
|
return tokenizer |
|
|
|