File size: 3,052 Bytes
5a27052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
"""
Utility for loading Nvidia's Aegis AI Content Safety Dataset 2.0 with
the exact fields needed for prompt injection detection experiments.

Only the `prompt` text and the normalized `prompt_label` fields are kept.
Labels are mapped to integers: `safe -> 0`, `unsafe -> 1`.
"""

from __future__ import annotations

from typing import Dict, Optional

from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset

DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0"
LABEL_MAP = {"safe": 0, "unsafe": 1}
SELECTED_COLUMNS = ["prompt", "prompt_label"]


def _map_labels(batch: Dict[str, list]) -> Dict[str, list]:
    """Batched mapping function that converts string labels to ints."""
    batch["prompt_label"] = [LABEL_MAP[label] for label in batch["prompt_label"]]
    return batch


def _prepare_split(ds: Dataset) -> Dataset:
    """
    Keep only the required columns and normalize labels for a single split.
    """
    subset = ds.select_columns(SELECTED_COLUMNS)
    return subset.map(_map_labels, batched=True)


def load_aegis_dataset(
    split: Optional[str] = None,
    streaming: bool = False,
) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict:
    """
    Load the Aegis dataset with normalized `prompt_label`.

    Args:
        split: Optional split name ("train", "validation", "test", etc.).
        streaming: Whether to stream the data instead of downloading it locally.

    Returns:
        A processed Dataset (if split is provided) or DatasetDict containing only
        `prompt` and integer `prompt_label` columns.
    """
    dataset = load_dataset(DATASET_NAME, split=split, streaming=streaming)

    if split is not None:
        if streaming:
            # IterableDataset does not support select_columns/map the same way.
            def generator():
                for row in dataset:
                    yield {
                        "prompt": row["prompt"],
                        "prompt_label": LABEL_MAP[row["prompt_label"]],
                    }

            return IterableDataset.from_generator(generator)

        return _prepare_split(dataset)

    # Multiple splits.
    if streaming:
        processed = {}
        for split_name, iterable in dataset.items():
            def make_iter(it):
                def generator():
                    for row in it:
                        yield {
                            "prompt": row["prompt"],
                            "prompt_label": LABEL_MAP[row["prompt_label"]],
                        }

                return IterableDataset.from_generator(generator)

            processed[split_name] = make_iter(iterable)
        return IterableDatasetDict(processed)

    return DatasetDict({split_name: _prepare_split(split_ds) for split_name, split_ds in dataset.items()})


if __name__ == "__main__":
    processed = load_aegis_dataset()
    for split_name, split_ds in processed.items():
        print(f"{split_name}: {len(split_ds)} samples")
        print(split_ds[0])