nexa-classify-api / data_loader.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
4.7 kB
"""
data_loader.py
──────────────
Handles all dataset loading, validation splitting, preprocessing and tokenisation.
AG News label scheme:
0 = World 1 = Sports 2 = Business 3 = Sci/Tech
"""
import logging
from typing import List, Optional, Tuple
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from config import CFG
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# ── Public API ────────────────────────────────────────────────────────────────
def load_ag_news(
max_train: Optional[int] = CFG.max_train_samples,
max_eval: Optional[int] = CFG.max_eval_samples,
max_test: Optional[int] = CFG.max_test_samples,
) -> DatasetDict:
"""
Load AG News from the HuggingFace datasets cache (downloads on first call).
AG News ships with 'train' (120 K) and 'test' (7.6 K) only.
We carve out a stratified 10 % of 'train' as the validation set.
Returns
-------
DatasetDict with splits: 'train', 'validation', 'test'
"""
logger.info("Loading AG News dataset …")
raw = load_dataset("ag_news")
# Stratified 90/10 train β†’ train + validation
tv = raw["train"].train_test_split(
test_size=0.10,
seed=CFG.seed,
stratify_by_column="label",
)
dataset = DatasetDict({
"train": tv["train"],
"validation": tv["test"],
"test": raw["test"],
})
# Optional down-sampling (speeds up CPU training significantly)
if max_train is not None:
n = min(max_train, len(dataset["train"]))
dataset["train"] = (
dataset["train"].shuffle(seed=CFG.seed).select(range(n))
)
if max_eval is not None:
n = min(max_eval, len(dataset["validation"]))
dataset["validation"] = (
dataset["validation"].shuffle(seed=CFG.seed).select(range(n))
)
if max_test is not None:
n = min(max_test, len(dataset["test"]))
dataset["test"] = dataset["test"].select(range(n))
logger.info(
f" train={len(dataset['train']):,} "
f"val={len(dataset['validation']):,} "
f"test={len(dataset['test']):,}"
)
return dataset
def load_test_only() -> Tuple[List[str], List[int]]:
"""
Load only the test split (fast, no stratified split overhead).
Used by compare_results.py.
"""
raw = load_dataset("ag_news")
return list(raw["test"]["text"]), list(raw["test"]["label"])
def get_raw_splits(dataset: DatasetDict) -> Tuple:
"""
Return plain Python lists of (texts, labels) for all three splits.
Used by the scikit-learn traditional ML pipeline.
"""
X_train = list(dataset["train"]["text"])
y_train = list(dataset["train"]["label"])
X_val = list(dataset["validation"]["text"])
y_val = list(dataset["validation"]["label"])
X_test = list(dataset["test"]["text"])
y_test = list(dataset["test"]["label"])
return X_train, y_train, X_val, y_val, X_test, y_test
def get_tokenizer() -> PreTrainedTokenizerBase:
"""Download (or load from local HuggingFace cache) the DistilBERT tokeniser."""
logger.info(f"Loading tokeniser: {CFG.model_checkpoint}")
return AutoTokenizer.from_pretrained(CFG.model_checkpoint)
def tokenise_dataset(
dataset: DatasetDict,
tokenizer: PreTrainedTokenizerBase,
) -> DatasetDict:
"""
Tokenise all splits for the HuggingFace Trainer.
Design decisions:
- padding=False β†’ pads at collation time via DataCollatorWithPadding
(more memory-efficient than padding all to max_length)
- num_proc=1 β†’ required on Windows; fork-based multi-processing
causes issues with PyTorch on Windows
"""
def _tokenise(batch: dict) -> dict:
return tokenizer(
batch["text"],
truncation=True,
max_length=CFG.max_length,
padding=False,
)
logger.info("Tokenising dataset …")
tokenised = dataset.map(
_tokenise,
batched=True,
batch_size=1_000,
num_proc=1,
remove_columns=["text"],
desc="Tokenising",
)
# HuggingFace Trainer requires the label column to be named 'labels'
tokenised = tokenised.rename_column("label", "labels")
tokenised.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
return tokenised