temsa's picture
Publish IrishCore-GlobalPointer-135M-v1-rc2
6edc303 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ["USE_TF"] = "0"
os.environ["USE_FLAX"] = "0"
os.environ["USE_TORCH"] = "1"
import torch
from transformers import AutoConfig
from common import decode_span_matrix, safe_auto_tokenizer
from model import IrishCoreGlobalPointerModel
def replacement(label: str) -> str:
return f"[PII:{label}]"
def mask_text(text: str, spans: list[dict]) -> str:
out = text
for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True):
out = out[: span["start"]] + replacement(span["label"]) + out[span["end"] :]
return out
def predict(text: str, model, tokenizer, min_score: float):
encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offsets = [tuple(item) for item in encoded.pop("offset_mapping")[0].tolist()]
device = next(model.parameters()).device
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
output = model(**encoded)
span_scores = torch.sigmoid(output.span_logits[0]).cpu().numpy()
spans = decode_span_matrix(text, offsets, span_scores, model.config, min_score)
for span in spans:
span["replacement"] = replacement(span["label"])
return spans
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--text", required=True)
parser.add_argument("--min-score", type=float, default=0.5)
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
parser.add_argument("--json", action="store_true")
args = parser.parse_args()
tokenizer = safe_auto_tokenizer(args.model)
config = AutoConfig.from_pretrained(args.model)
model = IrishCoreGlobalPointerModel.from_pretrained(args.model, config=config)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
model.to(device)
model.eval()
spans = predict(args.text, model, tokenizer, args.min_score)
result = {
"model": args.model,
"backend": "transformers_global_pointer",
"min_score": args.min_score,
"spans": spans,
"masked_text": mask_text(args.text, spans),
}
if args.json:
print(json.dumps(result, indent=2, ensure_ascii=False))
else:
print(result["masked_text"])
if __name__ == "__main__":
main()