ventuero
fuck
979a01c
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import login
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from antispam_api.config import HF_AUTH_TOKEN, MODEL_NAME
login(HF_AUTH_TOKEN)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = (
AutoModelForSequenceClassification
.from_pretrained(MODEL_NAME)
.to(device)
.eval()
)
def get_spam_score(text: str) -> float:
encoding = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask).logits
if logits.shape[1] == 1:
score = torch.sigmoid(logits).cpu().item()
else:
score = torch.softmax(logits, dim=-1).cpu().numpy()[0][1]
return float(score)