nerpa / anonymise.py
akhatre
add reference to gliner2 zero shot capabilities
29ae185
"""
NERPA – Text anonymisation using the fine-tuned GLiNER2 model.
Usage:
python anonymise.py "My name is John Smith, born 15/03/1990. Email: john@example.com"
python anonymise.py --file input.txt
python anonymise.py --file input.txt --output anonymised.txt
"""
import argparse
import logging
import sys
import warnings
from typing import Optional
warnings.filterwarnings("ignore", message=r".*incorrect regex pattern.*fix_mistral_regex.*")
import torch
from gliner2 import GLiNER2
logger = logging.getLogger(__name__)
# Entity types the model was fine-tuned to recognise, with descriptions
# that guide the bi-encoder towards better detection.
PII_ENTITIES: dict[str, str] = {
"LOCATION": "Address, country, city, postcode, street, any other location",
"AGE": "Age of a person",
"DIGITAL_KEYS": "Digital keys, passwords, pins used to access anything like servers, banks, APIs, accounts etc",
"BANK_ACCOUNT_DETAILS": "Bank account details such as number, IBAN, SWIFT, routing numbers etc",
"CARD_DETAILS": "Debit or credit card details such as card number, CVV, expiration etc",
"DATE_TIME": "Generic date and time",
"DATE_OF_BIRTH": "Date of birth",
"PERSONAL_ID_NUMBERS": "Common personal identification numbers such as passport numbers, driving licenses, taxpayer and insurance numbers",
"TECHNICAL_ID_NUMBERS": "IP and MAC addresses, serial numbers and any other technical ID numbers",
"EMAIL": "Email",
"PERSON_NAME": "Person name",
"BUSINESS_NAME": "Business name",
"PHONE": "Any personal or other phone numbers",
"URL": "Any short or full URL",
"USERNAME": "Username",
"VEHICLE_ID_NUMBERS": "Any vehicle numbers like license plates, vehicle identification numbers",
}
CONFIDENCE_THRESHOLD = 0.25
CHUNK_SIZE = 3000
CHUNK_OVERLAP = 100
BATCH_SIZE = 32
def load_model(model_path: str = ".") -> GLiNER2:
"""Load the NERPA model onto the best available device."""
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model = GLiNER2.from_pretrained(model_path)
try:
model.to(device)
except RuntimeError:
logger.warning(
"Failed to load model on %s, falling back to CPU.", device
)
model.to(torch.device("cpu"))
return model
def chunk_text(
text: str,
chunk_size: int = CHUNK_SIZE,
overlap: int = CHUNK_OVERLAP,
) -> tuple[list[str], list[int]]:
"""Split text into overlapping chunks, returning chunks and their start offsets."""
if not text:
return [], []
chunks: list[str] = []
starts: list[int] = []
step = chunk_size - overlap
for pos in range(0, len(text), step):
chunks.append(text[pos : pos + chunk_size])
starts.append(pos)
return chunks, starts
def detect_entities(
model: GLiNER2,
text: str,
entities: Optional[dict[str, str]] = None,
threshold: float = CONFIDENCE_THRESHOLD,
) -> list[dict]:
"""
Detect PII entities in text, returning a list of
``{"type": str, "start": int, "end": int, "score": float}`` dicts
with character offsets into the original text.
"""
entities = entities or PII_ENTITIES
# Always detect both date types so the model can disambiguate.
detect = dict(entities)
if "DATE_TIME" in detect and "DATE_OF_BIRTH" not in detect:
detect["DATE_OF_BIRTH"] = PII_ENTITIES["DATE_OF_BIRTH"]
elif "DATE_OF_BIRTH" in detect and "DATE_TIME" not in detect:
detect["DATE_TIME"] = PII_ENTITIES["DATE_TIME"]
chunks, offsets = chunk_text(text)
all_chunk_results: list[dict] = []
for batch_start in range(0, len(chunks), BATCH_SIZE):
batch = chunks[batch_start : batch_start + BATCH_SIZE]
results = model.batch_extract_entities(
batch,
detect,
include_confidence=True,
include_spans=True,
threshold=threshold,
)
all_chunk_results.extend(results)
# Merge results across chunks: de-duplicate overlapping detections.
seen: dict[tuple[int, int], dict] = {}
for chunk_result, chunk_offset in zip(all_chunk_results, offsets):
for label, occurrences in chunk_result["entities"].items():
for occurrence in occurrences:
start = occurrence["start"] + chunk_offset
end = occurrence["end"] + chunk_offset
position = (start, end)
if (
position not in seen
or seen[position]["score"] < occurrence["confidence"]
):
seen[position] = {
"type": label,
"score": occurrence["confidence"],
}
# Merge overlapping spans, keeping the highest-confidence label.
# NOTE: when two spans overlap they are fused into one span and
# assigned the label with the higher confidence score.
items = sorted(
[
(start, end, info)
for (start, end), info in seen.items()
if info["type"] in entities
],
key=lambda x: (x[0], x[1]),
)
if not items:
return []
merged: list[dict] = []
current_start, current_end, current_info = items[0]
for start, end, info in items[1:]:
if start < current_end: # overlapping
current_end = max(current_end, end)
if info["score"] > current_info["score"]:
current_info = info
else:
merged.append({
"type": current_info["type"],
"start": current_start,
"end": current_end,
"score": current_info["score"],
})
current_start, current_end, current_info = start, end, info
merged.append({
"type": current_info["type"],
"start": current_start,
"end": current_end,
"score": current_info["score"],
})
return merged
def anonymise(text: str, detected: list[dict]) -> str:
"""Replace detected entities with placeholders like ``[PERSON_NAME]``."""
parts: list[str] = []
prev_end = 0
for entity in sorted(detected, key=lambda e: e["start"]):
parts.append(text[prev_end : entity["start"]])
parts.append(f'[{entity["type"]}]')
prev_end = entity["end"]
parts.append(text[prev_end:])
return "".join(parts)
def main() -> None:
parser = argparse.ArgumentParser(
description="Anonymise PII in text using the NERPA model.",
)
parser.add_argument(
"text", nargs="?", help="Text to anonymise (or use --file)",
)
parser.add_argument(
"--file", "-f", help="Read text from a file instead",
)
parser.add_argument(
"--output", "-o",
help="Write anonymised text to file (default: stdout)",
)
parser.add_argument(
"--model", "-m", default=".",
help="Path to model directory (default: current dir)",
)
parser.add_argument(
"--threshold", "-t", type=float, default=CONFIDENCE_THRESHOLD,
help=f"Confidence threshold (default: {CONFIDENCE_THRESHOLD})",
)
parser.add_argument(
"--show-entities", action="store_true",
help="Print detected entities before anonymised text",
)
parser.add_argument(
"--extra-entities", "-e", action="append", metavar="LABEL=DESCRIPTION",
help=(
"Additional custom entity types to detect alongside the built-in "
"PII entities. Repeat for each type. Format: LABEL=\"Description\". "
"Example: -e PRODUCT=\"Product name\" -e SKILL=\"Professional skill\""
),
)
args = parser.parse_args()
if args.file:
try:
with open(args.file, encoding="utf-8") as f:
text = f.read()
except OSError as exc:
sys.exit(f"Error reading {args.file}: {exc}")
elif args.text:
text = args.text
else:
parser.error("Provide text as an argument or use --file")
extra: dict[str, str] = {}
if args.extra_entities:
for item in args.extra_entities:
if "=" not in item:
parser.error(
f"Invalid --extra-entities value '{item}'. "
"Expected format: LABEL=\"Description\""
)
label, description = item.split("=", 1)
extra[label.strip()] = description.strip()
model = load_model(args.model)
all_entities = {**PII_ENTITIES, **extra} if extra else None
detected = detect_entities(model, text, entities=all_entities, threshold=args.threshold)
if args.show_entities:
for entity in detected:
span = text[entity["start"] : entity["end"]]
logger.info(
" %-25s [%5d:%5d] (score=%.2f) %r",
entity["type"], entity["start"], entity["end"],
entity["score"], span,
)
result = anonymise(text, detected)
if args.output:
try:
with open(args.output, "w", encoding="utf-8") as f:
f.write(result)
except OSError as exc:
sys.exit(f"Error writing {args.output}: {exc}")
else:
print(result)
if __name__ == "__main__":
main()