File size: 1,087 Bytes
4560edd
43b64a9
a29fcc9
7d4bc98
4560edd
a29fcc9
 
 
4560edd
7d4bc98
 
 
43b64a9
 
 
 
 
 
7d4bc98
 
979a01c
7d4bc98
 
 
 
 
 
4560edd
7d4bc98
 
 
 
 
9406f70
 
 
 
7d4bc98
9406f70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn.functional as F  # noqa: N812
from huggingface_hub import login
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from antispam_api.config import HF_AUTH_TOKEN, MODEL_NAME

login(HF_AUTH_TOKEN)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = (
    AutoModelForSequenceClassification
    .from_pretrained(MODEL_NAME)
    .to(device)
    .eval()
)


def get_spam_score(text: str) -> float:
    encoding = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        if logits.shape[1] == 1:
            score = torch.sigmoid(logits).cpu().item()
        else:
            score = torch.softmax(logits, dim=-1).cpu().numpy()[0][1]

    return float(score)