codingcoolfun9ed commited on
Commit
8eec530
·
verified ·
1 Parent(s): 53b725a

updating this in prep to ship new single model with different hyperparameters

Browse files
Files changed (1) hide show
  1. api/predict.py +34 -42
api/predict.py CHANGED
@@ -6,42 +6,34 @@ from huggingface_hub import hf_hub_download
6
 
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  tokenizer = None
9
- models = None
10
 
11
  def load_resources():
12
- global tokenizer, models
13
 
14
- if tokenizer is not None and models is not None:
15
  return
16
 
17
- print("loading models...")
18
 
19
  tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
20
 
21
- num_classes = 2
22
- dropout = 0.4
23
-
24
- models = []
25
- for i in range(1, 6):
26
- model_filename = f"ensemble_model_{i}.pth"
27
-
28
- print(f"downloading {model_filename}...")
29
- model_path = hf_hub_download(
30
- repo_id="codingcoolfun9ed/sentinelcheck-models",
31
- filename=model_filename
32
- )
33
-
34
- model = DistilBertForSequenceClassification.from_pretrained(
35
- 'distilbert-base-uncased',
36
- num_labels=num_classes,
37
- dropout=dropout
38
- )
39
- model.load_state_dict(torch.load(model_path, map_location=device))
40
- model = model.to(device)
41
- model.eval()
42
- models.append(model)
43
-
44
- print("models loaded")
45
 
46
  def cleanText(text):
47
  if not text:
@@ -88,33 +80,33 @@ def predict_review(text):
88
  return_tensors='pt'
89
  )
90
 
91
- input_ids = encoding['input_ids'].to(device)
92
- attention_mask = encoding['attention_mask'].to(device)
93
 
94
- allOutputs = []
95
  with torch.no_grad():
96
- for model in models:
97
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
98
- probs = torch.softmax(outputs.logits, dim=1)
99
- allOutputs.append(probs.cpu().numpy())
100
 
101
- avgProbs = np.mean(allOutputs, axis=0)[0]
102
- fakeProb = avgProbs[1]
103
- realProb = avgProbs[0]
104
 
105
- isFake = fakeProb > 0.75
106
  confidence = max(fakeProb, realProb)
107
- prediction = "fake" if isFake else "real"
108
 
109
  if confidence < 0.75:
110
  prediction = "uncertain"
 
 
 
 
111
 
112
  lengthCat = getLengthCategory(cleaned)
113
 
114
  return {
115
  "prediction": prediction,
116
  "confidence": float(confidence),
117
- "is_fake": bool(isFake),
118
  "length_category": lengthCat,
119
- "token_count": len(cleaned.split())
 
 
120
  }
 
6
 
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  tokenizer = None
9
+ model = None
10
 
11
  def load_resources():
12
+ global tokenizer, model
13
 
14
+ if tokenizer is not None and model is not None:
15
  return
16
 
17
+ print("loading model...")
18
 
19
  tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
20
 
21
+ print("downloading model_2.pth...")
22
+ modelPath = hf_hub_download(
23
+ repo_id="codingcoolfun9ed/sentinelcheck-models",
24
+ filename="model_2.pth"
25
+ )
26
+
27
+ model = DistilBertForSequenceClassification.from_pretrained(
28
+ 'distilbert-base-uncased',
29
+ num_labels=2,
30
+ dropout=0.4
31
+ )
32
+ model.load_state_dict(torch.load(modelPath, map_location=device))
33
+ model = model.to(device)
34
+ model.eval()
35
+
36
+ print("model loaded")
 
 
 
 
 
 
 
 
37
 
38
  def cleanText(text):
39
  if not text:
 
80
  return_tensors='pt'
81
  )
82
 
83
+ inputIds = encoding['input_ids'].to(device)
84
+ attentionMask = encoding['attention_mask'].to(device)
85
 
 
86
  with torch.no_grad():
87
+ outputs = model(input_ids=inputIds, attention_mask=attentionMask)
88
+ probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
 
 
89
 
90
+ fakeProb = probs[1]
91
+ realProb = probs[0]
 
92
 
 
93
  confidence = max(fakeProb, realProb)
 
94
 
95
  if confidence < 0.75:
96
  prediction = "uncertain"
97
+ isFake = None
98
+ else:
99
+ isFake = fakeProb > realProb
100
+ prediction = "fake" if isFake else "real"
101
 
102
  lengthCat = getLengthCategory(cleaned)
103
 
104
  return {
105
  "prediction": prediction,
106
  "confidence": float(confidence),
107
+ "is_fake": isFake,
108
  "length_category": lengthCat,
109
+ "token_count": len(cleaned.split()),
110
+ "fake_probability": float(fakeProb),
111
+ "real_probability": float(realProb)
112
  }