skatzR commited on
Commit
d9c4c42
·
verified ·
1 Parent(s): fd20285

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -7
inference.py CHANGED
@@ -2,18 +2,13 @@ import os
2
  from typing import Any, Dict, List, Optional
3
 
4
  import torch
5
- from transformers import AutoTokenizer
6
 
7
  try:
8
  from huggingface_hub import hf_hub_download
9
  except Exception:
10
  hf_hub_download = None
11
 
12
- try:
13
- from .modeling_rqa import RQAModelHF
14
- except ImportError:
15
- from modeling_rqa import RQAModelHF
16
-
17
 
18
  ERROR_NAMES_RU = {
19
  "false_causality": "Ложная причинно-следственная связь",
@@ -61,7 +56,10 @@ class RQAInferenceHF:
61
  self.hidden_uncertain_margin = float(hidden_uncertain_margin)
62
  self.error_uncertain_margin = float(error_uncertain_margin)
63
 
64
- self.model = RQAModelHF.from_pretrained(model_path).to(self.device).eval()
 
 
 
65
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
66
 
67
  cfg = self.model.config
 
2
  from typing import Any, Dict, List, Optional
3
 
4
  import torch
5
+ from transformers import AutoModel, AutoTokenizer
6
 
7
  try:
8
  from huggingface_hub import hf_hub_download
9
  except Exception:
10
  hf_hub_download = None
11
 
 
 
 
 
 
12
 
13
  ERROR_NAMES_RU = {
14
  "false_causality": "Ложная причинно-следственная связь",
 
56
  self.hidden_uncertain_margin = float(hidden_uncertain_margin)
57
  self.error_uncertain_margin = float(error_uncertain_margin)
58
 
59
+ self.model = AutoModel.from_pretrained(
60
+ model_path,
61
+ trust_remote_code=True,
62
+ ).to(self.device).eval()
63
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
64
 
65
  cfg = self.model.config