from transformers import Pipeline from .classifier import Classifier class MultiTaskClassifierPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} postprocess_kwargs = {} return preprocess_kwargs, {}, postprocess_kwargs def preprocess(self, inputs): return self.tokenizer(inputs, padding="max_length", truncation=True, return_tensors=self.framework).to(self.device) def _forward(self, model_inputs): return self.model(**model_inputs) def postprocess(self, model_outputs): model_config = self.model.config classifier = Classifier(model_config.task_specific_params[Classifier.MODEL_CONFIG]) logits = model_outputs.logits.numpy() return classifier.get_results(logits)[0]