| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
|
|
| os.environ.setdefault("TRANSFORMERS_NO_TF", "1") |
| os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1") |
| os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") |
| os.environ["USE_TF"] = "0" |
| os.environ["USE_FLAX"] = "0" |
| os.environ["USE_TORCH"] = "1" |
|
|
| import torch |
| from transformers import AutoConfig |
|
|
| from common import ( |
| boundary_label_thresholds_from_config, |
| decode_token_presence_segments, |
| label_max_span_tokens_from_config, |
| label_min_nonspace_chars_from_config, |
| label_names_from_config, |
| safe_auto_tokenizer, |
| token_extend_thresholds_from_config, |
| token_label_thresholds_from_config, |
| ) |
| from multitask_model import IrishCoreTokenSpanModel |
|
|
|
|
| def replacement(label: str) -> str: |
| return f"[PII:{label}]" |
|
|
|
|
| def mask_text(text: str, spans: list[dict]) -> str: |
| out = text |
| for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True): |
| out = out[: span["start"]] + replacement(span["label"]) + out[span["end"] :] |
| return out |
|
|
|
|
| def predict(text: str, model, tokenizer, min_score: float): |
| encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="pt", truncation=True) |
| offsets = [tuple(item) for item in encoded.pop("offset_mapping")[0].tolist()] |
| device = next(model.parameters()).device |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| with torch.no_grad(): |
| output = model(**encoded) |
| token_scores = torch.sigmoid(output.token_logits[0]).cpu().numpy() |
| start_scores = torch.sigmoid(output.start_logits[0]).cpu().numpy() |
| end_scores = torch.sigmoid(output.end_logits[0]).cpu().numpy() |
| label_names = label_names_from_config(model.config) |
| thresholds = token_label_thresholds_from_config(model.config, min_score) |
| extend_thresholds = token_extend_thresholds_from_config(model.config) |
| max_span_tokens = label_max_span_tokens_from_config(model.config) |
| min_nonspace_chars = label_min_nonspace_chars_from_config(model.config) |
| boundary_thresholds = boundary_label_thresholds_from_config(model.config) |
| spans = decode_token_presence_segments( |
| text, |
| offsets, |
| token_scores, |
| label_names, |
| min_score, |
| thresholds, |
| extend_thresholds, |
| max_span_tokens, |
| min_nonspace_chars, |
| boundary_thresholds, |
| start_scores=start_scores, |
| end_scores=end_scores, |
| ) |
| for span in spans: |
| span["replacement"] = replacement(span["label"]) |
| return spans |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", required=True) |
| parser.add_argument("--text", required=True) |
| parser.add_argument("--min-score", type=float, default=0.5) |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") |
| parser.add_argument("--json", action="store_true") |
| args = parser.parse_args() |
|
|
| tokenizer = safe_auto_tokenizer(args.model) |
| config = AutoConfig.from_pretrained(args.model) |
| model = IrishCoreTokenSpanModel.from_pretrained(args.model, config=config) |
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| device = args.device |
| model.to(device) |
| model.eval() |
|
|
| spans = predict(args.text, model, tokenizer, args.min_score) |
| result = { |
| "model": args.model, |
| "backend": "transformers_token_span", |
| "min_score": args.min_score, |
| "spans": spans, |
| "masked_text": mask_text(args.text, spans), |
| } |
| if args.json: |
| print(json.dumps(result, indent=2, ensure_ascii=False)) |
| else: |
| print(result["masked_text"]) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|