GitHub Actions
deploy: sync from github 53893a5e68b161d96fbd7b2b5468a9351ac8347d
0083b07
from datasets import load_dataset
from transformers import DistilBertTokenizerFast
# AG News has 4 categories. These mappings are the single source of truth
# used by both training (loss computation) and serving (label decoding).
LABELS = ["World", "Sports", "Business", "Sci/Tech"]
ID2LABEL = {i: label for i, label in enumerate(LABELS)}
LABEL2ID = {label: i for i, label in enumerate(LABELS)}
NUM_LABELS = len(LABELS)
def load_ag_news(
tokenizer_name: str = "distilbert-base-uncased",
max_length: int = 128,
):
dataset = load_dataset("ag_news")
tokenizer = DistilBertTokenizerFast.from_pretrained(tokenizer_name)
def tokenize(batch):
return tokenizer(
batch["text"],
truncation=True,
max_length=max_length,
padding="max_length",
)
# batched=True processes 1000 articles at once instead of one by one — ~20x faster
dataset = dataset.map(tokenize, batched=True, batch_size=1000)
# Trainer expects "labels" (plural) — AG News ships it as "label" (singular)
dataset = dataset.rename_column("label", "labels")
# Return PyTorch tensors, not Python lists
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
return dataset, tokenizer