File size: 1,797 Bytes
29a351f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""
Data utilities for tokenization and dataset loading.
"""
from transformers import PreTrainedTokenizerFast
from datasets import DatasetDict
from config import MAX_INPUT, MAX_TARGET, TOKENIZER_NAME, SPECIAL_TOKENS, CACHE_DIR
def load_tokenizer():
"""Load tokenizer from Hugging Face Hub."""
print(f"Loading tokenizer from {TOKENIZER_NAME}...")
try:
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_NAME)
except Exception as e:
print(f"Error loading from Hub, trying local fallback: {e}")
tokenizer = PreTrainedTokenizerFast.from_pretrained("./tokenizer.json")
# Add special tokens
tokenizer.add_special_tokens(SPECIAL_TOKENS)
return tokenizer
tokenizer = load_tokenizer()
def tokenize_batch(batch):
"""Tokenize a batch of examples."""
model_inputs = tokenizer(
batch["source"],
truncation=True,
max_length=MAX_INPUT,
padding="max_length",
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch["target"],
truncation=True,
max_length=MAX_TARGET,
padding="max_length",
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def load_tokenized(name: str):
"""Load and tokenize dataset."""
print(f"Loading {name} dataset from cache...")
raw = DatasetDict.load_from_disk(str(CACHE_DIR / name))
print(f"Tokenizing {name} dataset...")
tokenized = raw.map(
tokenize_batch,
batched=True,
remove_columns=["source", "target"],
desc=f"Tokenizing {name}",
)
print(f"{name} dataset tokenization complete!")
return tokenized
def get_tokenizer():
"""Get the loaded tokenizer."""
return tokenizer
|