| from transformers import Pipeline | |
| from collections.abc import Iterable | |
| import torch | |
| class SciBertPaperClassifierPipeline(Pipeline): | |
| def _sanitize_parameters(self, **kwargs): | |
| return {}, {}, {} | |
| def preprocess(self, inputs): | |
| if not isinstance(inputs, Iterable): | |
| inputs = [inputs] | |
| texts = [ | |
| f"AUTHORS: {' '.join(paper.authors) if isinstance(paper.authors, list) else paper.authors} " | |
| f"TITLE: {paper.title} ABSTRACT: {paper.abstract}" | |
| for paper in inputs | |
| ] | |
| inputs = self.tokenizer( | |
| texts, truncation=True, padding=True, max_length=256, return_tensors="pt" | |
| ).to(self.device) | |
| return inputs | |
| def _forward(self, model_inputs): | |
| with torch.no_grad(): | |
| outputs = self.model(**model_inputs) | |
| return outputs | |
| def postprocess(self, model_outputs): | |
| probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1) | |
| results = [] | |
| for prob in probs: | |
| result = [ | |
| {self.model.config.id2label[label_idx]: score.item()} | |
| for label_idx, score in enumerate(prob) | |
| ] | |
| results.append(result) | |
| if 1 == len(results): | |
| return results[0] | |
| return results | |