PII-Scrub-Final-Submission / environment.py
krishuggingface's picture
Upload folder using huggingface_hub
d03f57f verified
"""
PII-Scrub-Assistant-v1 -- Environment
======================================
3 tasks, programmatic graders (0.0-1.0), partial-progress reward.
"""
from __future__ import annotations
import random
from typing import Optional
from models import ScrubObservation, ScrubState
TASKS = {
"task_1": {
"difficulty": "easy",
"instruction": "Redact the phone number in the following text. Replace any PII with [REDACTED].",
"original_text": "Hi, please call me back at 555-867-5309 regarding my order.",
"pii_items": {"phone": ["555-867-5309"]},
"order_ids": [],
},
"task_2": {
"difficulty": "medium",
"instruction": "Redact all PII (names and email addresses) in the following paragraph. Replace each piece of PII with [REDACTED].",
"original_text": "Dear Support,\n\nMy name is Johnathan Doe and I recently purchased your premium plan. Could you send the invoice to johndoe92@example.com? I need it for my expense report.\n\nThanks,\nJohnathan Doe",
"pii_items": {"name": ["Johnathan Doe"], "email": ["johndoe92@example.com"]},
"order_ids": [],
},
"task_3": {
"difficulty": "hard",
"instruction": "Redact all PII (names and email addresses) but DO NOT redact Order IDs (e.g. #12345) or System IDs. Replace each piece of PII with [REDACTED].",
"original_text": "Hello,\n\nI'm Alice Johnson and I need help with order #98432. My email is alice.johnson@corpmail.net and my secondary contact is alice_j@inbox.org. Please also check order #10271.\n\nRegards,\nAlice Johnson",
"pii_items": {"name": ["Alice Johnson"], "email": ["alice.johnson@corpmail.net", "alice_j@inbox.org"]},
"order_ids": ["#98432", "#10271"],
},
}
def _all_pii_removed(redacted, pii_items):
return {t: all(v not in redacted for v in vs) for t, vs in pii_items.items()}
def _order_ids_preserved(redacted, order_ids):
return all(oid in redacted for oid in order_ids)
def _count_over_redactions(original, redacted, pii_items):
expected = sum(original.count(v) for vs in pii_items.values() for v in vs)
return max(0, redacted.count("[REDACTED]") - expected)
def _structural_damage(original, redacted, pii_items):
tag = "[REDACTED]"
delta = sum(original.count(v) * (len(tag) - len(v)) for vs in pii_items.values() for v in vs)
exp_len = len(original) + delta
return abs(len(redacted) - exp_len) > max(20, int(exp_len * 0.15))
def compute_reward(original, redacted, pii_items, order_ids):
parts, reward, n_ok = [], 0.0, 0
for t, ok in _all_pii_removed(redacted, pii_items).items():
if ok:
reward += 0.5; n_ok += 1; parts.append(f"[OK] {t} redacted")
else:
parts.append(f"[FAIL] {t} NOT fully redacted")
over = _count_over_redactions(original, redacted, pii_items)
if over > 0:
reward -= over * 0.2; parts.append(f"[FAIL] Over-redacted ({over}), -{over*0.2:.1f}")
if order_ids:
if not _order_ids_preserved(redacted, order_ids):
reward -= 0.3; parts.append("[FAIL] Order/System ID(s) redacted (-0.3)")
else:
parts.append("[OK] Order/System IDs preserved")
if _structural_damage(original, redacted, pii_items):
reward -= 1.0; parts.append("[FAIL] Structural damage (-1.0)")
total = len(pii_items)
base = n_ok / total if total else 1.0
ded = over * 0.05
if order_ids and not _order_ids_preserved(redacted, order_ids): ded += 0.2
if _structural_damage(original, redacted, pii_items): ded += 0.3
score = max(0.0, min(1.0, base - ded))
return reward, score, " | ".join(parts)
class ScrubEnvironment:
def __init__(self):
self._task_id = self._task = self._player_id = self._session_id = None
self._step_count = 0; self._done = True
self._last_score = self._last_reward = None
def _validate(self, pid, sid):
if self._player_id and pid != self._player_id:
raise PermissionError(f"Player mismatch: expected {self._player_id!r}")
if self._session_id and sid != self._session_id:
raise PermissionError(f"Session mismatch: expected {self._session_id!r}")
def reset(self, player_id, session_id, task_id=None):
if task_id is None: task_id = random.choice(list(TASKS))
if task_id not in TASKS: raise ValueError(f"Unknown task: {task_id}")
self._player_id, self._session_id = player_id, session_id
self._task_id, self._task = task_id, TASKS[task_id]
self._step_count, self._done = 0, False
self._last_score = self._last_reward = None
return ScrubObservation(task_id=task_id, original_text=self._task["original_text"], instruction=self._task["instruction"])
def step(self, player_id, session_id, redacted_text):
if self._done or not self._task: raise RuntimeError("Call /reset first.")
self._validate(player_id, session_id)
self._step_count += 1
reward, score, feedback = compute_reward(self._task["original_text"], redacted_text, self._task["pii_items"], self._task["order_ids"])
self._done, self._last_score, self._last_reward = True, score, reward
return ScrubObservation(task_id=self._task_id, original_text=self._task["original_text"], instruction=self._task["instruction"], score=score, reward=reward, done=True, feedback=feedback)
def state(self):
return ScrubState(current_task_id=self._task_id, player_id=self._player_id, session_id=self._session_id, step_count=self._step_count, done=self._done, last_score=self._last_score, last_reward=self._last_reward, original_text=self._task["original_text"] if self._task else None)