PII-Scrub-Final-Submission / baseline_inference.py
krishuggingface's picture
Upload folder using huggingface_hub
d03f57f verified
"""
Baseline Inference -- Llama 3.1 8B via HF Inference Router.
Set HF_READ_TOKEN to use the LLM, otherwise falls back to regex.
"""
import os, re, sys, uuid, requests
BASE_URL = os.getenv("SCRUB_ENV_URL", "http://localhost:7860")
PLAYER_ID = "baseline-llama-v1"
SESSION_ID = str(uuid.uuid4())
HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")
def llama_scrub(text, instruction):
from openai import OpenAI
client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=HF_READ_TOKEN)
r = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{"role": "system", "content": f"{instruction}\nReplace PII with [REDACTED]. Keep Order/System IDs. Return ONLY the redacted text."},
{"role": "user", "content": text},
],
temperature=0.0, max_tokens=1024,
)
return r.choices[0].message.content.strip()
PHONE_RE = re.compile(r"(\+?\d{1,3}[-.\s]?)?(\(?\d{2,4}\)?[-.\s]?)(\d{3,4}[-.\s]?\d{3,4})")
EMAIL_RE = re.compile(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
NAME_RE = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b")
def regex_scrub(text):
tag = "[REDACTED]"
text = EMAIL_RE.sub(tag, text)
text = PHONE_RE.sub(tag, text)
text = NAME_RE.sub(tag, text)
return text
def scrub(text, instruction):
return llama_scrub(text, instruction) if HF_READ_TOKEN else regex_scrub(text)
def run_task(tid):
r = requests.post(f"{BASE_URL}/reset", json={"player_id": PLAYER_ID, "session_id": SESSION_ID, "task_id": tid})
r.raise_for_status(); obs = r.json()
redacted = scrub(obs["original_text"], obs["instruction"])
r = requests.post(f"{BASE_URL}/step", json={"player_id": PLAYER_ID, "session_id": SESSION_ID, "action": {"action_id": 1, "redacted_text": redacted}})
r.raise_for_status(); res = r.json()
print(f" {tid}: score={res['score']:.2f} reward={res['reward']:.1f} | {res['feedback']}")
return res.get("score", 0.0)
def main():
mode = "Llama-3.1-8B" if HF_READ_TOKEN else "Regex"
print(f"\nBaseline [{mode}]\n" + "=" * 50)
try: requests.get(f"{BASE_URL}/health", timeout=3).raise_for_status()
except: print(f"[ERROR] Server not reachable at {BASE_URL}"); sys.exit(1)
scores = {t: run_task(t) for t in ["task_1", "task_2", "task_3"]}
print(f"\nAverage: {sum(scores.values())/len(scores):.2f}")
if __name__ == "__main__":
main()