Spaces:
Running
Running
| """ | |
| data_loader.py | |
| ββββββββββββββ | |
| Handles all dataset loading, validation splitting, preprocessing and tokenisation. | |
| AG News label scheme: | |
| 0 = World 1 = Sports 2 = Business 3 = Sci/Tech | |
| """ | |
| import logging | |
| from typing import List, Optional, Tuple | |
| from datasets import load_dataset, DatasetDict | |
| from transformers import AutoTokenizer, PreTrainedTokenizerBase | |
| from config import CFG | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_ag_news( | |
| max_train: Optional[int] = CFG.max_train_samples, | |
| max_eval: Optional[int] = CFG.max_eval_samples, | |
| max_test: Optional[int] = CFG.max_test_samples, | |
| ) -> DatasetDict: | |
| """ | |
| Load AG News from the HuggingFace datasets cache (downloads on first call). | |
| AG News ships with 'train' (120 K) and 'test' (7.6 K) only. | |
| We carve out a stratified 10 % of 'train' as the validation set. | |
| Returns | |
| ------- | |
| DatasetDict with splits: 'train', 'validation', 'test' | |
| """ | |
| logger.info("Loading AG News dataset β¦") | |
| raw = load_dataset("ag_news") | |
| # Stratified 90/10 train β train + validation | |
| tv = raw["train"].train_test_split( | |
| test_size=0.10, | |
| seed=CFG.seed, | |
| stratify_by_column="label", | |
| ) | |
| dataset = DatasetDict({ | |
| "train": tv["train"], | |
| "validation": tv["test"], | |
| "test": raw["test"], | |
| }) | |
| # Optional down-sampling (speeds up CPU training significantly) | |
| if max_train is not None: | |
| n = min(max_train, len(dataset["train"])) | |
| dataset["train"] = ( | |
| dataset["train"].shuffle(seed=CFG.seed).select(range(n)) | |
| ) | |
| if max_eval is not None: | |
| n = min(max_eval, len(dataset["validation"])) | |
| dataset["validation"] = ( | |
| dataset["validation"].shuffle(seed=CFG.seed).select(range(n)) | |
| ) | |
| if max_test is not None: | |
| n = min(max_test, len(dataset["test"])) | |
| dataset["test"] = dataset["test"].select(range(n)) | |
| logger.info( | |
| f" train={len(dataset['train']):,} " | |
| f"val={len(dataset['validation']):,} " | |
| f"test={len(dataset['test']):,}" | |
| ) | |
| return dataset | |
| def load_test_only() -> Tuple[List[str], List[int]]: | |
| """ | |
| Load only the test split (fast, no stratified split overhead). | |
| Used by compare_results.py. | |
| """ | |
| raw = load_dataset("ag_news") | |
| return list(raw["test"]["text"]), list(raw["test"]["label"]) | |
| def get_raw_splits(dataset: DatasetDict) -> Tuple: | |
| """ | |
| Return plain Python lists of (texts, labels) for all three splits. | |
| Used by the scikit-learn traditional ML pipeline. | |
| """ | |
| X_train = list(dataset["train"]["text"]) | |
| y_train = list(dataset["train"]["label"]) | |
| X_val = list(dataset["validation"]["text"]) | |
| y_val = list(dataset["validation"]["label"]) | |
| X_test = list(dataset["test"]["text"]) | |
| y_test = list(dataset["test"]["label"]) | |
| return X_train, y_train, X_val, y_val, X_test, y_test | |
| def get_tokenizer() -> PreTrainedTokenizerBase: | |
| """Download (or load from local HuggingFace cache) the DistilBERT tokeniser.""" | |
| logger.info(f"Loading tokeniser: {CFG.model_checkpoint}") | |
| return AutoTokenizer.from_pretrained(CFG.model_checkpoint) | |
| def tokenise_dataset( | |
| dataset: DatasetDict, | |
| tokenizer: PreTrainedTokenizerBase, | |
| ) -> DatasetDict: | |
| """ | |
| Tokenise all splits for the HuggingFace Trainer. | |
| Design decisions: | |
| - padding=False β pads at collation time via DataCollatorWithPadding | |
| (more memory-efficient than padding all to max_length) | |
| - num_proc=1 β required on Windows; fork-based multi-processing | |
| causes issues with PyTorch on Windows | |
| """ | |
| def _tokenise(batch: dict) -> dict: | |
| return tokenizer( | |
| batch["text"], | |
| truncation=True, | |
| max_length=CFG.max_length, | |
| padding=False, | |
| ) | |
| logger.info("Tokenising dataset β¦") | |
| tokenised = dataset.map( | |
| _tokenise, | |
| batched=True, | |
| batch_size=1_000, | |
| num_proc=1, | |
| remove_columns=["text"], | |
| desc="Tokenising", | |
| ) | |
| # HuggingFace Trainer requires the label column to be named 'labels' | |
| tokenised = tokenised.rename_column("label", "labels") | |
| tokenised.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) | |
| return tokenised | |