""" PII-Scrub-Assistant-v1 -- Environment ====================================== 3 tasks, programmatic graders (0.0-1.0), partial-progress reward. """ from __future__ import annotations import random from typing import Optional from models import ScrubObservation, ScrubState TASKS = { "task_1": { "difficulty": "easy", "instruction": "Redact the phone number in the following text. Replace any PII with [REDACTED].", "original_text": "Hi, please call me back at 555-867-5309 regarding my order.", "pii_items": {"phone": ["555-867-5309"]}, "order_ids": [], }, "task_2": { "difficulty": "medium", "instruction": "Redact all PII (names and email addresses) in the following paragraph. Replace each piece of PII with [REDACTED].", "original_text": "Dear Support,\n\nMy name is Johnathan Doe and I recently purchased your premium plan. Could you send the invoice to johndoe92@example.com? I need it for my expense report.\n\nThanks,\nJohnathan Doe", "pii_items": {"name": ["Johnathan Doe"], "email": ["johndoe92@example.com"]}, "order_ids": [], }, "task_3": { "difficulty": "hard", "instruction": "Redact all PII (names and email addresses) but DO NOT redact Order IDs (e.g. #12345) or System IDs. Replace each piece of PII with [REDACTED].", "original_text": "Hello,\n\nI'm Alice Johnson and I need help with order #98432. My email is alice.johnson@corpmail.net and my secondary contact is alice_j@inbox.org. Please also check order #10271.\n\nRegards,\nAlice Johnson", "pii_items": {"name": ["Alice Johnson"], "email": ["alice.johnson@corpmail.net", "alice_j@inbox.org"]}, "order_ids": ["#98432", "#10271"], }, } def _all_pii_removed(redacted, pii_items): return {t: all(v not in redacted for v in vs) for t, vs in pii_items.items()} def _order_ids_preserved(redacted, order_ids): return all(oid in redacted for oid in order_ids) def _count_over_redactions(original, redacted, pii_items): expected = sum(original.count(v) for vs in pii_items.values() for v in vs) return max(0, redacted.count("[REDACTED]") - expected) def _structural_damage(original, redacted, pii_items): tag = "[REDACTED]" delta = sum(original.count(v) * (len(tag) - len(v)) for vs in pii_items.values() for v in vs) exp_len = len(original) + delta return abs(len(redacted) - exp_len) > max(20, int(exp_len * 0.15)) def compute_reward(original, redacted, pii_items, order_ids): parts, reward, n_ok = [], 0.0, 0 for t, ok in _all_pii_removed(redacted, pii_items).items(): if ok: reward += 0.5; n_ok += 1; parts.append(f"[OK] {t} redacted") else: parts.append(f"[FAIL] {t} NOT fully redacted") over = _count_over_redactions(original, redacted, pii_items) if over > 0: reward -= over * 0.2; parts.append(f"[FAIL] Over-redacted ({over}), -{over*0.2:.1f}") if order_ids: if not _order_ids_preserved(redacted, order_ids): reward -= 0.3; parts.append("[FAIL] Order/System ID(s) redacted (-0.3)") else: parts.append("[OK] Order/System IDs preserved") if _structural_damage(original, redacted, pii_items): reward -= 1.0; parts.append("[FAIL] Structural damage (-1.0)") total = len(pii_items) base = n_ok / total if total else 1.0 ded = over * 0.05 if order_ids and not _order_ids_preserved(redacted, order_ids): ded += 0.2 if _structural_damage(original, redacted, pii_items): ded += 0.3 score = max(0.0, min(1.0, base - ded)) return reward, score, " | ".join(parts) class ScrubEnvironment: def __init__(self): self._task_id = self._task = self._player_id = self._session_id = None self._step_count = 0; self._done = True self._last_score = self._last_reward = None def _validate(self, pid, sid): if self._player_id and pid != self._player_id: raise PermissionError(f"Player mismatch: expected {self._player_id!r}") if self._session_id and sid != self._session_id: raise PermissionError(f"Session mismatch: expected {self._session_id!r}") def reset(self, player_id, session_id, task_id=None): if task_id is None: task_id = random.choice(list(TASKS)) if task_id not in TASKS: raise ValueError(f"Unknown task: {task_id}") self._player_id, self._session_id = player_id, session_id self._task_id, self._task = task_id, TASKS[task_id] self._step_count, self._done = 0, False self._last_score = self._last_reward = None return ScrubObservation(task_id=task_id, original_text=self._task["original_text"], instruction=self._task["instruction"]) def step(self, player_id, session_id, redacted_text): if self._done or not self._task: raise RuntimeError("Call /reset first.") self._validate(player_id, session_id) self._step_count += 1 reward, score, feedback = compute_reward(self._task["original_text"], redacted_text, self._task["pii_items"], self._task["order_ids"]) self._done, self._last_score, self._last_reward = True, score, reward return ScrubObservation(task_id=self._task_id, original_text=self._task["original_text"], instruction=self._task["instruction"], score=score, reward=reward, done=True, feedback=feedback) def state(self): return ScrubState(current_task_id=self._task_id, player_id=self._player_id, session_id=self._session_id, step_count=self._step_count, done=self._done, last_score=self._last_score, last_reward=self._last_reward, original_text=self._task["original_text"] if self._task else None)