| from transformers import Pipeline |
| from collections.abc import Iterable |
| import torch |
|
|
| class SciBertPaperClassifierPipeline(Pipeline): |
| def _sanitize_parameters(self, **kwargs): |
| return {}, {}, {} |
|
|
| def preprocess(self, inputs): |
| if not isinstance(inputs, Iterable): |
| inputs = [inputs] |
| texts = [ |
| f"AUTHORS: {' '.join(paper.authors) if isinstance(paper.authors, list) else paper.authors} " |
| f"TITLE: {paper.title} ABSTRACT: {paper.abstract}" |
| for paper in inputs |
| ] |
| inputs = self.tokenizer( |
| texts, truncation=True, padding=True, max_length=256, return_tensors="pt" |
| ).to(self.device) |
| return inputs |
|
|
| def _forward(self, model_inputs): |
| with torch.no_grad(): |
| outputs = self.model(**model_inputs) |
| return outputs |
|
|
| def postprocess(self, model_outputs): |
| probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1) |
| results = [] |
| for prob in probs: |
| result = [ |
| {self.model.config.id2label[label_idx]: score.item()} |
| for label_idx, score in enumerate(prob) |
| ] |
| results.append(result) |
| if 1 == len(results): |
| return results[0] |
| return results |
|
|