ord-training-simple / src /data_utils.py
Vaishnav14220
Orchestrate full ORD training pipeline in Space
29a351f
"""
Data utilities for tokenization and dataset loading.
"""
from transformers import PreTrainedTokenizerFast
from datasets import DatasetDict
from config import MAX_INPUT, MAX_TARGET, TOKENIZER_NAME, SPECIAL_TOKENS, CACHE_DIR
def load_tokenizer():
"""Load tokenizer from Hugging Face Hub."""
print(f"Loading tokenizer from {TOKENIZER_NAME}...")
try:
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_NAME)
except Exception as e:
print(f"Error loading from Hub, trying local fallback: {e}")
tokenizer = PreTrainedTokenizerFast.from_pretrained("./tokenizer.json")
# Add special tokens
tokenizer.add_special_tokens(SPECIAL_TOKENS)
return tokenizer
tokenizer = load_tokenizer()
def tokenize_batch(batch):
"""Tokenize a batch of examples."""
model_inputs = tokenizer(
batch["source"],
truncation=True,
max_length=MAX_INPUT,
padding="max_length",
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch["target"],
truncation=True,
max_length=MAX_TARGET,
padding="max_length",
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def load_tokenized(name: str):
"""Load and tokenize dataset."""
print(f"Loading {name} dataset from cache...")
raw = DatasetDict.load_from_disk(str(CACHE_DIR / name))
print(f"Tokenizing {name} dataset...")
tokenized = raw.map(
tokenize_batch,
batched=True,
remove_columns=["source", "target"],
desc=f"Tokenizing {name}",
)
print(f"{name} dataset tokenization complete!")
return tokenized
def get_tokenizer():
"""Get the loaded tokenizer."""
return tokenizer