Spaces:
Runtime error
Runtime error
Commit ·
2ffc97e
1
Parent(s): 8fe5582
fix: no need torch
Browse files- onnx_model.py +4 -2
onnx_model.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import Any
|
|
| 8 |
import numpy as np
|
| 9 |
import onnxruntime as ort
|
| 10 |
from loguru import logger
|
| 11 |
-
from onnxruntime.transformers.io_binding_helper import TypeHelper
|
| 12 |
|
| 13 |
|
| 14 |
@dataclass
|
|
@@ -36,7 +35,10 @@ class ONNXModel:
|
|
| 36 |
else:
|
| 37 |
self.device = "cpu"
|
| 38 |
|
| 39 |
-
self.io_types =
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
self.input_names = [el.name for el in model.get_inputs()]
|
| 42 |
self.output_name = model.get_outputs()[0].name
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import onnxruntime as ort
|
| 10 |
from loguru import logger
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
|
|
|
| 35 |
else:
|
| 36 |
self.device = "cpu"
|
| 37 |
|
| 38 |
+
self.io_types = {
|
| 39 |
+
"input_ids": np.int32,
|
| 40 |
+
"attention_mask": np.bool_
|
| 41 |
+
}
|
| 42 |
|
| 43 |
self.input_names = [el.name for el in model.get_inputs()]
|
| 44 |
self.output_name = model.get_outputs()[0].name
|