from datasets import load_dataset from transformers import DistilBertTokenizerFast # AG News has 4 categories. These mappings are the single source of truth # used by both training (loss computation) and serving (label decoding). LABELS = ["World", "Sports", "Business", "Sci/Tech"] ID2LABEL = {i: label for i, label in enumerate(LABELS)} LABEL2ID = {label: i for i, label in enumerate(LABELS)} NUM_LABELS = len(LABELS) def load_ag_news( tokenizer_name: str = "distilbert-base-uncased", max_length: int = 128, ): dataset = load_dataset("ag_news") tokenizer = DistilBertTokenizerFast.from_pretrained(tokenizer_name) def tokenize(batch): return tokenizer( batch["text"], truncation=True, max_length=max_length, padding="max_length", ) # batched=True processes 1000 articles at once instead of one by one — ~20x faster dataset = dataset.map(tokenize, batched=True, batch_size=1000) # Trainer expects "labels" (plural) — AG News ships it as "label" (singular) dataset = dataset.rename_column("label", "labels") # Return PyTorch tensors, not Python lists dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) return dataset, tokenizer