MC / app.py
mirralz's picture
Update app.py
9bd8832 verified
# app.py
import torch
from transformers import BertTokenizer
import gradio as gr
from model import BertClassifier
model = BertClassifier()
state_dict = torch.load("bert_mc_dropout.pt", map_location="cpu")
# подгружаем корректно
model.load_state_dict(state_dict, strict=False)
model.eval()
# Токенизатор
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Метки классов AG News
label_map = {
0: "World",
1: "Sports",
2: "Business",
3: "Sci/Tech"
}
# Функция предсказания
def predict(text):
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(tokens["input_ids"], tokens["attention_mask"])
probs = torch.softmax(outputs, dim=1)
conf, predicted = torch.max(probs, dim=1)
return f"{label_map[predicted.item()]} (confidence: {conf.item():.2f})"
# Интерфейс
demo = gr.Interface(fn=predict, inputs=gr.Textbox(lines=4, placeholder="Enter your news text..."), outputs="text")
demo.launch()