File size: 3,090 Bytes
c772fc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os

import gradio as gr
import torch
from transformers import AutoTokenizer

from modeling_bert import BertForSequenceClassification


# 当前 app.py 所在目录
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# 训练完成后保存模型的目录
MODEL_DIR = os.path.join(BASE_DIR, "experiments")

# 如果 Spaces 提供 GPU 就用 GPU,否则自动回退到 CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 类别 id 到文本标签的映射
ID2LABEL = {
    0: "not_disaster",
    1: "disaster",
}

# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# 加载训练好的分类模型
model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
model.to(DEVICE)
model.eval()


def inference(input_text):
    # 处理空输入,避免直接送进模型报错
    input_text = (input_text or "").strip()
    if not input_text:
        return "Please input a sentence."

    # 把文本编码成模型可接收的输入格式
    # 包括 input_ids 和 attention_mask
    inputs = tokenizer(
        input_text,
        max_length=128,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )

    # 把输入张量移动到和模型相同的设备上
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    # 推理阶段不需要计算梯度
    with torch.no_grad():
        logits = model(**inputs).logits

    # 取分数最高的类别作为最终预测
    predicted_class_id = logits.argmax(dim=-1).item()
    output = ID2LABEL[predicted_class_id]
    return output


# 使用 Gradio Blocks 搭建一个简单网页界面
with gr.Blocks(css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-2 > div.wrap.svelte-w6rprc {height: 600px;}
""") as demo:
    gr.Markdown("# Disaster Tweet Classifier")
    gr.Markdown("Input a sentence or tweet, and the model will predict whether it describes a real disaster.")

    # 一行布局,里面放一个输入列
    with gr.Row():
        with gr.Column():
            # 用户输入文本
            input_text = gr.Textbox(
                placeholder="Insert your text here...",
                label="Input Text",
                lines=4,
            )

            # 显示模型预测结果
            answer = gr.Textbox(label="Prediction")

            # 点击按钮后触发推理
            generate_bt = gr.Button("Generate")

    # 把按钮、输入框、输出框和推理函数绑定起来
    generate_bt.click(
        fn=inference,
        inputs=[input_text],
        outputs=[answer],
        show_progress=True,
    )

    # 提供几个示例,方便在线体验
    gr.Examples(
        examples=[
            ["Forest fire near La Ronge Sask. Canada"],
            ["I love fruits and summer weather."],
            ["There is an emergency evacuation happening now in the building across the street."],
        ],
        inputs=input_text,
        outputs=answer,
        fn=inference,
        cache_examples=False,
    )

# 启动 Gradio 服务
demo.launch()