sentinelcheck-api / api /predict.py
codingcoolfun9ed's picture
optimal threshold slightly changed to fit new and improved models
884aadf verified
raw
history blame
8.74 kB
import torch
import numpy as np
import re
from transformers import (
DistilBertTokenizer, DistilBertForSequenceClassification,
RobertaTokenizer, RobertaForSequenceClassification,
BertTokenizer, BertForSequenceClassification
)
from huggingface_hub import hf_hub_download
import gc
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
models = []
tokenizers = []
maxLengths = []
modelWeights = [0.333, 0.333, 0.333]
optimalThreshold = 0.40
uncertaintyThreshold = 0.67
CLASS_NAMES = ['genuine', 'fake']
models_loaded = False
def validateText(text):
if not isinstance(text, str):
return False
text = text.strip()
return len(text) > 0 and len(text.split()) > 0
def cleanReview(text):
if not text or not isinstance(text, str):
return ""
text = re.sub(r'http\S+|www\.\S+', '', text)
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'([!?.])\1+', r'\1', text)
text = ' '.join(text.split())
return text.strip()
def loadResources():
global models, tokenizers, maxLengths, models_loaded
if models_loaded:
return
print("loading ensemble models...", flush=True)
modelConfigs = [
{
'filename': 'ensemble_model_1.pth',
'type': 'distilbert',
'name': 'distilbert-base-uncased',
'maxLen': 128
},
{
'filename': 'ensemble_model_2.pth',
'type': 'roberta',
'name': 'roberta-base',
'maxLen': 192
},
{
'filename': 'ensemble_model_3.pth',
'type': 'bert',
'name': 'bert-base-uncased',
'maxLen': 256
}
]
for i, config in enumerate(modelConfigs, 1):
try:
print(f"loading model {i}: {config['type']}", flush=True)
modelPath = hf_hub_download(
repo_id="codingcoolfun9ed/sentinelcheck-models",
filename=config['filename']
)
if config['type'] == 'distilbert':
tokenizer = DistilBertTokenizer.from_pretrained(config['name'])
model = DistilBertForSequenceClassification.from_pretrained(
config['name'],
num_labels=2
)
elif config['type'] == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(config['name'])
model = RobertaForSequenceClassification.from_pretrained(
config['name'],
num_labels=2
)
elif config['type'] == 'bert':
tokenizer = BertTokenizer.from_pretrained(config['name'])
model = BertForSequenceClassification.from_pretrained(
config['name'],
num_labels=2
)
else:
raise ValueError(f"unknown model type: {config['type']}")
checkpoint = torch.load(modelPath, map_location=device, weights_only=False)
if 'state_dict' not in checkpoint:
raise ValueError(f"model {i} missing state_dict")
model.load_state_dict(checkpoint['state_dict'], strict=False)
model = model.to(device)
model.eval()
for param in model.parameters():
param.requires_grad = False
models.append(model)
tokenizers.append(tokenizer)
maxLengths.append(config['maxLen'])
del checkpoint
gc.collect()
print(f"model {i} loaded successfully", flush=True)
except Exception as e:
print(f"error loading model {i}: {str(e)}", flush=True)
raise
models_loaded = True
print("all ensemble models loaded", flush=True)
def ensemblePredict(text):
if not models_loaded:
loadResources()
if not isinstance(text, str):
text = str(text)
text = cleanReview(text)
if not validateText(text):
return {
'fakeProb': 0.5,
'genuineProb': 0.5,
'isFake': None,
'agreement': 0.0,
'error': 'invalid_text'
}
weightedProbs = torch.zeros(1, 2).to(device)
allPreds = []
try:
with torch.no_grad():
for tokenizer, model, maxLen, weight in zip(tokenizers, models, maxLengths, modelWeights):
inputs = tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=maxLen,
padding='max_length'
)
inputIds = inputs['input_ids'].to(device)
attentionMask = inputs['attention_mask'].to(device)
outputs = model(input_ids=inputIds, attention_mask=attentionMask)
probs = torch.softmax(outputs.logits, dim=1)
weightedProbs += probs * weight
_, pred = torch.max(probs, 1)
allPreds.append(pred.item())
del inputs, inputIds, attentionMask, outputs, probs, pred
probs = weightedProbs[0].cpu().numpy()
genuineProb = float(probs[0])
fakeProb = float(probs[1])
isFake = fakeProb > optimalThreshold
finalPred = 1 if isFake else 0
agreementCount = sum(1 for p in allPreds if p == finalPred)
agreement = agreementCount / len(allPreds)
del weightedProbs, allPreds
gc.collect()
return {
'genuineProb': genuineProb,
'fakeProb': fakeProb,
'isFake': isFake,
'agreement': agreement
}
except Exception as e:
print(f"prediction error: {str(e)}", flush=True)
return {
'fakeProb': 0.5,
'genuineProb': 0.5,
'isFake': None,
'agreement': 0.0,
'error': str(e)
}
def getLengthCategory(text):
if not text:
return 'empty'
words = text.split()
wordCount = len(words)
if wordCount <= 20:
return 'short'
elif wordCount <= 50:
return 'short-medium'
elif wordCount <= 100:
return 'medium'
elif wordCount <= 200:
return 'long'
else:
return 'very-long'
def predict_review(text):
if not text or not isinstance(text, str):
return {
"prediction": "error",
"confidence": 0.0,
"is_fake": None,
"model_agreement": 0.0,
"fake_probability": 0.0,
"genuine_probability": 0.0,
"length_category": "empty",
"token_count": 0,
"error": "invalid input: text must be non-empty string"
}
cleaned = cleanReview(text)
if not cleaned or len(cleaned.strip()) == 0:
return {
"prediction": "error",
"confidence": 0.0,
"is_fake": None,
"model_agreement": 0.0,
"fake_probability": 0.0,
"genuine_probability": 0.0,
"length_category": "empty",
"token_count": 0,
"error": "empty text after preprocessing"
}
result = ensemblePredict(text)
if 'error' in result:
return {
"prediction": "error",
"confidence": 0.0,
"is_fake": None,
"model_agreement": result['agreement'] * 100,
"fake_probability": result['fakeProb'],
"genuine_probability": result['genuineProb'],
"length_category": getLengthCategory(cleaned),
"token_count": len(cleaned.split()),
"error": result['error']
}
fakeProb = result['fakeProb']
genuineProb = result['genuineProb']
isFake = result['isFake']
agreement = result['agreement']
confidence = max(fakeProb, genuineProb)
prediction = "fake" if isFake else "genuine"
isFakeOutput = isFake
if agreement < uncertaintyThreshold:
prediction = "uncertain"
isFakeOutput = None
lengthCat = getLengthCategory(cleaned)
tokenCount = len(cleaned.split())
return {
"prediction": prediction,
"confidence": confidence,
"is_fake": isFakeOutput,
"model_agreement": agreement * 100,
"fake_probability": fakeProb,
"genuine_probability": genuineProb,
"length_category": lengthCat,
"token_count": tokenCount
}