model_length
Browse files- hive_token_classification.py +20 -0
hive_token_classification.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
from transformers import Pipeline, AutoModel, AutoTokenizer
|
| 3 |
+
from transformers.pipelines.base import GenericTensor, ModelOutput
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HiveTokenClassification(Pipeline):
|
| 7 |
+
def _sanitize_parameters(self, **kwargs):
|
| 8 |
+
forward_parameters = {}
|
| 9 |
+
if "output_style" in kwargs:
|
| 10 |
+
forward_parameters["output_style"] = kwargs["output_style"]
|
| 11 |
+
return {}, forward_parameters, {}
|
| 12 |
+
|
| 13 |
+
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
|
| 14 |
+
return input_
|
| 15 |
+
|
| 16 |
+
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
|
| 17 |
+
return self.model.predict(input_tensors, self.tokenizer, **forward_parameters)
|
| 18 |
+
|
| 19 |
+
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
|
| 20 |
+
return {"output": model_outputs, "model_length": len(model_outputs)}
|