codingcoolfun9ed commited on
Commit
d31cc65
·
verified ·
1 Parent(s): 8eec530

updating this for the new esnemble for the final version AHHHH

Browse files
Files changed (1) hide show
  1. api/predict.py +231 -61
api/predict.py CHANGED
@@ -1,51 +1,201 @@
1
  import torch
2
  import numpy as np
3
  import re
4
- from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
 
 
 
 
5
  from huggingface_hub import hf_hub_download
 
6
 
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
- tokenizer = None
9
- model = None
10
 
11
- def load_resources():
12
- global tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- if tokenizer is not None and model is not None:
15
  return
16
 
17
- print("loading model...")
18
-
19
- tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
20
 
21
- print("downloading model_2.pth...")
22
- modelPath = hf_hub_download(
23
- repo_id="codingcoolfun9ed/sentinelcheck-models",
24
- filename="model_2.pth"
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- model = DistilBertForSequenceClassification.from_pretrained(
28
- 'distilbert-base-uncased',
29
- num_labels=2,
30
- dropout=0.4
31
- )
32
- model.load_state_dict(torch.load(modelPath, map_location=device))
33
- model = model.to(device)
34
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- print("model loaded")
37
 
38
- def cleanText(text):
39
- if not text:
40
- return ""
41
- text = str(text)
42
- text = re.sub(r'<[^>]+>', '', text)
43
- text = ' '.join(text.split())
44
- text = text.lower()
45
- text = text.strip()
46
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def getLengthCategory(text):
 
 
49
  words = text.split()
50
  wordCount = len(words)
51
  if wordCount <= 20:
@@ -60,53 +210,73 @@ def getLengthCategory(text):
60
  return 'very-long'
61
 
62
  def predict_review(text):
63
- load_resources()
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- cleaned = cleanText(text)
66
 
67
- if not cleaned:
68
  return {
69
- "prediction": "invalid",
70
  "confidence": 0.0,
71
- "is_fake": False,
 
 
 
 
 
72
  "error": "empty text after preprocessing"
73
  }
74
 
75
- encoding = tokenizer(
76
- cleaned,
77
- truncation=True,
78
- padding='max_length',
79
- max_length=256,
80
- return_tensors='pt'
81
- )
82
 
83
- inputIds = encoding['input_ids'].to(device)
84
- attentionMask = encoding['attention_mask'].to(device)
85
-
86
- with torch.no_grad():
87
- outputs = model(input_ids=inputIds, attention_mask=attentionMask)
88
- probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
 
 
 
 
 
 
89
 
90
- fakeProb = probs[1]
91
- realProb = probs[0]
 
 
92
 
93
- confidence = max(fakeProb, realProb)
94
 
95
- if confidence < 0.75:
96
  prediction = "uncertain"
97
- isFake = None
98
  else:
99
- isFake = fakeProb > realProb
100
- prediction = "fake" if isFake else "real"
101
 
102
  lengthCat = getLengthCategory(cleaned)
 
103
 
104
  return {
105
  "prediction": prediction,
106
  "confidence": float(confidence),
107
- "is_fake": isFake,
108
- "length_category": lengthCat,
109
- "token_count": len(cleaned.split()),
110
  "fake_probability": float(fakeProb),
111
- "real_probability": float(realProb)
 
 
112
  }
 
1
  import torch
2
  import numpy as np
3
  import re
4
+ from transformers import (
5
+ DistilBertTokenizer, DistilBertForSequenceClassification,
6
+ RobertaTokenizer, RobertaForSequenceClassification,
7
+ BertTokenizer, BertForSequenceClassification
8
+ )
9
  from huggingface_hub import hf_hub_download
10
+ import gc
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
13
 
14
+ models = []
15
+ tokenizers = []
16
+ maxLengths = []
17
+ modelWeights = [0.333, 0.333, 0.333]
18
+ optimalThreshold = 0.45
19
+ uncertaintyThreshold = 0.67
20
+
21
+ CLASS_NAMES = ['genuine', 'fake']
22
+
23
+ def validateText(text):
24
+ if not isinstance(text, str):
25
+ return False
26
+ text = text.strip()
27
+ return len(text) > 0 and len(text.split()) > 0
28
+
29
+ def cleanReview(text):
30
+ if not text or not isinstance(text, str):
31
+ return ""
32
+ text = re.sub(r'http\S+|www\.\S+', '', text)
33
+ text = re.sub(r'<[^>]+>', '', text)
34
+ text = re.sub(r'([!?.])\1+', r'\1', text)
35
+ text = ' '.join(text.split())
36
+ return text.strip()
37
+
38
+ def loadResources():
39
+ global models, tokenizers, maxLengths
40
 
41
+ if len(models) > 0:
42
  return
43
 
44
+ print("loading ensemble models...", flush=True)
 
 
45
 
46
+ modelConfigs = [
47
+ {
48
+ 'filename': 'ensemble_model_1.pth',
49
+ 'type': 'distilbert',
50
+ 'name': 'distilbert-base-uncased',
51
+ 'maxLen': 128
52
+ },
53
+ {
54
+ 'filename': 'ensemble_model_2.pth',
55
+ 'type': 'roberta',
56
+ 'name': 'roberta-base',
57
+ 'maxLen': 192
58
+ },
59
+ {
60
+ 'filename': 'ensemble_model_3.pth',
61
+ 'type': 'bert',
62
+ 'name': 'bert-base-uncased',
63
+ 'maxLen': 256
64
+ }
65
+ ]
66
 
67
+ for i, config in enumerate(modelConfigs, 1):
68
+ try:
69
+ print(f"loading model {i}: {config['type']}", flush=True)
70
+
71
+ modelPath = hf_hub_download(
72
+ repo_id="codingcoolfun9ed/sentinelcheck-models",
73
+ filename=config['filename']
74
+ )
75
+
76
+ if config['type'] == 'distilbert':
77
+ tokenizer = DistilBertTokenizer.from_pretrained(config['name'])
78
+ model = DistilBertForSequenceClassification.from_pretrained(
79
+ config['name'],
80
+ num_labels=2
81
+ )
82
+ elif config['type'] == 'roberta':
83
+ tokenizer = RobertaTokenizer.from_pretrained(config['name'])
84
+ model = RobertaForSequenceClassification.from_pretrained(
85
+ config['name'],
86
+ num_labels=2
87
+ )
88
+ elif config['type'] == 'bert':
89
+ tokenizer = BertTokenizer.from_pretrained(config['name'])
90
+ model = BertForSequenceClassification.from_pretrained(
91
+ config['name'],
92
+ num_labels=2
93
+ )
94
+ else:
95
+ raise ValueError(f"unknown model type: {config['type']}")
96
+
97
+ checkpoint = torch.load(modelPath, map_location=device, weights_only=False)
98
+
99
+ if 'state_dict' not in checkpoint:
100
+ raise ValueError(f"model {i} missing state_dict")
101
+
102
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
103
+ model = model.to(device)
104
+ model.eval()
105
+
106
+ for param in model.parameters():
107
+ param.requires_grad = False
108
+
109
+ models.append(model)
110
+ tokenizers.append(tokenizer)
111
+ maxLengths.append(config['maxLen'])
112
+
113
+ del checkpoint
114
+ gc.collect()
115
+
116
+ print(f"model {i} loaded successfully", flush=True)
117
+
118
+ except Exception as e:
119
+ print(f"error loading model {i}: {str(e)}", flush=True)
120
+ raise
121
 
122
+ print("all ensemble models loaded", flush=True)
123
 
124
+ def ensemblePredict(text):
125
+ loadResources()
126
+
127
+ if not isinstance(text, str):
128
+ text = str(text)
129
+
130
+ text = cleanReview(text)
131
+
132
+ if not validateText(text):
133
+ return {
134
+ 'fakeProb': 0.5,
135
+ 'genuineProb': 0.5,
136
+ 'isFake': None,
137
+ 'agreement': 0.0,
138
+ 'error': 'invalid_text'
139
+ }
140
+
141
+ weightedProbs = torch.zeros(1, 2).to(device)
142
+ allPreds = []
143
+
144
+ try:
145
+ with torch.no_grad():
146
+ for tokenizer, model, maxLen, weight in zip(tokenizers, models, maxLengths, modelWeights):
147
+ inputs = tokenizer(
148
+ text,
149
+ return_tensors='pt',
150
+ truncation=True,
151
+ max_length=maxLen,
152
+ padding='max_length'
153
+ )
154
+ inputIds = inputs['input_ids'].to(device)
155
+ attentionMask = inputs['attention_mask'].to(device)
156
+
157
+ outputs = model(input_ids=inputIds, attention_mask=attentionMask)
158
+ probs = torch.softmax(outputs.logits, dim=1)
159
+ weightedProbs += probs * weight
160
+
161
+ _, pred = torch.max(probs, 1)
162
+ allPreds.append(pred.item())
163
+
164
+ del inputs, inputIds, attentionMask, outputs, probs, pred
165
+
166
+ probs = weightedProbs[0].cpu().numpy()
167
+ genuineProb = float(probs[0])
168
+ fakeProb = float(probs[1])
169
+
170
+ isFake = fakeProb > optimalThreshold
171
+
172
+ finalPred = 1 if isFake else 0
173
+ agreementCount = sum(1 for p in allPreds if p == finalPred)
174
+ agreement = float(agreementCount) / len(allPreds)
175
+
176
+ del weightedProbs, allPreds
177
+ gc.collect()
178
+
179
+ return {
180
+ 'genuineProb': genuineProb,
181
+ 'fakeProb': fakeProb,
182
+ 'isFake': isFake,
183
+ 'agreement': agreement
184
+ }
185
+
186
+ except Exception as e:
187
+ print(f"prediction error: {str(e)}", flush=True)
188
+ return {
189
+ 'fakeProb': 0.5,
190
+ 'genuineProb': 0.5,
191
+ 'isFake': None,
192
+ 'agreement': 0.0,
193
+ 'error': str(e)
194
+ }
195
 
196
  def getLengthCategory(text):
197
+ if not text:
198
+ return 'empty'
199
  words = text.split()
200
  wordCount = len(words)
201
  if wordCount <= 20:
 
210
  return 'very-long'
211
 
212
  def predict_review(text):
213
+ if not text or not isinstance(text, str):
214
+ return {
215
+ "prediction": "error",
216
+ "confidence": 0.0,
217
+ "is_fake": None,
218
+ "model_agreement": 0.0,
219
+ "fake_probability": 0.0,
220
+ "genuine_probability": 0.0,
221
+ "length_category": "empty",
222
+ "token_count": 0,
223
+ "error": "invalid input: text must be non-empty string"
224
+ }
225
 
226
+ cleaned = cleanReview(text)
227
 
228
+ if not cleaned or len(cleaned.strip()) == 0:
229
  return {
230
+ "prediction": "error",
231
  "confidence": 0.0,
232
+ "is_fake": None,
233
+ "model_agreement": 0.0,
234
+ "fake_probability": 0.0,
235
+ "genuine_probability": 0.0,
236
+ "length_category": "empty",
237
+ "token_count": 0,
238
  "error": "empty text after preprocessing"
239
  }
240
 
241
+ result = ensemblePredict(text)
 
 
 
 
 
 
242
 
243
+ if 'error' in result:
244
+ return {
245
+ "prediction": "error",
246
+ "confidence": 0.0,
247
+ "is_fake": None,
248
+ "model_agreement": result['agreement'],
249
+ "fake_probability": result['fakeProb'],
250
+ "genuine_probability": result['genuineProb'],
251
+ "length_category": getLengthCategory(cleaned),
252
+ "token_count": len(cleaned.split()),
253
+ "error": result['error']
254
+ }
255
 
256
+ fakeProb = result['fakeProb']
257
+ genuineProb = result['genuineProb']
258
+ isFake = result['isFake']
259
+ agreement = result['agreement']
260
 
261
+ confidence = max(fakeProb, genuineProb)
262
 
263
+ if agreement < uncertaintyThreshold:
264
  prediction = "uncertain"
265
+ isFakeOutput = None
266
  else:
267
+ prediction = "fake" if isFake else "genuine"
268
+ isFakeOutput = isFake
269
 
270
  lengthCat = getLengthCategory(cleaned)
271
+ tokenCount = len(cleaned.split())
272
 
273
  return {
274
  "prediction": prediction,
275
  "confidence": float(confidence),
276
+ "is_fake": isFakeOutput,
277
+ "model_agreement": float(agreement),
 
278
  "fake_probability": float(fakeProb),
279
+ "genuine_probability": float(genuineProb),
280
+ "length_category": lengthCat,
281
+ "token_count": tokenCount
282
  }