File size: 2,422 Bytes
d03f57f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Baseline Inference -- Llama 3.1 8B via HF Inference Router.
Set HF_READ_TOKEN to use the LLM, otherwise falls back to regex.
"""
import os, re, sys, uuid, requests

BASE_URL = os.getenv("SCRUB_ENV_URL", "http://localhost:7860")
PLAYER_ID = "baseline-llama-v1"
SESSION_ID = str(uuid.uuid4())
HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")


def llama_scrub(text, instruction):
    from openai import OpenAI
    client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=HF_READ_TOKEN)
    r = client.chat.completions.create(
        model="meta-llama/Llama-3.1-8B-Instruct",
        messages=[
            {"role": "system", "content": f"{instruction}\nReplace PII with [REDACTED]. Keep Order/System IDs. Return ONLY the redacted text."},
            {"role": "user", "content": text},
        ],
        temperature=0.0, max_tokens=1024,
    )
    return r.choices[0].message.content.strip()


PHONE_RE = re.compile(r"(\+?\d{1,3}[-.\s]?)?(\(?\d{2,4}\)?[-.\s]?)(\d{3,4}[-.\s]?\d{3,4})")
EMAIL_RE = re.compile(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
NAME_RE  = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b")

def regex_scrub(text):
    tag = "[REDACTED]"
    text = EMAIL_RE.sub(tag, text)
    text = PHONE_RE.sub(tag, text)
    text = NAME_RE.sub(tag, text)
    return text

def scrub(text, instruction):
    return llama_scrub(text, instruction) if HF_READ_TOKEN else regex_scrub(text)


def run_task(tid):
    r = requests.post(f"{BASE_URL}/reset", json={"player_id": PLAYER_ID, "session_id": SESSION_ID, "task_id": tid})
    r.raise_for_status(); obs = r.json()
    redacted = scrub(obs["original_text"], obs["instruction"])
    r = requests.post(f"{BASE_URL}/step", json={"player_id": PLAYER_ID, "session_id": SESSION_ID, "action": {"action_id": 1, "redacted_text": redacted}})
    r.raise_for_status(); res = r.json()
    print(f"  {tid}: score={res['score']:.2f}  reward={res['reward']:.1f}  | {res['feedback']}")
    return res.get("score", 0.0)


def main():
    mode = "Llama-3.1-8B" if HF_READ_TOKEN else "Regex"
    print(f"\nBaseline [{mode}]\n" + "=" * 50)
    try: requests.get(f"{BASE_URL}/health", timeout=3).raise_for_status()
    except: print(f"[ERROR] Server not reachable at {BASE_URL}"); sys.exit(1)
    scores = {t: run_task(t) for t in ["task_1", "task_2", "task_3"]}
    print(f"\nAverage: {sum(scores.values())/len(scores):.2f}")

if __name__ == "__main__":
    main()