File size: 883 Bytes
4d16182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

MODEL_DIR = "app/model"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
model.eval()


def classify_text(text: str):
    """
    Retorna:
      pred: 0 (fake) ou 1 (real)
      confidence: probabilidade máxima, já no formato que sua API usa
    """
    encoded = tokenizer(
        text,
        truncation=True,
        padding=True,
        max_length=256,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        out = model(**encoded)
        logits = out.logits
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]

    pred = int(probs.argmax())
    confidence = float(probs.max())

    return pred, confidence