| import typing as tp | |
| from collections import namedtuple | |
| import torch | |
| from transformers import Pipeline, AutoModelForSequenceClassification | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| class PapersClassificationPipeline(Pipeline): | |
| def _sanitize_parameters(self, **kwargs): | |
| return {}, {}, {} | |
| def preprocess(self, inputs): | |
| if ( | |
| not isinstance(inputs, tp.Iterable) | |
| or isinstance(inputs, tp.Dict) | |
| or isinstance(inputs, str) | |
| ): | |
| inputs = [inputs] | |
| title = "title" | |
| authors = "authors" | |
| abstract = "abstract" | |
| texts = [ | |
| ( | |
| f"AUTHORS: {' '.join(paper[title]) if isinstance(paper[authors], list) else paper[authors]} " | |
| f"TITLE: {paper[title]} ABSTRACT: {paper[abstract]}" | |
| if not isinstance(paper, str) | |
| else paper | |
| ) | |
| for paper in inputs | |
| ] | |
| inputs = self.tokenizer( | |
| texts, truncation=True, padding=True, max_length=256, return_tensors="pt" | |
| ).to(self.device) | |
| return inputs | |
| def _forward(self, model_inputs): | |
| with torch.no_grad(): | |
| outputs = self.model(**model_inputs) | |
| return outputs | |
| def postprocess(self, model_outputs): | |
| probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1) | |
| results = [] | |
| for prob in probs: | |
| result = [ | |
| {"label": self.model.config.id2label[label_idx], "score": score.item()} | |
| for label_idx, score in enumerate(prob) | |
| ] | |
| results.append(result) | |
| if 1 == len(results): | |
| return results[0] | |
| return results | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "paper-classification", | |
| pipeline_class=PapersClassificationPipeline, | |
| pt_model=AutoModelForSequenceClassification, | |
| ) | |