|
|
import torch |
|
|
from model import FakeBERT |
|
|
|
|
|
MODEL_NAME = "distilbert/distilbert-base-uncased-finetuned-sst-2-english" |
|
|
MODEL_PATH = "distilbert_best.pth" |
|
|
MAX_LENGTH = 512 |
|
|
NUM_CLASSES = 3 |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def predict_veracity(texts, model, tokenizer, device, max_length=MAX_LENGTH): |
|
|
model.eval() |
|
|
id2label = {0: "F", 1: "U", 2: "T"} |
|
|
|
|
|
encodings = tokenizer( |
|
|
texts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = encodings["input_ids"].to(device) |
|
|
attention_mask = encodings["attention_mask"].to(device) |
|
|
token_type_ids = encodings.get("token_type_ids") |
|
|
if token_type_ids is not None: |
|
|
token_type_ids = token_type_ids.to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
logits = model(input_ids, attention_mask, token_type_ids) |
|
|
preds = torch.argmax(logits, dim=1).tolist() |
|
|
|
|
|
return [id2label.get(p, "U") for p in preds] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE) |
|
|
state_dict = torch.load(MODEL_PATH, map_location=DEVICE) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
labels = predict_sentiment(texts, model, tokenizer, DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|