| from transformers import Pipeline | |
| import torch | |
| class PairClassificationPipeline(Pipeline): | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| if "text_pair" in kwargs: | |
| preprocess_kwargs["text_pair"] = kwargs["text_pair"] | |
| return preprocess_kwargs, {}, {} | |
| def preprocess(self, text, text_pair=None): | |
| return self.tokenizer(text, text_pair=text_pair, return_tensors="pt") | |
| def _forward(self, model_inputs): | |
| return self.model(**model_inputs) | |
| def postprocess(self, model_outputs): | |
| logits = model_outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| best_class = probabilities.argmax().item() | |
| label = self.model.config.id2label[best_class] | |
| score = probabilities.squeeze()[best_class].item() | |
| logits = logits.squeeze().tolist() | |
| return {"label": label, "score": score, "logits": logits} | |