| import torch | |
| import numpy as np | |
| import re | |
| from transformers import ( | |
| DistilBertTokenizer, DistilBertForSequenceClassification, | |
| RobertaTokenizer, RobertaForSequenceClassification, | |
| BertTokenizer, BertForSequenceClassification | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| import gc | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| models = [] | |
| tokenizers = [] | |
| maxLengths = [] | |
| modelWeights = [0.333, 0.333, 0.333] | |
| optimalThreshold = 0.45 | |
| uncertaintyThreshold = 0.67 | |
| CLASS_NAMES = ['genuine', 'fake'] | |
| models_loaded = False | |
| def validateText(text): | |
| if not isinstance(text, str): | |
| return False | |
| text = text.strip() | |
| return len(text) > 0 and len(text.split()) > 0 | |
| def cleanReview(text): | |
| if not text or not isinstance(text, str): | |
| return "" | |
| text = re.sub(r'http\S+|www\.\S+', '', text) | |
| text = re.sub(r'<[^>]+>', '', text) | |
| text = re.sub(r'([!?.])\1+', r'\1', text) | |
| text = ' '.join(text.split()) | |
| return text.strip() | |
| def loadResources(): | |
| global models, tokenizers, maxLengths, models_loaded | |
| if models_loaded: | |
| return | |
| print("loading ensemble models...", flush=True) | |
| modelConfigs = [ | |
| { | |
| 'filename': 'ensemble_model_1.pth', | |
| 'type': 'distilbert', | |
| 'name': 'distilbert-base-uncased', | |
| 'maxLen': 128 | |
| }, | |
| { | |
| 'filename': 'ensemble_model_2.pth', | |
| 'type': 'roberta', | |
| 'name': 'roberta-base', | |
| 'maxLen': 192 | |
| }, | |
| { | |
| 'filename': 'ensemble_model_3.pth', | |
| 'type': 'bert', | |
| 'name': 'bert-base-uncased', | |
| 'maxLen': 256 | |
| } | |
| ] | |
| for i, config in enumerate(modelConfigs, 1): | |
| try: | |
| print(f"loading model {i}: {config['type']}", flush=True) | |
| modelPath = hf_hub_download( | |
| repo_id="codingcoolfun9ed/sentinelcheck-models", | |
| filename=config['filename'] | |
| ) | |
| if config['type'] == 'distilbert': | |
| tokenizer = DistilBertTokenizer.from_pretrained(config['name']) | |
| model = DistilBertForSequenceClassification.from_pretrained( | |
| config['name'], | |
| num_labels=2 | |
| ) | |
| elif config['type'] == 'roberta': | |
| tokenizer = RobertaTokenizer.from_pretrained(config['name']) | |
| model = RobertaForSequenceClassification.from_pretrained( | |
| config['name'], | |
| num_labels=2 | |
| ) | |
| elif config['type'] == 'bert': | |
| tokenizer = BertTokenizer.from_pretrained(config['name']) | |
| model = BertForSequenceClassification.from_pretrained( | |
| config['name'], | |
| num_labels=2 | |
| ) | |
| else: | |
| raise ValueError(f"unknown model type: {config['type']}") | |
| checkpoint = torch.load(modelPath, map_location=device, weights_only=False) | |
| if 'state_dict' not in checkpoint: | |
| raise ValueError(f"model {i} missing state_dict") | |
| model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| model = model.to(device) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| models.append(model) | |
| tokenizers.append(tokenizer) | |
| maxLengths.append(config['maxLen']) | |
| del checkpoint | |
| gc.collect() | |
| print(f"model {i} loaded successfully", flush=True) | |
| except Exception as e: | |
| print(f"error loading model {i}: {str(e)}", flush=True) | |
| raise | |
| models_loaded = True | |
| print("all ensemble models loaded", flush=True) | |
| def ensemblePredict(text): | |
| if not models_loaded: | |
| loadResources() | |
| if not isinstance(text, str): | |
| text = str(text) | |
| text = cleanReview(text) | |
| if not validateText(text): | |
| return { | |
| 'fakeProb': 0.5, | |
| 'genuineProb': 0.5, | |
| 'isFake': None, | |
| 'agreement': 0.0, | |
| 'error': 'invalid_text' | |
| } | |
| weightedProbs = torch.zeros(1, 2).to(device) | |
| allPreds = [] | |
| try: | |
| with torch.no_grad(): | |
| for tokenizer, model, maxLen, weight in zip(tokenizers, models, maxLengths, modelWeights): | |
| inputs = tokenizer( | |
| text, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=maxLen, | |
| padding='max_length' | |
| ) | |
| inputIds = inputs['input_ids'].to(device) | |
| attentionMask = inputs['attention_mask'].to(device) | |
| outputs = model(input_ids=inputIds, attention_mask=attentionMask) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| weightedProbs += probs * weight | |
| _, pred = torch.max(probs, 1) | |
| allPreds.append(pred.item()) | |
| del inputs, inputIds, attentionMask, outputs, probs, pred | |
| probs = weightedProbs[0].cpu().numpy() | |
| genuineProb = float(probs[0]) | |
| fakeProb = float(probs[1]) | |
| isFake = fakeProb > optimalThreshold | |
| finalPred = 1 if isFake else 0 | |
| agreementCount = sum(1 for p in allPreds if p == finalPred) | |
| agreement = float(agreementCount) / len(allPreds) | |
| del weightedProbs, allPreds | |
| gc.collect() | |
| return { | |
| 'genuineProb': genuineProb, | |
| 'fakeProb': fakeProb, | |
| 'isFake': isFake, | |
| 'agreement': agreement | |
| } | |
| except Exception as e: | |
| print(f"prediction error: {str(e)}", flush=True) | |
| return { | |
| 'fakeProb': 0.5, | |
| 'genuineProb': 0.5, | |
| 'isFake': None, | |
| 'agreement': 0.0, | |
| 'error': str(e) | |
| } | |
| def getLengthCategory(text): | |
| if not text: | |
| return 'empty' | |
| words = text.split() | |
| wordCount = len(words) | |
| if wordCount <= 20: | |
| return 'short' | |
| elif wordCount <= 50: | |
| return 'short-medium' | |
| elif wordCount <= 100: | |
| return 'medium' | |
| elif wordCount <= 200: | |
| return 'long' | |
| else: | |
| return 'very-long' | |
| def predict_review(text): | |
| if not text or not isinstance(text, str): | |
| return { | |
| "prediction": "error", | |
| "confidence": 0.0, | |
| "is_fake": None, | |
| "model_agreement": 0.0, | |
| "fake_probability": 0.0, | |
| "genuine_probability": 0.0, | |
| "length_category": "empty", | |
| "token_count": 0, | |
| "error": "invalid input: text must be non-empty string" | |
| } | |
| cleaned = cleanReview(text) | |
| if not cleaned or len(cleaned.strip()) == 0: | |
| return { | |
| "prediction": "error", | |
| "confidence": 0.0, | |
| "is_fake": None, | |
| "model_agreement": 0.0, | |
| "fake_probability": 0.0, | |
| "genuine_probability": 0.0, | |
| "length_category": "empty", | |
| "token_count": 0, | |
| "error": "empty text after preprocessing" | |
| } | |
| result = ensemblePredict(text) | |
| if 'error' in result: | |
| return { | |
| "prediction": "error", | |
| "confidence": 0.0, | |
| "is_fake": None, | |
| "model_agreement": result['agreement'], | |
| "fake_probability": result['fakeProb'], | |
| "genuine_probability": result['genuineProb'], | |
| "length_category": getLengthCategory(cleaned), | |
| "token_count": len(cleaned.split()), | |
| "error": result['error'] | |
| } | |
| fakeProb = result['fakeProb'] | |
| genuineProb = result['genuineProb'] | |
| isFake = result['isFake'] | |
| agreement = result['agreement'] | |
| confidence = max(fakeProb, genuineProb) | |
| if agreement < uncertaintyThreshold: | |
| prediction = "uncertain" | |
| isFakeOutput = None | |
| else: | |
| prediction = "fake" if isFake else "genuine" | |
| isFakeOutput = isFake | |
| lengthCat = getLengthCategory(cleaned) | |
| tokenCount = len(cleaned.split()) | |
| return { | |
| "prediction": prediction, | |
| "confidence": float(confidence), | |
| "is_fake": isFakeOutput, | |
| "model_agreement": float(agreement), | |
| "fake_probability": float(fakeProb), | |
| "genuine_probability": float(genuineProb), | |
| "length_category": lengthCat, | |
| "token_count": tokenCount | |
| } |