Upload inference.py with huggingface_hub
Browse files- inference.py +21 -4
inference.py
CHANGED
|
@@ -54,8 +54,15 @@ def load_ensemble(repo_id: str = None, local_dir: str = None):
|
|
| 54 |
config = json.load(f)
|
| 55 |
|
| 56 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
model1 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][0]).to(device)
|
| 60 |
model2 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][1]).to(device)
|
| 61 |
|
|
@@ -112,8 +119,18 @@ def predict(question: str, context: str, ensemble: dict, max_answer_len: int = 3
|
|
| 112 |
inp1 = {k: v.to(dev) for k, v in enc1.items()}
|
| 113 |
inp2 = {k: v.to(dev) for k, v in enc2.items()}
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if not ctx_idx:
|
| 118 |
return "", 1.0
|
| 119 |
|
|
|
|
| 54 |
config = json.load(f)
|
| 55 |
|
| 56 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 57 |
+
def _load_tok(mid, use_fast=True):
|
| 58 |
+
try:
|
| 59 |
+
return AutoTokenizer.from_pretrained(mid, use_fast=use_fast)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
if "sentencepiece" in str(e).lower() and use_fast:
|
| 62 |
+
return AutoTokenizer.from_pretrained(mid, use_fast=False)
|
| 63 |
+
raise
|
| 64 |
+
tokenizer1 = _load_tok(config["base_models"][0])
|
| 65 |
+
tokenizer2 = AutoTokenizer.from_pretrained(config["base_models"][1], use_fast=False) # PhoBERT cần use_fast=False
|
| 66 |
model1 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][0]).to(device)
|
| 67 |
model2 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][1]).to(device)
|
| 68 |
|
|
|
|
| 119 |
inp1 = {k: v.to(dev) for k, v in enc1.items()}
|
| 120 |
inp2 = {k: v.to(dev) for k, v in enc2.items()}
|
| 121 |
|
| 122 |
+
try:
|
| 123 |
+
seq_ids = enc1.sequence_ids(0)
|
| 124 |
+
except Exception:
|
| 125 |
+
# Slow tokenizer: RoBERTa layout [CLS] q [SEP] ctx [SEP], sep=2
|
| 126 |
+
sep_id = t1.convert_tokens_to_ids(t1.sep_token or "</s>")
|
| 127 |
+
ids = enc1["input_ids"][0].tolist()
|
| 128 |
+
sep_pos = [i for i, x in enumerate(ids) if x == sep_id]
|
| 129 |
+
if len(sep_pos) < 2:
|
| 130 |
+
return "", 1.0
|
| 131 |
+
ctx_idx = list(range(sep_pos[0] + 1, sep_pos[1]))
|
| 132 |
+
else:
|
| 133 |
+
ctx_idx = [i for i, s in enumerate(seq_ids) if s == 1]
|
| 134 |
if not ctx_idx:
|
| 135 |
return "", 1.0
|
| 136 |
|