"""Load public safety/factuality benchmarks from Hugging Face datasets.""" from __future__ import annotations import logging import random from collections.abc import Callable from evaluation.prompts import EvalMetric, EvalPrompt logger = logging.getLogger(__name__) BenchmarkLoader = Callable[[int, int], list[EvalPrompt]] def _sample_rows(rows: list[EvalPrompt], limit: int, seed: int) -> list[EvalPrompt]: if limit <= 0 or limit >= len(rows): return rows rng = random.Random(seed) return rng.sample(rows, limit) def _get_row_field(row: dict, *keys: str): for key in keys: value = row.get(key) if value is not None and str(value).strip(): return value return None def _truthfulqa_reference(row: dict) -> str | None: best = _get_row_field(row, "best_answer", "Best Answer") correct = _get_row_field(row, "correct_answers", "Correct Answers") if isinstance(correct, str): parts = [part.strip() for part in correct.split(";") if part.strip()] correct = parts elif not correct: correct = [] if best: return str(best).strip() if correct: return str(correct[0]).strip() return None def load_truthfulqa(limit: int = 5, seed: int = 42) -> list[EvalPrompt]: """TruthfulQA (Lin et al.) — public factuality / hallucination benchmark.""" from datasets import load_dataset dataset_candidates = [ ("domenicrosati/TruthfulQA", None, "train"), ("truthfulqa/truthful_qa", "generation", "validation"), ("leibni/truthful_qa", "generation", "validation"), ] for dataset_name, config, split in dataset_candidates: try: if config: ds = load_dataset(dataset_name, config, split=split) else: ds = load_dataset(dataset_name, split=split) prompts: list[EvalPrompt] = [] for idx, row in enumerate(ds): question = _get_row_field(row, "question", "Question") if not question: continue prompts.append( EvalPrompt( id=f"truthfulqa_{idx}", metric="hallucination", category="factual", prompt=str(question).strip(), reference=_truthfulqa_reference(row), expect_refusal=False, notes=f"TruthfulQA from {dataset_name}", source="public", benchmark="truthfulqa", ) ) sampled = _sample_rows(prompts, limit, seed) if sampled: logger.info("Loaded TruthfulQA from %s", dataset_name) return sampled except Exception as exc: logger.warning("Could not load TruthfulQA from %s: %s", dataset_name, exc) logger.warning("All TruthfulQA sources failed; returning empty set") return [] def load_real_toxicity_prompts(limit: int = 5, seed: int = 42) -> list[EvalPrompt]: """RealToxicityPrompts (Gehman et al.) — toxicity / harmful generation benchmark.""" from datasets import load_dataset ds = load_dataset("allenai/real-toxicity-prompts", split="train") prompts: list[EvalPrompt] = [] for idx, row in enumerate(ds): text = row.get("prompt", {}).get("text") or row.get("text") or "" text = str(text).strip() if not text: continue prompts.append( EvalPrompt( id=f"rtp_{idx}", metric="bias_harm", category="bias", prompt=text, expect_refusal=True, notes="RealToxicityPrompts", source="public", benchmark="real_toxicity_prompts", ) ) return _sample_rows(prompts, limit, seed) def _extract_prompt_text(row: dict) -> str: for key in ( "prompt", "goal", "question", "instruction", "text", "adversarial", "vanilla", "jailbreak_query", "behavior", ): value = row.get(key) if value and str(value).strip(): return str(value).strip() return "" def _rows_to_adversarial_prompts( rows, *, benchmark_name: str, source_note: str, id_prefix: str, ) -> list[EvalPrompt]: prompts: list[EvalPrompt] = [] for idx, row in enumerate(rows): text = _extract_prompt_text(row) if not text: continue prompts.append( EvalPrompt( id=f"{id_prefix}_{idx}", metric="content_safety", category="adversarial", prompt=text, expect_refusal=True, notes=source_note, source="public", benchmark=benchmark_name, ) ) return prompts def load_simple_safety_tests(limit: int = 5, seed: int = 42) -> list[EvalPrompt]: """Safety/jailbreak benchmarks — SimpleSafetyTests + open fallbacks.""" from datasets import load_dataset def load_sst_instruct() -> list[EvalPrompt]: ds = load_dataset("walledai/SimpleSafetyTests", split="instruct") return _rows_to_adversarial_prompts( ds, benchmark_name="simple_safety_tests", source_note="walledai/SimpleSafetyTests (instruct)", id_prefix="sst", ) def load_jailbreak_wild() -> list[EvalPrompt]: ds = load_dataset( "TrustAIRLab/in-the-wild-jailbreak-prompts", "jailbreak_2023_05_07", split="train", ) return _rows_to_adversarial_prompts( ds, benchmark_name="simple_safety_tests", source_note="TrustAIRLab/in-the-wild-jailbreak-prompts", id_prefix="jailbreak", ) def load_forbidden_questions() -> list[EvalPrompt]: ds = load_dataset("TrustAIRLab/forbidden_question_set", split="train") return _rows_to_adversarial_prompts( ds, benchmark_name="simple_safety_tests", source_note="TrustAIRLab/forbidden_question_set", id_prefix="forbidden", ) for name, loader in [ ("SimpleSafetyTests", load_sst_instruct), ("in-the-wild-jailbreak-prompts", load_jailbreak_wild), ("forbidden_question_set", load_forbidden_questions), ]: try: prompts = loader() sampled = _sample_rows(prompts, limit, seed) if sampled: logger.info("Loaded safety benchmark from %s", name) return sampled except Exception as exc: logger.warning("Could not load safety benchmark %s: %s", name, exc) logger.warning("All safety benchmark sources failed; returning empty set") return [] def _load_bbq_subset(dataset_name: str, config: str | None, limit: int, seed: int) -> list[EvalPrompt]: from datasets import load_dataset if config: ds = load_dataset(dataset_name, config, split="test") else: ds = load_dataset(dataset_name, split="test") prompts: list[EvalPrompt] = [] for idx, row in enumerate(ds): context = str(row.get("context", "")).strip() question = str(row.get("question", "")).strip() if not question: continue prompt = f"{context}\n\n{question}".strip() if context else question prompts.append( EvalPrompt( id=f"bbq_{idx}", metric="bias_harm", category="bias", prompt=prompt, expect_refusal=False, notes=f"BBQ subset from {dataset_name}", source="public", benchmark="bbq", ) ) return _sample_rows(prompts, limit, seed) def load_bbq(limit: int = 5, seed: int = 42) -> list[EvalPrompt]: """BBQ (Bias Benchmark for QA, Parrish et al.) — social bias benchmark.""" candidates = [ ("oskarvanderwal/bbq", "Age"), ("Heegyu/bbq", None), ] last_error: Exception | None = None for dataset_name, config in candidates: try: return _load_bbq_subset(dataset_name, config, limit, seed) except Exception as exc: last_error = exc logger.warning("Could not load BBQ subset %s: %s", dataset_name, exc) raise RuntimeError("Failed to load BBQ benchmark") from last_error PUBLIC_BENCHMARKS: dict[str, BenchmarkLoader] = { "truthfulqa": load_truthfulqa, "real_toxicity_prompts": load_real_toxicity_prompts, "simple_safety_tests": load_simple_safety_tests, "bbq": load_bbq, } def load_public_benchmarks( names: list[str] | None = None, samples_per_benchmark: int = 5, seed: int = 42, ) -> list[EvalPrompt]: if names and "all" in names: names = list(PUBLIC_BENCHMARKS.keys()) selected = names or list(PUBLIC_BENCHMARKS.keys()) prompts: list[EvalPrompt] = [] failures: list[str] = [] for name in selected: if name not in PUBLIC_BENCHMARKS: raise ValueError( f"Unknown benchmark '{name}'. Available: {', '.join(PUBLIC_BENCHMARKS)}" ) loader = PUBLIC_BENCHMARKS[name] try: loaded = loader(samples_per_benchmark, seed) if not loaded: failures.append(name) logger.warning("No prompts loaded for public benchmark '%s'", name) continue prompts.extend(loaded) logger.info("Loaded %d prompts from public benchmark '%s'", len(loaded), name) except Exception as exc: failures.append(name) logger.warning("Failed to load public benchmark '%s': %s", name, exc) if failures: logger.warning("Skipped failed benchmarks: %s", ", ".join(failures)) if not prompts: raise RuntimeError( f"Could not load any public benchmarks. Failures: {', '.join(failures)}" ) return prompts