|
|
""" |
|
|
NERPA – Text anonymisation using the fine-tuned GLiNER2 model. |
|
|
|
|
|
Usage: |
|
|
python anonymise.py "My name is John Smith, born 15/03/1990. Email: john@example.com" |
|
|
python anonymise.py --file input.txt |
|
|
python anonymise.py --file input.txt --output anonymised.txt |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import sys |
|
|
import warnings |
|
|
from typing import Optional |
|
|
|
|
|
warnings.filterwarnings("ignore", message=r".*incorrect regex pattern.*fix_mistral_regex.*") |
|
|
|
|
|
import torch |
|
|
from gliner2 import GLiNER2 |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
PII_ENTITIES: dict[str, str] = { |
|
|
"LOCATION": "Address, country, city, postcode, street, any other location", |
|
|
"AGE": "Age of a person", |
|
|
"DIGITAL_KEYS": "Digital keys, passwords, pins used to access anything like servers, banks, APIs, accounts etc", |
|
|
"BANK_ACCOUNT_DETAILS": "Bank account details such as number, IBAN, SWIFT, routing numbers etc", |
|
|
"CARD_DETAILS": "Debit or credit card details such as card number, CVV, expiration etc", |
|
|
"DATE_TIME": "Generic date and time", |
|
|
"DATE_OF_BIRTH": "Date of birth", |
|
|
"PERSONAL_ID_NUMBERS": "Common personal identification numbers such as passport numbers, driving licenses, taxpayer and insurance numbers", |
|
|
"TECHNICAL_ID_NUMBERS": "IP and MAC addresses, serial numbers and any other technical ID numbers", |
|
|
"EMAIL": "Email", |
|
|
"PERSON_NAME": "Person name", |
|
|
"BUSINESS_NAME": "Business name", |
|
|
"PHONE": "Any personal or other phone numbers", |
|
|
"URL": "Any short or full URL", |
|
|
"USERNAME": "Username", |
|
|
"VEHICLE_ID_NUMBERS": "Any vehicle numbers like license plates, vehicle identification numbers", |
|
|
} |
|
|
|
|
|
CONFIDENCE_THRESHOLD = 0.25 |
|
|
CHUNK_SIZE = 3000 |
|
|
CHUNK_OVERLAP = 100 |
|
|
BATCH_SIZE = 32 |
|
|
|
|
|
|
|
|
def load_model(model_path: str = ".") -> GLiNER2: |
|
|
"""Load the NERPA model onto the best available device.""" |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
|
|
|
model = GLiNER2.from_pretrained(model_path) |
|
|
try: |
|
|
model.to(device) |
|
|
except RuntimeError: |
|
|
logger.warning( |
|
|
"Failed to load model on %s, falling back to CPU.", device |
|
|
) |
|
|
model.to(torch.device("cpu")) |
|
|
return model |
|
|
|
|
|
|
|
|
def chunk_text( |
|
|
text: str, |
|
|
chunk_size: int = CHUNK_SIZE, |
|
|
overlap: int = CHUNK_OVERLAP, |
|
|
) -> tuple[list[str], list[int]]: |
|
|
"""Split text into overlapping chunks, returning chunks and their start offsets.""" |
|
|
if not text: |
|
|
return [], [] |
|
|
chunks: list[str] = [] |
|
|
starts: list[int] = [] |
|
|
step = chunk_size - overlap |
|
|
for pos in range(0, len(text), step): |
|
|
chunks.append(text[pos : pos + chunk_size]) |
|
|
starts.append(pos) |
|
|
return chunks, starts |
|
|
|
|
|
|
|
|
def detect_entities( |
|
|
model: GLiNER2, |
|
|
text: str, |
|
|
entities: Optional[dict[str, str]] = None, |
|
|
threshold: float = CONFIDENCE_THRESHOLD, |
|
|
) -> list[dict]: |
|
|
""" |
|
|
Detect PII entities in text, returning a list of |
|
|
``{"type": str, "start": int, "end": int, "score": float}`` dicts |
|
|
with character offsets into the original text. |
|
|
""" |
|
|
entities = entities or PII_ENTITIES |
|
|
|
|
|
|
|
|
detect = dict(entities) |
|
|
if "DATE_TIME" in detect and "DATE_OF_BIRTH" not in detect: |
|
|
detect["DATE_OF_BIRTH"] = PII_ENTITIES["DATE_OF_BIRTH"] |
|
|
elif "DATE_OF_BIRTH" in detect and "DATE_TIME" not in detect: |
|
|
detect["DATE_TIME"] = PII_ENTITIES["DATE_TIME"] |
|
|
|
|
|
chunks, offsets = chunk_text(text) |
|
|
|
|
|
all_chunk_results: list[dict] = [] |
|
|
for batch_start in range(0, len(chunks), BATCH_SIZE): |
|
|
batch = chunks[batch_start : batch_start + BATCH_SIZE] |
|
|
results = model.batch_extract_entities( |
|
|
batch, |
|
|
detect, |
|
|
include_confidence=True, |
|
|
include_spans=True, |
|
|
threshold=threshold, |
|
|
) |
|
|
all_chunk_results.extend(results) |
|
|
|
|
|
|
|
|
seen: dict[tuple[int, int], dict] = {} |
|
|
for chunk_result, chunk_offset in zip(all_chunk_results, offsets): |
|
|
for label, occurrences in chunk_result["entities"].items(): |
|
|
for occurrence in occurrences: |
|
|
start = occurrence["start"] + chunk_offset |
|
|
end = occurrence["end"] + chunk_offset |
|
|
position = (start, end) |
|
|
if ( |
|
|
position not in seen |
|
|
or seen[position]["score"] < occurrence["confidence"] |
|
|
): |
|
|
seen[position] = { |
|
|
"type": label, |
|
|
"score": occurrence["confidence"], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
items = sorted( |
|
|
[ |
|
|
(start, end, info) |
|
|
for (start, end), info in seen.items() |
|
|
if info["type"] in entities |
|
|
], |
|
|
key=lambda x: (x[0], x[1]), |
|
|
) |
|
|
if not items: |
|
|
return [] |
|
|
|
|
|
merged: list[dict] = [] |
|
|
current_start, current_end, current_info = items[0] |
|
|
for start, end, info in items[1:]: |
|
|
if start < current_end: |
|
|
current_end = max(current_end, end) |
|
|
if info["score"] > current_info["score"]: |
|
|
current_info = info |
|
|
else: |
|
|
merged.append({ |
|
|
"type": current_info["type"], |
|
|
"start": current_start, |
|
|
"end": current_end, |
|
|
"score": current_info["score"], |
|
|
}) |
|
|
current_start, current_end, current_info = start, end, info |
|
|
merged.append({ |
|
|
"type": current_info["type"], |
|
|
"start": current_start, |
|
|
"end": current_end, |
|
|
"score": current_info["score"], |
|
|
}) |
|
|
|
|
|
return merged |
|
|
|
|
|
|
|
|
def anonymise(text: str, detected: list[dict]) -> str: |
|
|
"""Replace detected entities with placeholders like ``[PERSON_NAME]``.""" |
|
|
parts: list[str] = [] |
|
|
prev_end = 0 |
|
|
for entity in sorted(detected, key=lambda e: e["start"]): |
|
|
parts.append(text[prev_end : entity["start"]]) |
|
|
parts.append(f'[{entity["type"]}]') |
|
|
prev_end = entity["end"] |
|
|
parts.append(text[prev_end:]) |
|
|
return "".join(parts) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Anonymise PII in text using the NERPA model.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"text", nargs="?", help="Text to anonymise (or use --file)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--file", "-f", help="Read text from a file instead", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", "-o", |
|
|
help="Write anonymised text to file (default: stdout)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", "-m", default=".", |
|
|
help="Path to model directory (default: current dir)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--threshold", "-t", type=float, default=CONFIDENCE_THRESHOLD, |
|
|
help=f"Confidence threshold (default: {CONFIDENCE_THRESHOLD})", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--show-entities", action="store_true", |
|
|
help="Print detected entities before anonymised text", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--extra-entities", "-e", action="append", metavar="LABEL=DESCRIPTION", |
|
|
help=( |
|
|
"Additional custom entity types to detect alongside the built-in " |
|
|
"PII entities. Repeat for each type. Format: LABEL=\"Description\". " |
|
|
"Example: -e PRODUCT=\"Product name\" -e SKILL=\"Professional skill\"" |
|
|
), |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.file: |
|
|
try: |
|
|
with open(args.file, encoding="utf-8") as f: |
|
|
text = f.read() |
|
|
except OSError as exc: |
|
|
sys.exit(f"Error reading {args.file}: {exc}") |
|
|
elif args.text: |
|
|
text = args.text |
|
|
else: |
|
|
parser.error("Provide text as an argument or use --file") |
|
|
|
|
|
extra: dict[str, str] = {} |
|
|
if args.extra_entities: |
|
|
for item in args.extra_entities: |
|
|
if "=" not in item: |
|
|
parser.error( |
|
|
f"Invalid --extra-entities value '{item}'. " |
|
|
"Expected format: LABEL=\"Description\"" |
|
|
) |
|
|
label, description = item.split("=", 1) |
|
|
extra[label.strip()] = description.strip() |
|
|
|
|
|
model = load_model(args.model) |
|
|
all_entities = {**PII_ENTITIES, **extra} if extra else None |
|
|
detected = detect_entities(model, text, entities=all_entities, threshold=args.threshold) |
|
|
|
|
|
if args.show_entities: |
|
|
for entity in detected: |
|
|
span = text[entity["start"] : entity["end"]] |
|
|
logger.info( |
|
|
" %-25s [%5d:%5d] (score=%.2f) %r", |
|
|
entity["type"], entity["start"], entity["end"], |
|
|
entity["score"], span, |
|
|
) |
|
|
|
|
|
result = anonymise(text, detected) |
|
|
|
|
|
if args.output: |
|
|
try: |
|
|
with open(args.output, "w", encoding="utf-8") as f: |
|
|
f.write(result) |
|
|
except OSError as exc: |
|
|
sys.exit(f"Error writing {args.output}: {exc}") |
|
|
else: |
|
|
print(result) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|