chitlchow commited on
Commit
4d1017a
·
1 Parent(s): af96511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -4,18 +4,21 @@ from src.model import BertClassifier, RobertaClassifier
4
  from transformers import BertTokenizer
5
  from datetime import datetime
6
 
 
 
 
 
 
 
7
 
8
-
9
- def ai_text_classifier(text: str):
10
- device = torch.device('cpu')
11
- model_name = 'bert-base-uncased'
12
- model = BertClassifier(model_name, 0.5)
13
- model.to(device)
14
- model.load_state_dict(torch.load('models/bert-all-data.pth', map_location=device))
15
- tokenizer = BertTokenizer.from_pretrained(model_name)
16
  tokens = tokenizer(text, return_tensors='pt', max_length=512, padding='max_length', truncation=True).to(device)
 
 
17
  prob = model(tokens['input_ids'], tokens['attention_mask']).item()
18
 
 
19
  return {
20
  "AI": prob,
21
  'Others': 1 - prob
 
4
  from transformers import BertTokenizer
5
  from datetime import datetime
6
 
7
+ device = torch.device('cpu')
8
+ model_name = 'bert-base-uncased'
9
+ model = BertClassifier(model_name, 0.5)
10
+ model.to(device)
11
+ model.load_state_dict(torch.load('models/bert-all-data.pth', map_location=device))
12
+ tokenizer = BertTokenizer.from_pretrained(model_name)
13
 
14
+ def ai_text_classifier(text: str):
15
+ # Convert Text into tokens
 
 
 
 
 
 
16
  tokens = tokenizer(text, return_tensors='pt', max_length=512, padding='max_length', truncation=True).to(device)
17
+
18
+ # Get probability of the text
19
  prob = model(tokens['input_ids'], tokens['attention_mask']).item()
20
 
21
+ # Return the probability in dictionary
22
  return {
23
  "AI": prob,
24
  'Others': 1 - prob