temsa commited on
Commit
001fa0a
·
verified ·
1 Parent(s): d1ddb26

Fix remote tokenizer loading for transformers inference

Browse files
Files changed (1) hide show
  1. inference_mask.py +3 -48
inference_mask.py CHANGED
@@ -2,9 +2,6 @@
2
  import argparse
3
  import json
4
  import os
5
- import shutil
6
- import tempfile
7
- from pathlib import Path
8
 
9
  os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
10
  os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
@@ -14,52 +11,10 @@ os.environ["USE_FLAX"] = "0"
14
  os.environ["USE_TORCH"] = "1"
15
 
16
  import torch
17
- from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
18
 
19
  from irish_core_decoder import repair_irish_core_spans
20
-
21
-
22
- def load_tokenizer(model_ref: str):
23
- tokenizer_ref = model_ref
24
- tokenizer_path = Path(model_ref)
25
- if tokenizer_path.exists():
26
- tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json"
27
- if tokenizer_cfg_path.exists():
28
- data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8"))
29
- if "fix_mistral_regex" in data:
30
- tmpdir = Path(tempfile.mkdtemp(prefix="openmed_tokenizer_"))
31
- keep = {
32
- "tokenizer_config.json",
33
- "tokenizer.json",
34
- "special_tokens_map.json",
35
- "vocab.txt",
36
- "vocab.json",
37
- "merges.txt",
38
- "added_tokens.json",
39
- "sentencepiece.bpe.model",
40
- "spiece.model",
41
- }
42
- for child in tokenizer_path.iterdir():
43
- if child.is_file() and child.name in keep:
44
- shutil.copy2(child, tmpdir / child.name)
45
- data.pop("fix_mistral_regex", None)
46
- (tmpdir / "tokenizer_config.json").write_text(
47
- json.dumps(data, ensure_ascii=False, indent=2) + "\n",
48
- encoding="utf-8",
49
- )
50
- tokenizer_ref = str(tmpdir)
51
- try:
52
- return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True)
53
- except Exception:
54
- pass
55
- try:
56
- return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False)
57
- except TypeError:
58
- pass
59
- try:
60
- return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True)
61
- except Exception:
62
- return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False)
63
 
64
 
65
  def mask_text(text: str, spans: list[dict]) -> str:
@@ -79,7 +34,7 @@ def main() -> None:
79
  parser.add_argument("--json", action="store_true")
80
  args = parser.parse_args()
81
 
82
- tokenizer = load_tokenizer(args.model)
83
  model = AutoModelForTokenClassification.from_pretrained(args.model)
84
  if args.device == "auto":
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
2
  import argparse
3
  import json
4
  import os
 
 
 
5
 
6
  os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
7
  os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
 
11
  os.environ["USE_TORCH"] = "1"
12
 
13
  import torch
14
+ from transformers import AutoModelForTokenClassification, pipeline
15
 
16
  from irish_core_decoder import repair_irish_core_spans
17
+ from onnx_token_classifier import safe_auto_tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def mask_text(text: str, spans: list[dict]) -> str:
 
34
  parser.add_argument("--json", action="store_true")
35
  args = parser.parse_args()
36
 
37
+ tokenizer = safe_auto_tokenizer(args.model)
38
  model = AutoModelForTokenClassification.from_pretrained(args.model)
39
  if args.device == "auto":
40
  device = "cuda" if torch.cuda.is_available() else "cpu"