Fix remote tokenizer loading for transformers inference
Browse files- inference_mask.py +3 -48
inference_mask.py
CHANGED
|
@@ -2,9 +2,6 @@
|
|
| 2 |
import argparse
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
-
import shutil
|
| 6 |
-
import tempfile
|
| 7 |
-
from pathlib import Path
|
| 8 |
|
| 9 |
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
|
| 10 |
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
|
|
@@ -14,52 +11,10 @@ os.environ["USE_FLAX"] = "0"
|
|
| 14 |
os.environ["USE_TORCH"] = "1"
|
| 15 |
|
| 16 |
import torch
|
| 17 |
-
from transformers import AutoModelForTokenClassification,
|
| 18 |
|
| 19 |
from irish_core_decoder import repair_irish_core_spans
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def load_tokenizer(model_ref: str):
|
| 23 |
-
tokenizer_ref = model_ref
|
| 24 |
-
tokenizer_path = Path(model_ref)
|
| 25 |
-
if tokenizer_path.exists():
|
| 26 |
-
tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json"
|
| 27 |
-
if tokenizer_cfg_path.exists():
|
| 28 |
-
data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8"))
|
| 29 |
-
if "fix_mistral_regex" in data:
|
| 30 |
-
tmpdir = Path(tempfile.mkdtemp(prefix="openmed_tokenizer_"))
|
| 31 |
-
keep = {
|
| 32 |
-
"tokenizer_config.json",
|
| 33 |
-
"tokenizer.json",
|
| 34 |
-
"special_tokens_map.json",
|
| 35 |
-
"vocab.txt",
|
| 36 |
-
"vocab.json",
|
| 37 |
-
"merges.txt",
|
| 38 |
-
"added_tokens.json",
|
| 39 |
-
"sentencepiece.bpe.model",
|
| 40 |
-
"spiece.model",
|
| 41 |
-
}
|
| 42 |
-
for child in tokenizer_path.iterdir():
|
| 43 |
-
if child.is_file() and child.name in keep:
|
| 44 |
-
shutil.copy2(child, tmpdir / child.name)
|
| 45 |
-
data.pop("fix_mistral_regex", None)
|
| 46 |
-
(tmpdir / "tokenizer_config.json").write_text(
|
| 47 |
-
json.dumps(data, ensure_ascii=False, indent=2) + "\n",
|
| 48 |
-
encoding="utf-8",
|
| 49 |
-
)
|
| 50 |
-
tokenizer_ref = str(tmpdir)
|
| 51 |
-
try:
|
| 52 |
-
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True)
|
| 53 |
-
except Exception:
|
| 54 |
-
pass
|
| 55 |
-
try:
|
| 56 |
-
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False)
|
| 57 |
-
except TypeError:
|
| 58 |
-
pass
|
| 59 |
-
try:
|
| 60 |
-
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True)
|
| 61 |
-
except Exception:
|
| 62 |
-
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False)
|
| 63 |
|
| 64 |
|
| 65 |
def mask_text(text: str, spans: list[dict]) -> str:
|
|
@@ -79,7 +34,7 @@ def main() -> None:
|
|
| 79 |
parser.add_argument("--json", action="store_true")
|
| 80 |
args = parser.parse_args()
|
| 81 |
|
| 82 |
-
tokenizer =
|
| 83 |
model = AutoModelForTokenClassification.from_pretrained(args.model)
|
| 84 |
if args.device == "auto":
|
| 85 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 2 |
import argparse
|
| 3 |
import json
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
|
| 7 |
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
|
|
|
|
| 11 |
os.environ["USE_TORCH"] = "1"
|
| 12 |
|
| 13 |
import torch
|
| 14 |
+
from transformers import AutoModelForTokenClassification, pipeline
|
| 15 |
|
| 16 |
from irish_core_decoder import repair_irish_core_spans
|
| 17 |
+
from onnx_token_classifier import safe_auto_tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def mask_text(text: str, spans: list[dict]) -> str:
|
|
|
|
| 34 |
parser.add_argument("--json", action="store_true")
|
| 35 |
args = parser.parse_args()
|
| 36 |
|
| 37 |
+
tokenizer = safe_auto_tokenizer(args.model)
|
| 38 |
model = AutoModelForTokenClassification.from_pretrained(args.model)
|
| 39 |
if args.device == "auto":
|
| 40 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|