import typing as tp import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer class SciBertPaperClassifier: def __init__(self, model_path="trained_model"): self.model = AutoModelForSequenceClassification.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() def __call__(self, inputs): if not isinstance(inputs, tp.Iterable): inputs = [inputs] texts = [ f"AUTHORS: {' '.join(paper.authors) if isinstance(paper.authors, list) else paper.authors} " f"TITLE: {paper.title} ABSTRACT: {paper.abstract}" for paper in inputs ] inputs = self.tokenizer( texts, truncation=True, padding=True, max_length=256, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) results = [] for prob in probs: result = [ {self.model.config.id2label[label_idx]: score.item()} for label_idx, score in enumerate(prob) ] results.append(result) if 1 == len(results): return results[0] return results