bert-question-classifier / classifier_pipeline.py
lekhnathrijal's picture
Upload MultiTaskClassifierPipeline
692d9ea
from transformers import Pipeline
from .classifier import Classifier
class MultiTaskClassifierPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
postprocess_kwargs = {}
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs):
return self.tokenizer(inputs, padding="max_length", truncation=True, return_tensors=self.framework).to(self.device)
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs):
model_config = self.model.config
classifier = Classifier(model_config.task_specific_params[Classifier.MODEL_CONFIG])
logits = model_outputs.logits.numpy()
return classifier.get_results(logits)[0]