| # sdg_predict/inference.py | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| def load_model(model_name, device): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) | |
| model.eval() | |
| return tokenizer, model | |
| def batched(iterable, batch_size): | |
| for i in range(0, len(iterable), batch_size): | |
| yield iterable[i:i + batch_size] | |
| def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True): | |
| results = [] | |
| for batch_texts in batched(texts, batch_size): | |
| inputs = tokenizer( | |
| batch_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| for prob in probs: | |
| if return_all_scores: | |
| results.append([ | |
| {"label": model.config.id2label[i], "score": prob[i].item()} | |
| for i in range(len(prob)) | |
| ]) | |
| else: | |
| top = torch.argmax(prob).item() | |
| results.append({ | |
| "label": model.config.id2label[top], | |
| "score": prob[top].item() | |
| }) | |
| return results | |