sentinelcheck-api / api /predict.py
main
fresh deploy with external models
02c45ef
raw
history blame
3.03 kB
import torch
import numpy as np
import re
import os
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
scriptDir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
modelsDir = os.path.join(scriptDir, "models")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = None
models = None
def load_resources():
global tokenizer, models
if tokenizer is not None and models is not None:
return
print("loading models...")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
num_classes = 2
dropout = 0.4
models = []
for i in range(1, 6):
model = DistilBertForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=num_classes,
dropout=dropout
)
model.load_state_dict(torch.load(os.path.join(modelsDir, f"ensemble_model_{i}.pth"), map_location=device))
model = model.to(device)
model.eval()
models.append(model)
print("models loaded")
def cleanText(text):
if not text:
return ""
text = str(text)
text = re.sub(r'<[^>]+>', '', text)
text = ' '.join(text.split())
text = text.lower()
text = text.strip()
return text
def getLengthCategory(text):
words = text.split()
wordCount = len(words)
if wordCount <= 20:
return 'short'
elif wordCount <= 50:
return 'short-medium'
elif wordCount <= 100:
return 'medium'
elif wordCount <= 200:
return 'long'
else:
return 'very-long'
def predict_review(text):
load_resources()
cleaned = cleanText(text)
if not cleaned:
return {
"prediction": "invalid",
"confidence": 0.0,
"is_fake": False,
"error": "empty text after preprocessing"
}
encoding = tokenizer(
cleaned,
truncation=True,
padding='max_length',
max_length=256,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
allOutputs = []
with torch.no_grad():
for model in models:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
allOutputs.append(probs.cpu().numpy())
avgProbs = np.mean(allOutputs, axis=0)[0]
fakeProb = avgProbs[1]
realProb = avgProbs[0]
isFake = fakeProb > 0.5
confidence = max(fakeProb, realProb)
prediction = "fake" if isFake else "real"
if confidence < 0.75:
prediction = "uncertain"
lengthCat = getLengthCategory(cleaned)
return {
"prediction": prediction,
"confidence": float(confidence),
"is_fake": bool(isFake),
"length_category": lengthCat,
"token_count": len(cleaned.split())
}