File size: 1,331 Bytes
7a028db
 
 
c690084
 
2bbbe37
c690084
 
 
 
2bbbe37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a028db
 
 
 
2bbbe37
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from model import FakeBERT

MODEL_NAME = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
MODEL_PATH = "distilbert_best.pth"
MAX_LENGTH = 512
NUM_CLASSES = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def predict_veracity(texts, model, tokenizer, device, max_length=MAX_LENGTH):
    model.eval()
    id2label = {0: "F", 1: "U", 2: "T"}

    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    input_ids = encodings["input_ids"].to(device)
    attention_mask = encodings["attention_mask"].to(device)
    token_type_ids = encodings.get("token_type_ids")
    if token_type_ids is not None:
        token_type_ids = token_type_ids.to(device)

    with torch.inference_mode():
        logits = model(input_ids, attention_mask, token_type_ids)
        preds = torch.argmax(logits, dim=1).tolist()

    return [id2label.get(p, "U") for p in preds]




# Load resources
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)

# Label a list of texts
labels = predict_sentiment(texts, model, tokenizer, DEVICE)