Update inference.py
Browse files- inference.py +5 -7
inference.py
CHANGED
|
@@ -2,18 +2,13 @@ import os
|
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
from transformers import AutoTokenizer
|
| 6 |
|
| 7 |
try:
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
except Exception:
|
| 10 |
hf_hub_download = None
|
| 11 |
|
| 12 |
-
try:
|
| 13 |
-
from .modeling_rqa import RQAModelHF
|
| 14 |
-
except ImportError:
|
| 15 |
-
from modeling_rqa import RQAModelHF
|
| 16 |
-
|
| 17 |
|
| 18 |
ERROR_NAMES_RU = {
|
| 19 |
"false_causality": "Ложная причинно-следственная связь",
|
|
@@ -61,7 +56,10 @@ class RQAInferenceHF:
|
|
| 61 |
self.hidden_uncertain_margin = float(hidden_uncertain_margin)
|
| 62 |
self.error_uncertain_margin = float(error_uncertain_margin)
|
| 63 |
|
| 64 |
-
self.model =
|
|
|
|
|
|
|
|
|
|
| 65 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 66 |
|
| 67 |
cfg = self.model.config
|
|
|
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer
|
| 6 |
|
| 7 |
try:
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
except Exception:
|
| 10 |
hf_hub_download = None
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
ERROR_NAMES_RU = {
|
| 14 |
"false_causality": "Ложная причинно-следственная связь",
|
|
|
|
| 56 |
self.hidden_uncertain_margin = float(hidden_uncertain_margin)
|
| 57 |
self.error_uncertain_margin = float(error_uncertain_margin)
|
| 58 |
|
| 59 |
+
self.model = AutoModel.from_pretrained(
|
| 60 |
+
model_path,
|
| 61 |
+
trust_remote_code=True,
|
| 62 |
+
).to(self.device).eval()
|
| 63 |
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 64 |
|
| 65 |
cfg = self.model.config
|