sentinelcheck-api / api /predict.py
codingcoolfun9ed's picture
change threshold again again again
53b725a verified
raw
history blame
3.17 kB
import torch
import numpy as np
import re
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
from huggingface_hub import hf_hub_download
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = None
models = None
def load_resources():
global tokenizer, models
if tokenizer is not None and models is not None:
return
print("loading models...")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
num_classes = 2
dropout = 0.4
models = []
for i in range(1, 6):
model_filename = f"ensemble_model_{i}.pth"
print(f"downloading {model_filename}...")
model_path = hf_hub_download(
repo_id="codingcoolfun9ed/sentinelcheck-models",
filename=model_filename
)
model = DistilBertForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=num_classes,
dropout=dropout
)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
models.append(model)
print("models loaded")
def cleanText(text):
if not text:
return ""
text = str(text)
text = re.sub(r'<[^>]+>', '', text)
text = ' '.join(text.split())
text = text.lower()
text = text.strip()
return text
def getLengthCategory(text):
words = text.split()
wordCount = len(words)
if wordCount <= 20:
return 'short'
elif wordCount <= 50:
return 'short-medium'
elif wordCount <= 100:
return 'medium'
elif wordCount <= 200:
return 'long'
else:
return 'very-long'
def predict_review(text):
load_resources()
cleaned = cleanText(text)
if not cleaned:
return {
"prediction": "invalid",
"confidence": 0.0,
"is_fake": False,
"error": "empty text after preprocessing"
}
encoding = tokenizer(
cleaned,
truncation=True,
padding='max_length',
max_length=256,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
allOutputs = []
with torch.no_grad():
for model in models:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
allOutputs.append(probs.cpu().numpy())
avgProbs = np.mean(allOutputs, axis=0)[0]
fakeProb = avgProbs[1]
realProb = avgProbs[0]
isFake = fakeProb > 0.75
confidence = max(fakeProb, realProb)
prediction = "fake" if isFake else "real"
if confidence < 0.75:
prediction = "uncertain"
lengthCat = getLengthCategory(cleaned)
return {
"prediction": prediction,
"confidence": float(confidence),
"is_fake": bool(isFake),
"length_category": lengthCat,
"token_count": len(cleaned.split())
}