Spaces:
Running
Running
File size: 4,700 Bytes
a229747 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
data_loader.py
ββββββββββββββ
Handles all dataset loading, validation splitting, preprocessing and tokenisation.
AG News label scheme:
0 = World 1 = Sports 2 = Business 3 = Sci/Tech
"""
import logging
from typing import List, Optional, Tuple
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from config import CFG
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_ag_news(
max_train: Optional[int] = CFG.max_train_samples,
max_eval: Optional[int] = CFG.max_eval_samples,
max_test: Optional[int] = CFG.max_test_samples,
) -> DatasetDict:
"""
Load AG News from the HuggingFace datasets cache (downloads on first call).
AG News ships with 'train' (120 K) and 'test' (7.6 K) only.
We carve out a stratified 10 % of 'train' as the validation set.
Returns
-------
DatasetDict with splits: 'train', 'validation', 'test'
"""
logger.info("Loading AG News dataset β¦")
raw = load_dataset("ag_news")
# Stratified 90/10 train β train + validation
tv = raw["train"].train_test_split(
test_size=0.10,
seed=CFG.seed,
stratify_by_column="label",
)
dataset = DatasetDict({
"train": tv["train"],
"validation": tv["test"],
"test": raw["test"],
})
# Optional down-sampling (speeds up CPU training significantly)
if max_train is not None:
n = min(max_train, len(dataset["train"]))
dataset["train"] = (
dataset["train"].shuffle(seed=CFG.seed).select(range(n))
)
if max_eval is not None:
n = min(max_eval, len(dataset["validation"]))
dataset["validation"] = (
dataset["validation"].shuffle(seed=CFG.seed).select(range(n))
)
if max_test is not None:
n = min(max_test, len(dataset["test"]))
dataset["test"] = dataset["test"].select(range(n))
logger.info(
f" train={len(dataset['train']):,} "
f"val={len(dataset['validation']):,} "
f"test={len(dataset['test']):,}"
)
return dataset
def load_test_only() -> Tuple[List[str], List[int]]:
"""
Load only the test split (fast, no stratified split overhead).
Used by compare_results.py.
"""
raw = load_dataset("ag_news")
return list(raw["test"]["text"]), list(raw["test"]["label"])
def get_raw_splits(dataset: DatasetDict) -> Tuple:
"""
Return plain Python lists of (texts, labels) for all three splits.
Used by the scikit-learn traditional ML pipeline.
"""
X_train = list(dataset["train"]["text"])
y_train = list(dataset["train"]["label"])
X_val = list(dataset["validation"]["text"])
y_val = list(dataset["validation"]["label"])
X_test = list(dataset["test"]["text"])
y_test = list(dataset["test"]["label"])
return X_train, y_train, X_val, y_val, X_test, y_test
def get_tokenizer() -> PreTrainedTokenizerBase:
"""Download (or load from local HuggingFace cache) the DistilBERT tokeniser."""
logger.info(f"Loading tokeniser: {CFG.model_checkpoint}")
return AutoTokenizer.from_pretrained(CFG.model_checkpoint)
def tokenise_dataset(
dataset: DatasetDict,
tokenizer: PreTrainedTokenizerBase,
) -> DatasetDict:
"""
Tokenise all splits for the HuggingFace Trainer.
Design decisions:
- padding=False β pads at collation time via DataCollatorWithPadding
(more memory-efficient than padding all to max_length)
- num_proc=1 β required on Windows; fork-based multi-processing
causes issues with PyTorch on Windows
"""
def _tokenise(batch: dict) -> dict:
return tokenizer(
batch["text"],
truncation=True,
max_length=CFG.max_length,
padding=False,
)
logger.info("Tokenising dataset β¦")
tokenised = dataset.map(
_tokenise,
batched=True,
batch_size=1_000,
num_proc=1,
remove_columns=["text"],
desc="Tokenising",
)
# HuggingFace Trainer requires the label column to be named 'labels'
tokenised = tokenised.rename_column("label", "labels")
tokenised.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
return tokenised
|