| """ |
| Baseline Inference -- Llama 3.1 8B via HF Inference Router. |
| Set HF_READ_TOKEN to use the LLM, otherwise falls back to regex. |
| """ |
| import os, re, sys, uuid, requests |
|
|
| BASE_URL = os.getenv("SCRUB_ENV_URL", "http://localhost:7860") |
| PLAYER_ID = "baseline-llama-v1" |
| SESSION_ID = str(uuid.uuid4()) |
| HF_READ_TOKEN = os.getenv("HF_READ_TOKEN") |
|
|
|
|
| def llama_scrub(text, instruction): |
| from openai import OpenAI |
| client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=HF_READ_TOKEN) |
| r = client.chat.completions.create( |
| model="meta-llama/Llama-3.1-8B-Instruct", |
| messages=[ |
| {"role": "system", "content": f"{instruction}\nReplace PII with [REDACTED]. Keep Order/System IDs. Return ONLY the redacted text."}, |
| {"role": "user", "content": text}, |
| ], |
| temperature=0.0, max_tokens=1024, |
| ) |
| return r.choices[0].message.content.strip() |
|
|
|
|
| PHONE_RE = re.compile(r"(\+?\d{1,3}[-.\s]?)?(\(?\d{2,4}\)?[-.\s]?)(\d{3,4}[-.\s]?\d{3,4})") |
| EMAIL_RE = re.compile(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}") |
| NAME_RE = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b") |
|
|
| def regex_scrub(text): |
| tag = "[REDACTED]" |
| text = EMAIL_RE.sub(tag, text) |
| text = PHONE_RE.sub(tag, text) |
| text = NAME_RE.sub(tag, text) |
| return text |
|
|
| def scrub(text, instruction): |
| return llama_scrub(text, instruction) if HF_READ_TOKEN else regex_scrub(text) |
|
|
|
|
| def run_task(tid): |
| r = requests.post(f"{BASE_URL}/reset", json={"player_id": PLAYER_ID, "session_id": SESSION_ID, "task_id": tid}) |
| r.raise_for_status(); obs = r.json() |
| redacted = scrub(obs["original_text"], obs["instruction"]) |
| r = requests.post(f"{BASE_URL}/step", json={"player_id": PLAYER_ID, "session_id": SESSION_ID, "action": {"action_id": 1, "redacted_text": redacted}}) |
| r.raise_for_status(); res = r.json() |
| print(f" {tid}: score={res['score']:.2f} reward={res['reward']:.1f} | {res['feedback']}") |
| return res.get("score", 0.0) |
|
|
|
|
| def main(): |
| mode = "Llama-3.1-8B" if HF_READ_TOKEN else "Regex" |
| print(f"\nBaseline [{mode}]\n" + "=" * 50) |
| try: requests.get(f"{BASE_URL}/health", timeout=3).raise_for_status() |
| except: print(f"[ERROR] Server not reachable at {BASE_URL}"); sys.exit(1) |
| scores = {t: run_task(t) for t in ["task_1", "task_2", "task_3"]} |
| print(f"\nAverage: {sum(scores.values())/len(scores):.2f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|