| from transformers import Pipeline | |
| from .classifier import Classifier | |
| class MultiTaskClassifierPipeline(Pipeline): | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| postprocess_kwargs = {} | |
| return preprocess_kwargs, {}, postprocess_kwargs | |
| def preprocess(self, inputs): | |
| return self.tokenizer(inputs, padding="max_length", truncation=True, return_tensors=self.framework).to(self.device) | |
| def _forward(self, model_inputs): | |
| return self.model(**model_inputs) | |
| def postprocess(self, model_outputs): | |
| model_config = self.model.config | |
| classifier = Classifier(model_config.task_specific_params[Classifier.MODEL_CONFIG]) | |
| logits = model_outputs.logits.numpy() | |
| return classifier.get_results(logits)[0] |