File size: 924 Bytes
2e14f20
1d71f97
40191c9
643a0fa
 
 
4d1017a
 
 
 
 
 
643a0fa
22a3411
4d1017a
692b5ed
4d1017a
 
f039e46
a0c91c2
4d1017a
a955f77
1e7929c
 
a955f77
2e14f20
1d71f97
643a0fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import gradio as gr
import torch
from src.model import BertClassifier, RobertaClassifier
from transformers import BertTokenizer
from datetime import datetime

device = torch.device('cpu')
model_name = 'bert-base-uncased'
model = BertClassifier(model_name, 0.5)
model.to(device)
model.load_state_dict(torch.load('models/bert-all-data.pth', map_location=device))
tokenizer = BertTokenizer.from_pretrained(model_name)

def ai_text_classifier(text: str) -> dict:    
    # Convert Text into tokens
    tokens = tokenizer(text, return_tensors='pt', max_length=512, padding='max_length', truncation=True).to(device)
    
    # Get probability of the text
    prob = model(tokens['input_ids'], tokens['attention_mask']).item()
    
    # Return the probability in dictionary
    return {
        "AI": prob, 
        'Others': 1 - prob
    }

demo = gr.Interface(fn=ai_text_classifier, inputs="text", outputs="label")
demo.launch()