Simon Clematide
Add CLI and inference modules for batch prediction using Hugging Face model
9d36a4d
raw
history blame
1.43 kB
# sdg_predict/inference.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
def load_model(model_name, device):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
model.eval()
return tokenizer, model
def batched(iterable, batch_size):
for i in range(0, len(iterable), batch_size):
yield iterable[i:i + batch_size]
def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True):
results = []
for batch_texts in batched(texts, batch_size):
inputs = tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
for prob in probs:
if return_all_scores:
results.append([
{"label": model.config.id2label[i], "score": prob[i].item()}
for i in range(len(prob))
])
else:
top = torch.argmax(prob).item()
results.append({
"label": model.config.id2label[top],
"score": prob[top].item()
})
return results