IrishCore-DiffMask-135M-v1-rc3 / inference_mask.py
temsa's picture
Add files using upload-large-folder tool
85b0a00 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ["USE_TF"] = "0"
os.environ["USE_FLAX"] = "0"
os.environ["USE_TORCH"] = "1"
import torch
from transformers import AutoConfig
from common import (
boundary_label_thresholds_from_config,
decode_token_presence_segments,
label_max_span_tokens_from_config,
label_min_nonspace_chars_from_config,
label_names_from_config,
safe_auto_tokenizer,
token_extend_thresholds_from_config,
token_label_thresholds_from_config,
)
from multitask_model import IrishCoreTokenSpanModel
def replacement(label: str) -> str:
return f"[PII:{label}]"
def mask_text(text: str, spans: list[dict]) -> str:
out = text
for span in sorted(spans, key=lambda item: (item["start"], item["end"]), reverse=True):
out = out[: span["start"]] + replacement(span["label"]) + out[span["end"] :]
return out
def predict(text: str, model, tokenizer, min_score: float):
encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offsets = [tuple(item) for item in encoded.pop("offset_mapping")[0].tolist()]
device = next(model.parameters()).device
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
output = model(**encoded)
token_scores = torch.sigmoid(output.token_logits[0]).cpu().numpy()
start_scores = torch.sigmoid(output.start_logits[0]).cpu().numpy()
end_scores = torch.sigmoid(output.end_logits[0]).cpu().numpy()
label_names = label_names_from_config(model.config)
thresholds = token_label_thresholds_from_config(model.config, min_score)
extend_thresholds = token_extend_thresholds_from_config(model.config)
max_span_tokens = label_max_span_tokens_from_config(model.config)
min_nonspace_chars = label_min_nonspace_chars_from_config(model.config)
boundary_thresholds = boundary_label_thresholds_from_config(model.config)
spans = decode_token_presence_segments(
text,
offsets,
token_scores,
label_names,
min_score,
thresholds,
extend_thresholds,
max_span_tokens,
min_nonspace_chars,
boundary_thresholds,
start_scores=start_scores,
end_scores=end_scores,
)
for span in spans:
span["replacement"] = replacement(span["label"])
return spans
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--text", required=True)
parser.add_argument("--min-score", type=float, default=0.5)
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
parser.add_argument("--json", action="store_true")
args = parser.parse_args()
tokenizer = safe_auto_tokenizer(args.model)
config = AutoConfig.from_pretrained(args.model)
model = IrishCoreTokenSpanModel.from_pretrained(args.model, config=config)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
model.to(device)
model.eval()
spans = predict(args.text, model, tokenizer, args.min_score)
result = {
"model": args.model,
"backend": "transformers_token_span",
"min_score": args.min_score,
"spans": spans,
"masked_text": mask_text(args.text, spans),
}
if args.json:
print(json.dumps(result, indent=2, ensure_ascii=False))
else:
print(result["masked_text"])
if __name__ == "__main__":
main()