chitlchow's picture
Update app.py
22a3411
raw
history blame contribute delete
924 Bytes
import gradio as gr
import torch
from src.model import BertClassifier, RobertaClassifier
from transformers import BertTokenizer
from datetime import datetime
device = torch.device('cpu')
model_name = 'bert-base-uncased'
model = BertClassifier(model_name, 0.5)
model.to(device)
model.load_state_dict(torch.load('models/bert-all-data.pth', map_location=device))
tokenizer = BertTokenizer.from_pretrained(model_name)
def ai_text_classifier(text: str) -> dict:
# Convert Text into tokens
tokens = tokenizer(text, return_tensors='pt', max_length=512, padding='max_length', truncation=True).to(device)
# Get probability of the text
prob = model(tokens['input_ids'], tokens['attention_mask']).item()
# Return the probability in dictionary
return {
"AI": prob,
'Others': 1 - prob
}
demo = gr.Interface(fn=ai_text_classifier, inputs="text", outputs="label")
demo.launch()