temsa's picture
Upload folder using huggingface_hub
d49b0d0 verified
#!/usr/bin/env python3
import argparse
import json
import os
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ["USE_TF"] = "0"
os.environ["USE_FLAX"] = "0"
os.environ["USE_TORCH"] = "1"
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
import regex as re
TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE)
EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE)
ALLOWED = {
"PPSN",
"ACCOUNT_NUMBER",
"BANK_ROUTING_NUMBER",
"CREDIT_DEBIT_CARD",
"PASSPORT_NUMBER",
"POSTCODE",
"PHONE_NUMBER",
"EMAIL",
"FIRST_NAME",
"LAST_NAME",
"SWIFT_BIC",
}
def tokenize_with_spans(text: str):
return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)]
def normalize_label(label: str) -> str:
label = (label or "").strip()
if label.startswith("B-") or label.startswith("I-"):
label = label[2:]
return label.upper()
def looks_like_eircode(value: str) -> bool:
return EIRCODE_RE.match(value.strip()) is not None
def ppsn_label_ids(model):
ids = []
for raw_id, raw_label in model.config.id2label.items():
label_id = int(raw_id)
label = str(raw_label or "").strip()
if label.endswith("PPSN"):
ids.append(label_id)
return sorted(ids)
def word_aligned_ppsn_spans(text: str, model, tokenizer, threshold: float):
pieces = tokenize_with_spans(text)
if not pieces:
return []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
device = next(model.parameters()).device
encoded = {k: v.to(device) for k, v in encoded.items()}
with torch.no_grad():
logits = model(**encoded).logits[0]
probs = torch.softmax(logits, dim=-1)
label_ids = ppsn_label_ids(model)
word_scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in label_ids:
score = max(score, float(probs[token_index, label_id]))
word_scores.append(score)
spans = []
active = None
for (_, start, end), score in zip(pieces, word_scores):
if score >= threshold:
if active is None:
active = {"start": start, "end": end, "score": score}
else:
active["end"] = end
active["score"] = max(active["score"], score)
elif active is not None:
spans.append(active)
active = None
if active is not None:
spans.append(active)
for span in spans:
span["label"] = "PPSN"
span["text"] = text[span["start"]:span["end"]]
return spans
def merge_spans(text: str, general_spans: list[dict], ppsn_spans: list[dict], other_min_score: float):
out = []
for span in general_spans:
label = normalize_label(span.get("entity_group") or span.get("entity") or "")
if label not in ALLOWED or label == "PPSN":
continue
if float(span.get("score", 0.0)) < other_min_score:
continue
out.append({
"label": label,
"start": int(span["start"]),
"end": int(span["end"]),
"score": float(span["score"]),
"text": text[int(span["start"]):int(span["end"])],
})
def overlaps(a, b):
return not (a["end"] <= b["start"] or b["end"] <= a["start"])
for span in ppsn_spans:
if looks_like_eircode(span["text"]):
continue
if any(overlaps(span, existing) for existing in out):
continue
out.append(span)
out.sort(key=lambda item: (item["start"], item["end"]))
return out
def mask_text(text: str, spans: list[dict]) -> str:
out = text
for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True):
out = out[:span["start"]] + f"[{span['label']}]" + out[span["end"]:]
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=".")
parser.add_argument("--text", required=True)
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
parser.add_argument("--ppsn-min-score", type=float, default=0.4)
parser.add_argument("--other-min-score", type=float, default=0.5)
parser.add_argument("--json", action="store_true")
args = parser.parse_args()
try:
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, fix_mistral_regex=True)
except Exception:
try:
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, fix_mistral_regex=False)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(args.model)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
model.to(device)
model.eval()
nlp = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple", device=0 if device == "cuda" else -1)
general = nlp(args.text)
ppsn = word_aligned_ppsn_spans(args.text, model, tokenizer, threshold=args.ppsn_min_score)
spans = merge_spans(args.text, general, ppsn, other_min_score=args.other_min_score)
result = {
"model": args.model,
"masked_text": mask_text(args.text, spans),
"spans": spans,
"ppsn_decoder": "word_aligned",
"ppsn_min_score": args.ppsn_min_score,
"other_min_score": args.other_min_score,
}
if args.json:
print(json.dumps(result, indent=2, ensure_ascii=False))
else:
print(result["masked_text"])
if __name__ == "__main__":
main()