| |
| import torch |
| from transformers import BertTokenizer |
| import gradio as gr |
| from model import BertClassifier |
|
|
| model = BertClassifier() |
| state_dict = torch.load("bert_mc_dropout.pt", map_location="cpu") |
|
|
| |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
|
|
| |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
| |
| label_map = { |
| 0: "World", |
| 1: "Sports", |
| 2: "Business", |
| 3: "Sci/Tech" |
| } |
|
|
| |
| def predict(text): |
| tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
| with torch.no_grad(): |
| outputs = model(tokens["input_ids"], tokens["attention_mask"]) |
| probs = torch.softmax(outputs, dim=1) |
| conf, predicted = torch.max(probs, dim=1) |
| return f"{label_map[predicted.item()]} (confidence: {conf.item():.2f})" |
|
|
| |
| demo = gr.Interface(fn=predict, inputs=gr.Textbox(lines=4, placeholder="Enter your news text..."), outputs="text") |
| demo.launch() |