khanhmse commited on
Commit
a7ec61a
·
verified ·
1 Parent(s): 81b65cc

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +21 -4
inference.py CHANGED
@@ -54,8 +54,15 @@ def load_ensemble(repo_id: str = None, local_dir: str = None):
54
  config = json.load(f)
55
 
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
- tokenizer1 = AutoTokenizer.from_pretrained(config["base_models"][0], use_fast=True)
58
- tokenizer2 = AutoTokenizer.from_pretrained(config["base_models"][1], use_fast=False)
 
 
 
 
 
 
 
59
  model1 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][0]).to(device)
60
  model2 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][1]).to(device)
61
 
@@ -112,8 +119,18 @@ def predict(question: str, context: str, ensemble: dict, max_answer_len: int = 3
112
  inp1 = {k: v.to(dev) for k, v in enc1.items()}
113
  inp2 = {k: v.to(dev) for k, v in enc2.items()}
114
 
115
- seq_ids = enc1.sequence_ids(0)
116
- ctx_idx = [i for i, s in enumerate(seq_ids) if s == 1]
 
 
 
 
 
 
 
 
 
 
117
  if not ctx_idx:
118
  return "", 1.0
119
 
 
54
  config = json.load(f)
55
 
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ def _load_tok(mid, use_fast=True):
58
+ try:
59
+ return AutoTokenizer.from_pretrained(mid, use_fast=use_fast)
60
+ except Exception as e:
61
+ if "sentencepiece" in str(e).lower() and use_fast:
62
+ return AutoTokenizer.from_pretrained(mid, use_fast=False)
63
+ raise
64
+ tokenizer1 = _load_tok(config["base_models"][0])
65
+ tokenizer2 = AutoTokenizer.from_pretrained(config["base_models"][1], use_fast=False) # PhoBERT cần use_fast=False
66
  model1 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][0]).to(device)
67
  model2 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][1]).to(device)
68
 
 
119
  inp1 = {k: v.to(dev) for k, v in enc1.items()}
120
  inp2 = {k: v.to(dev) for k, v in enc2.items()}
121
 
122
+ try:
123
+ seq_ids = enc1.sequence_ids(0)
124
+ except Exception:
125
+ # Slow tokenizer: RoBERTa layout [CLS] q [SEP] ctx [SEP], sep=2
126
+ sep_id = t1.convert_tokens_to_ids(t1.sep_token or "</s>")
127
+ ids = enc1["input_ids"][0].tolist()
128
+ sep_pos = [i for i, x in enumerate(ids) if x == sep_id]
129
+ if len(sep_pos) < 2:
130
+ return "", 1.0
131
+ ctx_idx = list(range(sep_pos[0] + 1, sep_pos[1]))
132
+ else:
133
+ ctx_idx = [i for i, s in enumerate(seq_ids) if s == 1]
134
  if not ctx_idx:
135
  return "", 1.0
136