ventuero commited on
Commit
43b64a9
·
1 Parent(s): 7d4bc98
Files changed (1) hide show
  1. src/antispam_api/loader.py +10 -6
src/antispam_api/loader.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
 
4
  from antispam_api.config import MODEL_NAME
@@ -6,12 +7,15 @@ from antispam_api.config import MODEL_NAME
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForSequenceClassification.from_pretrained(
10
- MODEL_NAME, num_labels=1,
11
- ).to(device).eval()
 
 
 
12
 
13
 
14
- def get_spam_score(text: str):
15
  encoding = tokenizer(
16
  text,
17
  padding="max_length",
@@ -24,6 +28,6 @@ def get_spam_score(text: str):
24
 
25
  with torch.no_grad():
26
  logits = model(input_ids, attention_mask=attention_mask).logits
27
- score = torch.sigmoid(logits).cpu().numpy()[0][0]
28
 
29
- return float(score)
 
1
  import torch
2
+ import torch.nn.functional as F # noqa: N812
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
 
5
  from antispam_api.config import MODEL_NAME
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = (
11
+ AutoModelForSequenceClassification
12
+ .from_pretrained(MODEL_NAME)
13
+ .to(device)
14
+ .eval()
15
+ )
16
 
17
 
18
+ def get_scam_score(text: str) -> float:
19
  encoding = tokenizer(
20
  text,
21
  padding="max_length",
 
28
 
29
  with torch.no_grad():
30
  logits = model(input_ids, attention_mask=attention_mask).logits
31
+ probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
32
 
33
+ return float(probs[1])