champ-chatbot / classes /pii_filter.py
qyle's picture
pii filter improved
3da1373 verified
import logging
import re
from typing import Optional
from gliner import GLiNER
logger = logging.getLogger("uvicorn")
LABELS = [
"email",
"date_of_birth",
"last_name",
"street_address",
]
LABELS_PLACEHOLDERS = {
"email": "an email",
"phone_number": "a phone number",
"date_of_birth": "a date of birth",
"last_name": "a last name",
"street_address": "a location",
"ssn": "a ssn",
}
RE_SSN = r"\b\d{3}[- ]?\d{3}[- ]?\d{3}\b"
RE_ZIP = (
r"\b[ABCEGHJKLMNPRSTVXY]\d[ABCEGHJKLMNPRSTVWXYZ][ ]?\d[ABCEGHJKLMNPRSTVWXYZ]\d\b"
)
RE_PHONE = r"(?:\+?\d{1,3}[-\s.]?)?\(?\d{3}\)?[-\s.]?\d{3}[-\s.]?\d{4}"
def clean_backslashes(txt: str) -> str:
"""Cleans backslashes from a string.
For example, passing the string "It\'s not for everyone" will return "It's not for everyone".
Backslashes next to names or locations confuse the PII filter.
Args:
txt (str): String to clean
Returns:
str: Cleaned string
"""
return txt.replace("\\'", "'")
def chunk_text(text: str, max_chars: int = 1000) -> list[tuple[str, int]]:
"""
The text is sometimes too large for the model. We chunk it here so we can pass
each chunk to the model one by one.
"""
chunks = []
start = 0
text_len = len(text)
while start < text_len:
# On prend un bloc (environ 1000 caractères ~ 250-300 tokens)
end = start + max_chars
# Pour éviter de couper un mot au milieu, on recule jusqu'au dernier espace
if end < text_len:
end = text.rfind(" ", start, end)
if end <= start: # Si aucun espace n'est trouvé
end = start + max_chars
chunks.append((text[start:end], start))
# On avance le curseur (on peut ajouter un overlap ici si nécessaire)
start = end
return chunks
class PIIFilter:
_instance: Optional["PIIFilter"] = None
model: None
def __new__(cls):
if cls._instance is None:
logger.info("Loading the PII filter into memory...")
cls._instance = super(PIIFilter, cls).__new__(cls)
# TODO: manual SSN detection
cls._instance.model = GLiNER.from_pretrained("nvidia/gliner-PII")
return cls._instance
def sanitize(self, text: str) -> str:
if not text:
return text
text = clean_backslashes(text)
all_entities = []
# 1. Chunking pour GLiNER (max_chars=1000 pour rester sous les 384 tokens)
chunks = chunk_text(text, max_chars=1000)
for chunk, offset in chunks:
chunk_entities = self.model.predict_entities(chunk, LABELS, threshold=0.6)
for ent in chunk_entities:
all_entities.append(
{
"start": ent["start"] + offset,
"end": ent["end"] + offset,
"label": ent["label"],
}
)
# 2. Ajout des détections par Regex
regex_rules = [
(RE_SSN, "ssn"),
(RE_ZIP, "street_address"),
(RE_PHONE, "phone_number"),
]
for pattern, label in regex_rules:
for match in re.finditer(pattern, text):
all_entities.append(
{"start": match.start(), "end": match.end(), "label": label}
)
# 3. Gestion des chevauchements (Overlaps)
# Si deux entités se chevauchent, on garde la plus large.
all_entities.sort(key=lambda x: x["start"])
merged_entities = []
if all_entities:
current = all_entities[0]
for next_ent in all_entities[1:]:
if next_ent["start"] < current["end"]:
# Chevauchement trouvé, on prend l'enveloppe maximale
current["end"] = max(current["end"], next_ent["end"])
# On peut aussi décider ici quel label prioriser
else:
merged_entities.append(current)
current = next_ent
merged_entities.append(current)
# 4. Remplacement (en partant de la fin pour garder les index valides)
redacted_text = text
for entity in sorted(merged_entities, key=lambda x: x["start"], reverse=True):
placeholder = LABELS_PLACEHOLDERS[entity["label"]]
redacted_text = (
redacted_text[: entity["start"]]
+ placeholder
+ redacted_text[entity["end"] :]
)
return redacted_text