Spaces:
Sleeping
Sleeping
| """ | |
| UMBRA Dataset Loader β fetches and caches all training datasets from HuggingFace Hub. | |
| Datasets used: | |
| 1. Anthropic/hh-rlhf β GRPO preference training (chosen/rejected pairs) | |
| 2. ai4privacy/pii-masking-300k β Sentrix PII validation test cases | |
| 3. truthful_qa β Calibration / honest-uncertainty reward signal | |
| 4. lmsys/toxic-chat β Adversarial & prompt-injection lines for Manipulator NPC | |
| 5. daily_dialog β Natural multi-turn dialogue for NPC script expansion | |
| All datasets are downloaded once and cached under .cache/umbra_datasets/. | |
| Set UMBRA_OFFLINE=1 env var to skip downloads and use only cached data. | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| CACHE_DIR = Path(".cache/umbra_datasets") | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| OFFLINE = os.environ.get("UMBRA_OFFLINE", "0") == "1" | |
| # ββ Loaders βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_hh_rlhf(split: str = "train", max_samples: int = 5000): | |
| """ | |
| Anthropic/hh-rlhf β chosen / rejected conversation pairs. | |
| Used for: GRPO preference training prompt extraction. | |
| Schema: {"chosen": str, "rejected": str} | |
| """ | |
| from datasets import load_dataset | |
| logger.info(f"[Dataset] Loading hh-rlhf split={split}, max={max_samples}") | |
| ds = load_dataset( | |
| "Anthropic/hh-rlhf", | |
| split=split, | |
| cache_dir=str(CACHE_DIR), | |
| ) | |
| if max_samples and len(ds) > max_samples: | |
| ds = ds.select(range(max_samples)) | |
| logger.info(f"[Dataset] hh-rlhf loaded: {len(ds)} rows") | |
| return ds | |
| def load_pii_masking(split: str = "train", max_samples: int = 2000): | |
| """ | |
| ai4privacy/pii-masking-300k β text with Indian PII annotated. | |
| Used for: Sentrix regex validation and false-positive measurement. | |
| Schema: {"source_text": str, "target_text": str, "privacy_mask": list} | |
| """ | |
| from datasets import load_dataset | |
| logger.info(f"[Dataset] Loading pii-masking-300k split={split}, max={max_samples}") | |
| ds = load_dataset( | |
| "ai4privacy/pii-masking-300k", | |
| split=split, | |
| cache_dir=str(CACHE_DIR), | |
| ) | |
| if max_samples and len(ds) > max_samples: | |
| ds = ds.select(range(max_samples)) | |
| logger.info(f"[Dataset] pii-masking-300k loaded: {len(ds)} rows") | |
| return ds | |
| def load_truthful_qa(split: str = "validation"): | |
| """ | |
| truthful_qa β factual questions with best/correct answers. | |
| Used for: calibration reward (rewarding 'I don't know' on hard questions). | |
| Schema: {"question": str, "best_answer": str, "correct_answers": list, "incorrect_answers": list} | |
| """ | |
| from datasets import load_dataset | |
| logger.info(f"[Dataset] Loading truthful_qa split={split}") | |
| ds = load_dataset( | |
| "truthful_qa", | |
| "generation", | |
| split=split, | |
| cache_dir=str(CACHE_DIR), | |
| ) | |
| logger.info(f"[Dataset] truthful_qa loaded: {len(ds)} rows") | |
| return ds | |
| def load_toxic_chat(split: str = "train", max_samples: int = 1000): | |
| """ | |
| lmsys/toxic-chat β adversarial user inputs including prompt injections. | |
| Used for: expanding Manipulator NPC script pool with real injection attempts. | |
| Schema: {"user_input": str, "model_output": str, "toxicity_annotation": int, "jailbreaking": bool} | |
| """ | |
| from datasets import load_dataset | |
| logger.info(f"[Dataset] Loading toxic-chat split={split}, max={max_samples}") | |
| ds = load_dataset( | |
| "lmsys/toxic-chat", | |
| "toxicchat0124", | |
| split=split, | |
| cache_dir=str(CACHE_DIR), | |
| ) | |
| if max_samples and len(ds) > max_samples: | |
| ds = ds.select(range(max_samples)) | |
| logger.info(f"[Dataset] toxic-chat loaded: {len(ds)} rows") | |
| return ds | |
| def load_daily_dialog(split: str = "train", max_samples: int = 2000): | |
| """ | |
| daily_dialog β multi-turn natural conversations with dialogue act labels. | |
| Used for: expanding Agreeable, Liar, and Emotional NPC script pools. | |
| Act labels: 1=inform, 2=question, 3=directive, 4=commissive | |
| Schema: {"dialog": list[str], "act": list[int], "emotion": list[int]} | |
| NOTE: The original daily_dialog dataset uses a loading script which is | |
| blocked in datasets>=2.20. Falls back to silicone/dyda_da (parquet), | |
| then to a minimal synthetic dataset so training always proceeds. | |
| """ | |
| from datasets import load_dataset, Dataset | |
| from collections import defaultdict | |
| # ββ Attempt 1: silicone/dyda_da (parquet mirror of daily_dialog acts) ββ | |
| try: | |
| logger.info(f"[Dataset] Loading silicone/dyda_da (daily_dialog replacement) split={split}") | |
| da_ds = load_dataset( | |
| "silicone", "dyda_da", | |
| split=split, | |
| cache_dir=str(CACHE_DIR), | |
| ) | |
| # Group flat utterance rows back into dialog format by Dialogue_ID | |
| dialogs: dict = defaultdict(list) | |
| acts_map: dict = defaultdict(list) | |
| for row in da_ds: | |
| did = row["Dialogue_ID"] | |
| dialogs[did].append(row["Utterance"]) | |
| # dyda_da act labels: 0=__dummy__, 1=inform, 2=question, 3=directive, 4=commissive | |
| acts_map[did].append(int(row["Dialogue_Act"])) | |
| rows = [] | |
| for did in list(dialogs.keys())[:max_samples]: | |
| utt = dialogs[did] | |
| rows.append({ | |
| "dialog": utt, | |
| "act": acts_map[did], | |
| "emotion": [0] * len(utt), # dyda_da has no emotion field | |
| }) | |
| ds = Dataset.from_list(rows) | |
| logger.info(f"[Dataset] silicone/dyda_da loaded: {len(ds)} dialogs") | |
| return ds | |
| except Exception as e1: | |
| logger.warning(f"[Dataset] silicone/dyda_da unavailable ({e1}), using synthetic fallback") | |
| # ββ Fallback: minimal hardcoded dialogues (NPC static scripts still work) ββ | |
| SYNTHETIC = [ | |
| {"dialog": ["Sure, I'll help you with that.", "Of course, anything for you!", "Happy to assist anytime."], | |
| "act": [4, 4, 4], "emotion": [0, 0, 0]}, | |
| {"dialog": ["You really should do this now.", "I'm telling you, it's urgent!", "Don't wait any longer."], | |
| "act": [3, 3, 3], "emotion": [1, 2, 1]}, | |
| {"dialog": ["I didn't say that at all.", "Actually, that's not what I meant.", "You're misremembering."], | |
| "act": [1, 1, 1], "emotion": [0, 0, 0]}, | |
| {"dialog": ["What do you think about it?", "Can you explain that further?", "Are you sure?"], | |
| "act": [2, 2, 2], "emotion": [0, 0, 0]}, | |
| {"dialog": ["I promise I'll do better.", "You can count on me this time.", "I won't let you down."], | |
| "act": [4, 4, 4], "emotion": [0, 0, 0]}, | |
| {"dialog": ["Listen to me carefully.", "You need to act on this immediately.", "There's no time to hesitate."], | |
| "act": [3, 3, 3], "emotion": [2, 3, 2]}, | |
| ] | |
| # Repeat to fill max_samples | |
| repeated = (SYNTHETIC * (max_samples // len(SYNTHETIC) + 1))[:max_samples] | |
| ds = Dataset.from_list(repeated) | |
| logger.info(f"[Dataset] Using {len(ds)} synthetic fallback dialogues for NPC expansion") | |
| return ds | |
| # ββ Builders ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_grpo_prompts(hh_ds, max_prompts: int = 2000) -> list[str]: | |
| """ | |
| Extract prompt strings from hh-rlhf chosen/rejected pairs for GRPO training. | |
| The format is: "\\n\\nHuman: <prompt>\\n\\nAssistant: <response>" | |
| We extract only the Human turn as the prompt. | |
| Returns a list of prompt strings (deduplicated). | |
| """ | |
| prompts = [] | |
| seen = set() | |
| for row in hh_ds: | |
| text = row.get("chosen", "") | |
| # Split on first Assistant turn to isolate the human prompt | |
| parts = text.split("\n\nAssistant:") | |
| if not parts: | |
| continue | |
| raw_prompt = parts[0].replace("\n\nHuman:", "").strip() | |
| # Filter too-short or duplicate prompts | |
| if len(raw_prompt) < 20 or raw_prompt in seen: | |
| continue | |
| seen.add(raw_prompt) | |
| prompts.append(raw_prompt) | |
| if len(prompts) >= max_prompts: | |
| break | |
| logger.info(f"[Dataset] Built {len(prompts)} GRPO prompts from hh-rlhf") | |
| return prompts | |
| def build_npc_scripts_from_datasets(toxic_ds, dialog_ds) -> dict[str, list[str]]: | |
| """ | |
| Build expanded NPC script pools from real dataset samples. | |
| Manipulator β toxic-chat rows with toxicity_annotation==1 or jailbreaking==True | |
| Agreeable β daily_dialog commissive acts (act==4) | |
| Emotional β daily_dialog directive acts (act==3) with high emotion score | |
| Liar β daily_dialog inform acts rephrased with contradiction prefix | |
| Returns dict: {npc_type: [list_of_script_lines]} | |
| Static SCRIPTS in npc_agents.py are always kept; these are appended. | |
| """ | |
| manipulator_lines: list[str] = [] | |
| agreeable_lines: list[str] = [] | |
| emotional_lines: list[str] = [] | |
| liar_lines: list[str] = [] | |
| # ββ Manipulator from toxic-chat ββββββββββββββββββββββ | |
| for row in toxic_ds: | |
| user_input = row.get("user_input", "") or "" | |
| is_toxic = bool(row.get("toxicity_annotation", 0)) or bool(row.get("jailbreaking", False)) | |
| if is_toxic and len(user_input) > 10: | |
| manipulator_lines.append(user_input[:200].strip()) | |
| if len(manipulator_lines) >= 30: | |
| break | |
| # ββ Agreeable / Emotional / Liar from daily_dialog ββ | |
| for row in dialog_ds: | |
| dialog: list[str] = row.get("dialog", []) | |
| acts: list[int] = row.get("act", []) | |
| emotions: list[int] = row.get("emotion", []) | |
| for i, utterance in enumerate(dialog): | |
| utterance = utterance.strip() | |
| if not utterance or len(utterance) < 10: | |
| continue | |
| act = acts[i] if i < len(acts) else 0 | |
| emotion = emotions[i] if i < len(emotions) else 0 | |
| if act == 4 and len(agreeable_lines) < 25: # commissive β agreeable | |
| agreeable_lines.append(utterance[:200]) | |
| elif act == 3 and emotion > 0 and len(emotional_lines) < 25: # directive + emotion | |
| emotional_lines.append(utterance[:200]) | |
| elif act == 1 and i > 0 and len(liar_lines) < 25: # inform β twist into contradiction | |
| prev = dialog[i - 1].strip()[:60] | |
| liar_lines.append(f"Actually, that's not what I said β {utterance[:150]}") | |
| if (len(agreeable_lines) >= 25 and len(emotional_lines) >= 25 | |
| and len(liar_lines) >= 25): | |
| break | |
| result = { | |
| "Manipulator": manipulator_lines, | |
| "Agreeable": agreeable_lines, | |
| "Emotional": emotional_lines, | |
| "Liar": liar_lines, | |
| } | |
| for npc_type, lines in result.items(): | |
| logger.info(f"[Dataset] NPC '{npc_type}': {len(lines)} dataset script lines") | |
| return result | |
| def build_pii_test_cases(pii_ds, max_cases: int = 500) -> list[dict]: | |
| """ | |
| Build PII test cases from ai4privacy dataset. | |
| Each case: {"text": original_text, "redacted": masked_text, "pii_labels": list} | |
| Used to benchmark Sentrix false-positive and false-negative rates. | |
| """ | |
| cases = [] | |
| for row in pii_ds: | |
| source = row.get("source_text", "") or "" | |
| target = row.get("target_text", "") or "" | |
| masks = row.get("privacy_mask", []) or [] | |
| if source and target: | |
| cases.append({ | |
| "text": source, | |
| "redacted": target, | |
| "pii_labels": [m.get("label", "") for m in masks if isinstance(m, dict)], | |
| }) | |
| if len(cases) >= max_cases: | |
| break | |
| logger.info(f"[Dataset] Built {len(cases)} PII test cases from pii-masking-300k") | |
| return cases | |
| def build_calibration_qa(tqa_ds, max_cases: int = 200) -> list[dict]: | |
| """ | |
| Build (question, best_answer, correct_answers) tuples from truthful_qa. | |
| Used to reward the model for expressing appropriate uncertainty. | |
| """ | |
| qa_pairs = [] | |
| for row in tqa_ds: | |
| question = row.get("question", "") or "" | |
| best_answer = row.get("best_answer", "") or "" | |
| correct = row.get("correct_answers", []) or [] | |
| if question and best_answer: | |
| qa_pairs.append({ | |
| "question": question, | |
| "answer": best_answer, | |
| "correct_answers": correct, | |
| }) | |
| if len(qa_pairs) >= max_cases: | |
| break | |
| logger.info(f"[Dataset] Built {len(qa_pairs)} calibration QA pairs from truthful_qa") | |
| return qa_pairs | |
| # ββ Sentrix Benchmarking βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def benchmark_sentrix(pii_test_cases: list[dict]) -> dict: | |
| """ | |
| Run Sentrix PII guard against the ai4privacy test cases. | |
| Returns: {"true_positives": int, "false_positives": int, "false_negatives": int, | |
| "precision": float, "recall": float} | |
| """ | |
| import sys | |
| from pathlib import Path as _P | |
| sys.path.insert(0, str(_P(__file__).parent.parent)) | |
| from sentrix.pii_guard import run as sentrix_run, SentrixBlockException | |
| tp = fp = fn = 0 | |
| for case in pii_test_cases: | |
| has_pii = bool(case.get("pii_labels")) | |
| try: | |
| result = sentrix_run(case["text"]) | |
| detected = result.get("pii_found", False) or result.get("severity") in ("block", "warn") | |
| except SentrixBlockException: | |
| detected = True | |
| if has_pii and detected: | |
| tp += 1 | |
| elif not has_pii and detected: | |
| fp += 1 | |
| elif has_pii and not detected: | |
| fn += 1 | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 | |
| metrics = { | |
| "true_positives": tp, | |
| "false_positives": fp, | |
| "false_negatives": fn, | |
| "precision": round(precision, 4), | |
| "recall": round(recall, 4), | |
| "f1": round(f1, 4), | |
| } | |
| logger.info(f"[Sentrix Benchmark] {metrics}") | |
| return metrics | |
| # ββ Quick smoke test βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| print("Loading all UMBRA datasets (this downloads ~2-3 GB on first run)...\n") | |
| hh_ds = load_hh_rlhf(max_samples=100) | |
| pii_ds = load_pii_masking(max_samples=100) | |
| tqa_ds = load_truthful_qa() | |
| toxic_ds = load_toxic_chat(max_samples=100) | |
| dialog_ds = load_daily_dialog(max_samples=100) | |
| prompts = build_grpo_prompts(hh_ds, max_prompts=10) | |
| npc_pool = build_npc_scripts_from_datasets(toxic_ds, dialog_ds) | |
| pii_cases = build_pii_test_cases(pii_ds, max_cases=50) | |
| qa_pairs = build_calibration_qa(tqa_ds, max_cases=20) | |
| print(f"\n=== UMBRA Dataset Summary ===") | |
| print(f" GRPO prompts : {len(prompts)}") | |
| print(f" NPC script lines : { {k: len(v) for k, v in npc_pool.items()} }") | |
| print(f" PII test cases : {len(pii_cases)}") | |
| print(f" Calibration QA : {len(qa_pairs)}") | |
| print(f"\n Sample GRPO prompt : {prompts[0][:100]}...") | |
| print(f"\nRunning Sentrix benchmark on {len(pii_cases)} PII cases...") | |
| metrics = benchmark_sentrix(pii_cases) | |
| print(f" Sentrix F1: {metrics['f1']} Precision: {metrics['precision']} Recall: {metrics['recall']}") | |
| print("\nAll datasets OK.") | |