File size: 3,166 Bytes
02c45ef
 
 
 
b363dbf
02c45ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b363dbf
 
 
 
 
 
 
 
02c45ef
 
 
 
 
b363dbf
02c45ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import numpy as np
import re
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
from huggingface_hub import hf_hub_download

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = None
models = None

def load_resources():
    global tokenizer, models
    
    if tokenizer is not None and models is not None:
        return
    
    print("loading models...")
    
    tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
    
    num_classes = 2
    dropout = 0.4
    
    models = []
    for i in range(1, 6):
        model_filename = f"ensemble_model_{i}.pth"
        
        print(f"downloading {model_filename}...")
        model_path = hf_hub_download(
            repo_id="codingcoolfun9ed/sentinelcheck-models",
            filename=model_filename
        )
        
        model = DistilBertForSequenceClassification.from_pretrained(
            'distilbert-base-uncased',
            num_labels=num_classes,
            dropout=dropout
        )
        model.load_state_dict(torch.load(model_path, map_location=device))
        model = model.to(device)
        model.eval()
        models.append(model)
    
    print("models loaded")

def cleanText(text):
    if not text:
        return ""
    text = str(text)
    text = re.sub(r'<[^>]+>', '', text)
    text = ' '.join(text.split())
    text = text.lower()
    text = text.strip()
    return text

def getLengthCategory(text):
    words = text.split()
    wordCount = len(words)
    if wordCount <= 20:
        return 'short'
    elif wordCount <= 50:
        return 'short-medium'
    elif wordCount <= 100:
        return 'medium'
    elif wordCount <= 200:
        return 'long'
    else:
        return 'very-long'

def predict_review(text):
    load_resources()
    
    cleaned = cleanText(text)
    
    if not cleaned:
        return {
            "prediction": "invalid",
            "confidence": 0.0,
            "is_fake": False,
            "error": "empty text after preprocessing"
        }
    
    encoding = tokenizer(
        cleaned,
        truncation=True,
        padding='max_length',
        max_length=256,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    allOutputs = []
    with torch.no_grad():
        for model in models:
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.softmax(outputs.logits, dim=1)
            allOutputs.append(probs.cpu().numpy())
    
    avgProbs = np.mean(allOutputs, axis=0)[0]
    fakeProb = avgProbs[1]
    realProb = avgProbs[0]
    
    isFake = fakeProb > 0.5
    confidence = max(fakeProb, realProb)
    prediction = "fake" if isFake else "real"
    
    if confidence < 0.75:
        prediction = "uncertain"
    
    lengthCat = getLengthCategory(cleaned)
    
    return {
        "prediction": prediction,
        "confidence": float(confidence),
        "is_fake": bool(isFake),
        "length_category": lengthCat,
        "token_count": len(cleaned.split())
    }