temsa's picture
Fix remote tokenizer loading for transformers inference
001fa0a verified
#!/usr/bin/env python3
import argparse
import json
import os
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ["USE_TF"] = "0"
os.environ["USE_FLAX"] = "0"
os.environ["USE_TORCH"] = "1"
import torch
from transformers import AutoModelForTokenClassification, pipeline
from irish_core_decoder import repair_irish_core_spans
from onnx_token_classifier import safe_auto_tokenizer
def mask_text(text: str, spans: list[dict]) -> str:
out = text
for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True):
out = out[:span["start"]] + f"[{span['label']}]" + out[span["end"]:]
return out
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=".")
parser.add_argument("--text", required=True)
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
parser.add_argument("--ppsn-min-score", type=float, default=0.55)
parser.add_argument("--other-min-score", type=float, default=0.50)
parser.add_argument("--json", action="store_true")
args = parser.parse_args()
tokenizer = safe_auto_tokenizer(args.model)
model = AutoModelForTokenClassification.from_pretrained(args.model)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
model.to(device)
model.eval()
nlp = pipeline(
"token-classification",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
device=0 if device == "cuda" else -1,
)
general = nlp(args.text)
spans = repair_irish_core_spans(
args.text,
model,
tokenizer,
general,
other_min_score=args.other_min_score,
ppsn_min_score=args.ppsn_min_score,
)
result = {
"model": args.model,
"masked_text": mask_text(args.text, spans),
"spans": spans,
"ppsn_decoder": "word_aligned",
"general_decoder": "irish_core_label_aware",
"ppsn_min_score": args.ppsn_min_score,
"other_min_score": args.other_min_score,
"backend": "transformers",
}
if args.json:
print(json.dumps(result, indent=2, ensure_ascii=False))
else:
print(result["masked_text"])
if __name__ == "__main__":
main()