| from typing import Any, Dict | |
| from transformers import Pipeline, AutoModel, AutoTokenizer | |
| from transformers.pipelines.base import GenericTensor, ModelOutput | |
| class HiveTokenClassification(Pipeline): | |
| def _sanitize_parameters(self, **kwargs): | |
| forward_parameters = {} | |
| if "output_style" in kwargs: | |
| forward_parameters["output_style"] = kwargs["output_style"] | |
| return {}, forward_parameters, {} | |
| def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: | |
| return input_ | |
| def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: | |
| return self.model.predict(input_tensors, self.tokenizer, output_style=forward_parameters['output_style']) | |
| def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any: | |
| return {"output": model_outputs, "length": len(model_outputs)} | |