nerpa / anonymise.py
hugging-hat's picture
Upload anonymise.py with huggingface_hub
a05a9b8 verified
raw
history blame
7.03 kB
"""
NERPA – Text anonymisation using the fine-tuned GLiNER2 model.
Usage:
python anonymise.py "My name is John Smith, born 15/03/1990. Email: john@example.com"
python anonymise.py --file input.txt
python anonymise.py --file input.txt --output anonymised.txt
"""
import argparse
import sys
from typing import Dict, List, Tuple
import torch
from gliner2 import GLiNER2
# Entity types the model was fine-tuned to recognise, with descriptions
# that guide the bi-encoder towards better detection.
PII_ENTITIES = {
"LOCATION": "Address, country, city, postcode, street, any other location",
"AGE": "Age of a person",
"DIGITAL_KEYS": "Digital keys, passwords, pins used to access anything like servers, banks, APIs, accounts etc",
"BANK_ACCOUNT_DETAILS": "Bank account details such as number, IBAN, SWIFT, routing numbers etc",
"CARD_DETAILS": "Debit or credit card details such as card number, CVV, expiration etc",
"DATE_TIME": "Generic date and time",
"DATE_OF_BIRTH": "Date of birth",
"PERSONAL_ID_NUMBERS": "Common personal identification numbers such as passport numbers, driving licenses, taxpayer and insurance numbers",
"TECHNICAL_ID_NUMBERS": "IP and MAC addresses, serial numbers and any other technical ID numbers",
"EMAIL": "Email",
"PERSON_NAME": "Person name",
"BUSINESS_NAME": "Business name",
"PHONE": "Any personal or other phone numbers",
"URL": "Any short or full URL",
"USERNAME": "Username",
"VEHICLE_ID_NUMBERS": "Any vehicle numbers like license plates, vehicle identification numbers",
}
CONFIDENCE_THRESHOLD = 0.25
CHUNK_SIZE = 3000
CHUNK_OVERLAP = 100
def load_model(model_path: str = ".") -> GLiNER2:
"""Load the NERPA model onto the best available device."""
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model = GLiNER2.from_pretrained(model_path)
model.to(device)
return model
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> Tuple[List[str], List[int]]:
"""Split text into overlapping chunks, returning chunks and their start offsets."""
if not text:
return [], []
chunks, starts = [], []
step = chunk_size - overlap
pos = 0
while pos < len(text):
chunks.append(text[pos : pos + chunk_size])
starts.append(pos)
if pos + chunk_size >= len(text):
break
pos += step
return chunks, starts
def detect_entities(
model: GLiNER2,
text: str,
entities: Dict[str, str] = None,
threshold: float = CONFIDENCE_THRESHOLD,
) -> List[dict]:
"""
Detect PII entities in text, returning a list of
{"type": str, "start": int, "end": int, "score": float} dicts
with character offsets into the original text.
"""
entities = entities or PII_ENTITIES
# Always detect both date types so the model can disambiguate
detect = dict(entities)
if "DATE_TIME" in detect and "DATE_OF_BIRTH" not in detect:
detect["DATE_OF_BIRTH"] = PII_ENTITIES["DATE_OF_BIRTH"]
elif "DATE_OF_BIRTH" in detect and "DATE_TIME" not in detect:
detect["DATE_TIME"] = PII_ENTITIES["DATE_TIME"]
chunks, offsets = chunk_text(text)
all_chunk_results = []
for batch_start in range(0, len(chunks), 32):
batch = chunks[batch_start : batch_start + 32]
results = model.batch_extract_entities(
batch,
detect,
include_confidence=True,
include_spans=True,
threshold=threshold,
)
all_chunk_results.extend(results)
# Merge results across chunks: de-duplicate overlapping detections
seen: Dict[Tuple[int, int], dict] = {}
for chunk_result, chunk_offset in zip(all_chunk_results, offsets):
for label, occurrences in chunk_result["entities"].items():
for occ in occurrences:
start = occ["start"] + chunk_offset
end = occ["end"] + chunk_offset
pos = (start, end)
if pos not in seen or seen[pos]["score"] < occ["confidence"]:
seen[pos] = {"type": label, "score": occ["confidence"]}
# Merge overlapping spans, keeping highest confidence label
items = sorted(
[(s, e, info) for (s, e), info in seen.items() if info["type"] in entities],
key=lambda x: (x[0], x[1]),
)
if not items:
return []
merged = []
cur_s, cur_e, cur_info = items[0]
for s, e, info in items[1:]:
if s < cur_e: # overlapping
cur_e = max(cur_e, e)
if info["score"] > cur_info["score"]:
cur_info = info
else:
merged.append({"type": cur_info["type"], "start": cur_s, "end": cur_e, "score": cur_info["score"]})
cur_s, cur_e, cur_info = s, e, info
merged.append({"type": cur_info["type"], "start": cur_s, "end": cur_e, "score": cur_info["score"]})
return merged
def anonymise(text: str, detected: List[dict]) -> str:
"""Replace detected entities with placeholders like [PERSON_NAME]."""
# Process from end to start so offsets stay valid
result = text
for entity in sorted(detected, key=lambda e: e["start"], reverse=True):
placeholder = f'[{entity["type"]}]'
result = result[: entity["start"]] + placeholder + result[entity["end"] :]
return result
def main():
parser = argparse.ArgumentParser(description="Anonymise PII in text using the NERPA model.")
parser.add_argument("text", nargs="?", help="Text to anonymise (or use --file)")
parser.add_argument("--file", "-f", help="Read text from a file instead")
parser.add_argument("--output", "-o", help="Write anonymised text to file (default: stdout)")
parser.add_argument("--model", "-m", default=".", help="Path to model directory (default: current dir)")
parser.add_argument("--threshold", "-t", type=float, default=CONFIDENCE_THRESHOLD, help="Confidence threshold (default: 0.25)")
parser.add_argument("--show-entities", action="store_true", help="Print detected entities before anonymised text")
args = parser.parse_args()
if args.file:
with open(args.file) as f:
text = f.read()
elif args.text:
text = args.text
else:
parser.error("Provide text as an argument or use --file")
model = load_model(args.model)
detected = detect_entities(model, text, threshold=args.threshold)
if args.show_entities:
for e in detected:
print(f' {e["type"]:25s} [{e["start"]:5d}:{e["end"]:5d}] (score={e["score"]:.2f}) "{text[e["start"]:e["end"]]}"', file=sys.stderr)
print(file=sys.stderr)
result = anonymise(text, detected)
if args.output:
with open(args.output, "w") as f:
f.write(result)
else:
print(result)
if __name__ == "__main__":
main()