""" Data utilities for tokenization and dataset loading. """ from transformers import PreTrainedTokenizerFast from datasets import DatasetDict from config import MAX_INPUT, MAX_TARGET, TOKENIZER_NAME, SPECIAL_TOKENS, CACHE_DIR def load_tokenizer(): """Load tokenizer from Hugging Face Hub.""" print(f"Loading tokenizer from {TOKENIZER_NAME}...") try: tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_NAME) except Exception as e: print(f"Error loading from Hub, trying local fallback: {e}") tokenizer = PreTrainedTokenizerFast.from_pretrained("./tokenizer.json") # Add special tokens tokenizer.add_special_tokens(SPECIAL_TOKENS) return tokenizer tokenizer = load_tokenizer() def tokenize_batch(batch): """Tokenize a batch of examples.""" model_inputs = tokenizer( batch["source"], truncation=True, max_length=MAX_INPUT, padding="max_length", ) with tokenizer.as_target_tokenizer(): labels = tokenizer( batch["target"], truncation=True, max_length=MAX_TARGET, padding="max_length", ) model_inputs["labels"] = labels["input_ids"] return model_inputs def load_tokenized(name: str): """Load and tokenize dataset.""" print(f"Loading {name} dataset from cache...") raw = DatasetDict.load_from_disk(str(CACHE_DIR / name)) print(f"Tokenizing {name} dataset...") tokenized = raw.map( tokenize_batch, batched=True, remove_columns=["source", "target"], desc=f"Tokenizing {name}", ) print(f"{name} dataset tokenization complete!") return tokenized def get_tokenizer(): """Get the loaded tokenizer.""" return tokenizer