|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "." |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
id2label = {0: "非灾难相关", 1: "真实灾难"} |
|
|
|
|
|
|
|
|
|
|
|
def classify_text(text): |
|
|
""" |
|
|
接收文本输入,返回一个包含各类标签及其概率的字典。 |
|
|
""" |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) |
|
|
|
|
|
|
|
|
inputs = {key: val.to(device) for key, val in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
confidences = {id2label[i]: float(prob) for i, prob in enumerate(probabilities)} |
|
|
|
|
|
return confidences |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_text, |
|
|
inputs=gr.Textbox(lines=5, placeholder="在这里输入一条英文推文..."), |
|
|
outputs=gr.Label(num_top_classes=2), |
|
|
title="Tweet Disaster Classifier", |
|
|
description="This is a text classification model based on BERT, which is used to determine whether a tweet describes a real disaster event. The model is trained on the 'nlp-getting-started' dataset from Kaggle.", |
|
|
examples=[ |
|
|
["Forest fire near La Ronge Sask. Canada"], |
|
|
["Just got a new job! This is the bomb!"], |
|
|
["My house is on fire!"], |
|
|
["I love my cat."] |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|