import torch import numpy as np import re from transformers import ( DistilBertTokenizer, DistilBertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, BertTokenizer, BertForSequenceClassification ) from huggingface_hub import hf_hub_download import gc device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') models = [] tokenizers = [] maxLengths = [] modelWeights = [0.333, 0.333, 0.333] optimalThreshold = 0.45 uncertaintyThreshold = 0.67 CLASS_NAMES = ['genuine', 'fake'] models_loaded = False def validateText(text): if not isinstance(text, str): return False text = text.strip() return len(text) > 0 and len(text.split()) > 0 def cleanReview(text): if not text or not isinstance(text, str): return "" text = re.sub(r'http\S+|www\.\S+', '', text) text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'([!?.])\1+', r'\1', text) text = ' '.join(text.split()) return text.strip() def loadResources(): global models, tokenizers, maxLengths, models_loaded if models_loaded: return print("loading ensemble models...", flush=True) modelConfigs = [ { 'filename': 'ensemble_model_1.pth', 'type': 'distilbert', 'name': 'distilbert-base-uncased', 'maxLen': 128 }, { 'filename': 'ensemble_model_2.pth', 'type': 'roberta', 'name': 'roberta-base', 'maxLen': 192 }, { 'filename': 'ensemble_model_3.pth', 'type': 'bert', 'name': 'bert-base-uncased', 'maxLen': 256 } ] for i, config in enumerate(modelConfigs, 1): try: print(f"loading model {i}: {config['type']}", flush=True) modelPath = hf_hub_download( repo_id="codingcoolfun9ed/sentinelcheck-models", filename=config['filename'] ) if config['type'] == 'distilbert': tokenizer = DistilBertTokenizer.from_pretrained(config['name']) model = DistilBertForSequenceClassification.from_pretrained( config['name'], num_labels=2 ) elif config['type'] == 'roberta': tokenizer = RobertaTokenizer.from_pretrained(config['name']) model = RobertaForSequenceClassification.from_pretrained( config['name'], num_labels=2 ) elif config['type'] == 'bert': tokenizer = BertTokenizer.from_pretrained(config['name']) model = BertForSequenceClassification.from_pretrained( config['name'], num_labels=2 ) else: raise ValueError(f"unknown model type: {config['type']}") checkpoint = torch.load(modelPath, map_location=device, weights_only=False) if 'state_dict' not in checkpoint: raise ValueError(f"model {i} missing state_dict") model.load_state_dict(checkpoint['state_dict'], strict=False) model = model.to(device) model.eval() for param in model.parameters(): param.requires_grad = False models.append(model) tokenizers.append(tokenizer) maxLengths.append(config['maxLen']) del checkpoint gc.collect() print(f"model {i} loaded successfully", flush=True) except Exception as e: print(f"error loading model {i}: {str(e)}", flush=True) raise models_loaded = True print("all ensemble models loaded", flush=True) def ensemblePredict(text): if not models_loaded: loadResources() if not isinstance(text, str): text = str(text) text = cleanReview(text) if not validateText(text): return { 'fakeProb': 0.5, 'genuineProb': 0.5, 'isFake': None, 'agreement': 0.0, 'error': 'invalid_text' } weightedProbs = torch.zeros(1, 2).to(device) allPreds = [] try: with torch.no_grad(): for tokenizer, model, maxLen, weight in zip(tokenizers, models, maxLengths, modelWeights): inputs = tokenizer( text, return_tensors='pt', truncation=True, max_length=maxLen, padding='max_length' ) inputIds = inputs['input_ids'].to(device) attentionMask = inputs['attention_mask'].to(device) outputs = model(input_ids=inputIds, attention_mask=attentionMask) probs = torch.softmax(outputs.logits, dim=1) weightedProbs += probs * weight _, pred = torch.max(probs, 1) allPreds.append(pred.item()) del inputs, inputIds, attentionMask, outputs, probs, pred probs = weightedProbs[0].cpu().numpy() genuineProb = float(probs[0]) fakeProb = float(probs[1]) isFake = fakeProb > optimalThreshold finalPred = 1 if isFake else 0 agreementCount = sum(1 for p in allPreds if p == finalPred) agreement = float(agreementCount) / len(allPreds) del weightedProbs, allPreds gc.collect() return { 'genuineProb': genuineProb, 'fakeProb': fakeProb, 'isFake': isFake, 'agreement': agreement } except Exception as e: print(f"prediction error: {str(e)}", flush=True) return { 'fakeProb': 0.5, 'genuineProb': 0.5, 'isFake': None, 'agreement': 0.0, 'error': str(e) } def getLengthCategory(text): if not text: return 'empty' words = text.split() wordCount = len(words) if wordCount <= 20: return 'short' elif wordCount <= 50: return 'short-medium' elif wordCount <= 100: return 'medium' elif wordCount <= 200: return 'long' else: return 'very-long' def predict_review(text): if not text or not isinstance(text, str): return { "prediction": "error", "confidence": 0.0, "is_fake": None, "model_agreement": 0.0, "fake_probability": 0.0, "genuine_probability": 0.0, "length_category": "empty", "token_count": 0, "error": "invalid input: text must be non-empty string" } cleaned = cleanReview(text) if not cleaned or len(cleaned.strip()) == 0: return { "prediction": "error", "confidence": 0.0, "is_fake": None, "model_agreement": 0.0, "fake_probability": 0.0, "genuine_probability": 0.0, "length_category": "empty", "token_count": 0, "error": "empty text after preprocessing" } result = ensemblePredict(text) if 'error' in result: return { "prediction": "error", "confidence": 0.0, "is_fake": None, "model_agreement": result['agreement'], "fake_probability": result['fakeProb'], "genuine_probability": result['genuineProb'], "length_category": getLengthCategory(cleaned), "token_count": len(cleaned.split()), "error": result['error'] } fakeProb = result['fakeProb'] genuineProb = result['genuineProb'] isFake = result['isFake'] agreement = result['agreement'] confidence = max(fakeProb, genuineProb) if agreement < uncertaintyThreshold: prediction = "uncertain" isFakeOutput = None else: prediction = "fake" if isFake else "genuine" isFakeOutput = isFake lengthCat = getLengthCategory(cleaned) tokenCount = len(cleaned.split()) return { "prediction": prediction, "confidence": float(confidence), "is_fake": isFakeOutput, "model_agreement": float(agreement), "fake_probability": float(fakeProb), "genuine_probability": float(genuineProb), "length_category": lengthCat, "token_count": tokenCount }