""" data_loader.py ────────────── Handles all dataset loading, validation splitting, preprocessing and tokenisation. AG News label scheme: 0 = World 1 = Sports 2 = Business 3 = Sci/Tech """ import logging from typing import List, Optional, Tuple from datasets import load_dataset, DatasetDict from transformers import AutoTokenizer, PreTrainedTokenizerBase from config import CFG logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) # ── Public API ──────────────────────────────────────────────────────────────── def load_ag_news( max_train: Optional[int] = CFG.max_train_samples, max_eval: Optional[int] = CFG.max_eval_samples, max_test: Optional[int] = CFG.max_test_samples, ) -> DatasetDict: """ Load AG News from the HuggingFace datasets cache (downloads on first call). AG News ships with 'train' (120 K) and 'test' (7.6 K) only. We carve out a stratified 10 % of 'train' as the validation set. Returns ------- DatasetDict with splits: 'train', 'validation', 'test' """ logger.info("Loading AG News dataset …") raw = load_dataset("ag_news") # Stratified 90/10 train → train + validation tv = raw["train"].train_test_split( test_size=0.10, seed=CFG.seed, stratify_by_column="label", ) dataset = DatasetDict({ "train": tv["train"], "validation": tv["test"], "test": raw["test"], }) # Optional down-sampling (speeds up CPU training significantly) if max_train is not None: n = min(max_train, len(dataset["train"])) dataset["train"] = ( dataset["train"].shuffle(seed=CFG.seed).select(range(n)) ) if max_eval is not None: n = min(max_eval, len(dataset["validation"])) dataset["validation"] = ( dataset["validation"].shuffle(seed=CFG.seed).select(range(n)) ) if max_test is not None: n = min(max_test, len(dataset["test"])) dataset["test"] = dataset["test"].select(range(n)) logger.info( f" train={len(dataset['train']):,} " f"val={len(dataset['validation']):,} " f"test={len(dataset['test']):,}" ) return dataset def load_test_only() -> Tuple[List[str], List[int]]: """ Load only the test split (fast, no stratified split overhead). Used by compare_results.py. """ raw = load_dataset("ag_news") return list(raw["test"]["text"]), list(raw["test"]["label"]) def get_raw_splits(dataset: DatasetDict) -> Tuple: """ Return plain Python lists of (texts, labels) for all three splits. Used by the scikit-learn traditional ML pipeline. """ X_train = list(dataset["train"]["text"]) y_train = list(dataset["train"]["label"]) X_val = list(dataset["validation"]["text"]) y_val = list(dataset["validation"]["label"]) X_test = list(dataset["test"]["text"]) y_test = list(dataset["test"]["label"]) return X_train, y_train, X_val, y_val, X_test, y_test def get_tokenizer() -> PreTrainedTokenizerBase: """Download (or load from local HuggingFace cache) the DistilBERT tokeniser.""" logger.info(f"Loading tokeniser: {CFG.model_checkpoint}") return AutoTokenizer.from_pretrained(CFG.model_checkpoint) def tokenise_dataset( dataset: DatasetDict, tokenizer: PreTrainedTokenizerBase, ) -> DatasetDict: """ Tokenise all splits for the HuggingFace Trainer. Design decisions: - padding=False → pads at collation time via DataCollatorWithPadding (more memory-efficient than padding all to max_length) - num_proc=1 → required on Windows; fork-based multi-processing causes issues with PyTorch on Windows """ def _tokenise(batch: dict) -> dict: return tokenizer( batch["text"], truncation=True, max_length=CFG.max_length, padding=False, ) logger.info("Tokenising dataset …") tokenised = dataset.map( _tokenise, batched=True, batch_size=1_000, num_proc=1, remove_columns=["text"], desc="Tokenising", ) # HuggingFace Trainer requires the label column to be named 'labels' tokenised = tokenised.rename_column("label", "labels") tokenised.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) return tokenised