| |
| import argparse |
| import json |
| import os |
|
|
| os.environ.setdefault("TRANSFORMERS_NO_TF", "1") |
| os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1") |
| os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") |
| os.environ["USE_TF"] = "0" |
| os.environ["USE_FLAX"] = "0" |
| os.environ["USE_TORCH"] = "1" |
|
|
| import torch |
| from transformers import AutoModelForTokenClassification, pipeline |
|
|
| from irish_core_decoder import repair_irish_core_spans |
| from onnx_token_classifier import safe_auto_tokenizer |
|
|
|
|
| def mask_text(text: str, spans: list[dict]) -> str: |
| out = text |
| for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True): |
| out = out[:span["start"]] + f"[{span['label']}]" + out[span["end"]:] |
| return out |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default=".") |
| parser.add_argument("--text", required=True) |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") |
| parser.add_argument("--ppsn-min-score", type=float, default=0.55) |
| parser.add_argument("--other-min-score", type=float, default=0.50) |
| parser.add_argument("--json", action="store_true") |
| args = parser.parse_args() |
|
|
| tokenizer = safe_auto_tokenizer(args.model) |
| model = AutoModelForTokenClassification.from_pretrained(args.model) |
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| device = args.device |
| model.to(device) |
| model.eval() |
|
|
| nlp = pipeline( |
| "token-classification", |
| model=model, |
| tokenizer=tokenizer, |
| aggregation_strategy="simple", |
| device=0 if device == "cuda" else -1, |
| ) |
| general = nlp(args.text) |
| spans = repair_irish_core_spans( |
| args.text, |
| model, |
| tokenizer, |
| general, |
| other_min_score=args.other_min_score, |
| ppsn_min_score=args.ppsn_min_score, |
| ) |
| result = { |
| "model": args.model, |
| "masked_text": mask_text(args.text, spans), |
| "spans": spans, |
| "ppsn_decoder": "word_aligned", |
| "general_decoder": "irish_core_label_aware", |
| "ppsn_min_score": args.ppsn_min_score, |
| "other_min_score": args.other_min_score, |
| "backend": "transformers", |
| } |
| if args.json: |
| print(json.dumps(result, indent=2, ensure_ascii=False)) |
| else: |
| print(result["masked_text"]) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|