File size: 5,706 Bytes
d03f57f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
PII-Scrub-Assistant-v1 -- Environment
======================================
3 tasks, programmatic graders (0.0-1.0), partial-progress reward.
"""
from __future__ import annotations
import random
from typing import Optional
from models import ScrubObservation, ScrubState


TASKS = {
    "task_1": {
        "difficulty": "easy",
        "instruction": "Redact the phone number in the following text. Replace any PII with [REDACTED].",
        "original_text": "Hi, please call me back at 555-867-5309 regarding my order.",
        "pii_items": {"phone": ["555-867-5309"]},
        "order_ids": [],
    },
    "task_2": {
        "difficulty": "medium",
        "instruction": "Redact all PII (names and email addresses) in the following paragraph. Replace each piece of PII with [REDACTED].",
        "original_text": "Dear Support,\n\nMy name is Johnathan Doe and I recently purchased your premium plan. Could you send the invoice to johndoe92@example.com? I need it for my expense report.\n\nThanks,\nJohnathan Doe",
        "pii_items": {"name": ["Johnathan Doe"], "email": ["johndoe92@example.com"]},
        "order_ids": [],
    },
    "task_3": {
        "difficulty": "hard",
        "instruction": "Redact all PII (names and email addresses) but DO NOT redact Order IDs (e.g. #12345) or System IDs. Replace each piece of PII with [REDACTED].",
        "original_text": "Hello,\n\nI'm Alice Johnson and I need help with order #98432. My email is alice.johnson@corpmail.net and my secondary contact is alice_j@inbox.org. Please also check order #10271.\n\nRegards,\nAlice Johnson",
        "pii_items": {"name": ["Alice Johnson"], "email": ["alice.johnson@corpmail.net", "alice_j@inbox.org"]},
        "order_ids": ["#98432", "#10271"],
    },
}


def _all_pii_removed(redacted, pii_items):
    return {t: all(v not in redacted for v in vs) for t, vs in pii_items.items()}

def _order_ids_preserved(redacted, order_ids):
    return all(oid in redacted for oid in order_ids)

def _count_over_redactions(original, redacted, pii_items):
    expected = sum(original.count(v) for vs in pii_items.values() for v in vs)
    return max(0, redacted.count("[REDACTED]") - expected)

def _structural_damage(original, redacted, pii_items):
    tag = "[REDACTED]"
    delta = sum(original.count(v) * (len(tag) - len(v)) for vs in pii_items.values() for v in vs)
    exp_len = len(original) + delta
    return abs(len(redacted) - exp_len) > max(20, int(exp_len * 0.15))


def compute_reward(original, redacted, pii_items, order_ids):
    parts, reward, n_ok = [], 0.0, 0
    for t, ok in _all_pii_removed(redacted, pii_items).items():
        if ok:
            reward += 0.5; n_ok += 1; parts.append(f"[OK] {t} redacted")
        else:
            parts.append(f"[FAIL] {t} NOT fully redacted")
    over = _count_over_redactions(original, redacted, pii_items)
    if over > 0:
        reward -= over * 0.2; parts.append(f"[FAIL] Over-redacted ({over}), -{over*0.2:.1f}")
    if order_ids:
        if not _order_ids_preserved(redacted, order_ids):
            reward -= 0.3; parts.append("[FAIL] Order/System ID(s) redacted (-0.3)")
        else:
            parts.append("[OK] Order/System IDs preserved")
    if _structural_damage(original, redacted, pii_items):
        reward -= 1.0; parts.append("[FAIL] Structural damage (-1.0)")
    total = len(pii_items)
    base = n_ok / total if total else 1.0
    ded = over * 0.05
    if order_ids and not _order_ids_preserved(redacted, order_ids): ded += 0.2
    if _structural_damage(original, redacted, pii_items): ded += 0.3
    score = max(0.0, min(1.0, base - ded))
    return reward, score, " | ".join(parts)


class ScrubEnvironment:
    def __init__(self):
        self._task_id = self._task = self._player_id = self._session_id = None
        self._step_count = 0; self._done = True
        self._last_score = self._last_reward = None

    def _validate(self, pid, sid):
        if self._player_id and pid != self._player_id:
            raise PermissionError(f"Player mismatch: expected {self._player_id!r}")
        if self._session_id and sid != self._session_id:
            raise PermissionError(f"Session mismatch: expected {self._session_id!r}")

    def reset(self, player_id, session_id, task_id=None):
        if task_id is None: task_id = random.choice(list(TASKS))
        if task_id not in TASKS: raise ValueError(f"Unknown task: {task_id}")
        self._player_id, self._session_id = player_id, session_id
        self._task_id, self._task = task_id, TASKS[task_id]
        self._step_count, self._done = 0, False
        self._last_score = self._last_reward = None
        return ScrubObservation(task_id=task_id, original_text=self._task["original_text"], instruction=self._task["instruction"])

    def step(self, player_id, session_id, redacted_text):
        if self._done or not self._task: raise RuntimeError("Call /reset first.")
        self._validate(player_id, session_id)
        self._step_count += 1
        reward, score, feedback = compute_reward(self._task["original_text"], redacted_text, self._task["pii_items"], self._task["order_ids"])
        self._done, self._last_score, self._last_reward = True, score, reward
        return ScrubObservation(task_id=self._task_id, original_text=self._task["original_text"], instruction=self._task["instruction"], score=score, reward=reward, done=True, feedback=feedback)

    def state(self):
        return ScrubState(current_task_id=self._task_id, player_id=self._player_id, session_id=self._session_id, step_count=self._step_count, done=self._done, last_score=self._last_score, last_reward=self._last_reward, original_text=self._task["original_text"] if self._task else None)