Spaces:
Sleeping
Sleeping
Garrulus21yyx commited on
Commit ·
c772fc0
1
Parent(s): 2a07c48
Add minimal Gradio Space files
Browse files- app.py +109 -0
- experiments/config.json +26 -0
- experiments/model.safetensors +3 -0
- experiments/special_tokens_map.json +7 -0
- experiments/tokenizer.json +0 -0
- experiments/tokenizer_config.json +56 -0
- experiments/vocab.txt +0 -0
- modeling_bert.py +101 -0
- requirements.txt +5 -0
app.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
from modeling_bert import BertForSequenceClassification
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# 当前 app.py 所在目录
|
| 11 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
|
| 13 |
+
# 训练完成后保存模型的目录
|
| 14 |
+
MODEL_DIR = os.path.join(BASE_DIR, "experiments")
|
| 15 |
+
|
| 16 |
+
# 如果 Spaces 提供 GPU 就用 GPU,否则自动回退到 CPU
|
| 17 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
|
| 19 |
+
# 类别 id 到文本标签的映射
|
| 20 |
+
ID2LABEL = {
|
| 21 |
+
0: "not_disaster",
|
| 22 |
+
1: "disaster",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# 加载 tokenizer
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
| 27 |
+
|
| 28 |
+
# 加载训练好的分类模型
|
| 29 |
+
model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
|
| 30 |
+
model.to(DEVICE)
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def inference(input_text):
|
| 35 |
+
# 处理空输入,避免直接送进模型报错
|
| 36 |
+
input_text = (input_text or "").strip()
|
| 37 |
+
if not input_text:
|
| 38 |
+
return "Please input a sentence."
|
| 39 |
+
|
| 40 |
+
# 把文本编码成模型可接收的输入格式
|
| 41 |
+
# 包括 input_ids 和 attention_mask
|
| 42 |
+
inputs = tokenizer(
|
| 43 |
+
input_text,
|
| 44 |
+
max_length=128,
|
| 45 |
+
truncation=True,
|
| 46 |
+
padding="max_length",
|
| 47 |
+
return_tensors="pt",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# 把输入张量移动到和模型相同的设备上
|
| 51 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 52 |
+
|
| 53 |
+
# 推理阶段不需要计算梯度
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
logits = model(**inputs).logits
|
| 56 |
+
|
| 57 |
+
# 取分数最高的类别作为最终预测
|
| 58 |
+
predicted_class_id = logits.argmax(dim=-1).item()
|
| 59 |
+
output = ID2LABEL[predicted_class_id]
|
| 60 |
+
return output
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# 使用 Gradio Blocks 搭建一个简单网页界面
|
| 64 |
+
with gr.Blocks(css="""
|
| 65 |
+
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
|
| 66 |
+
#component-2 > div.wrap.svelte-w6rprc {height: 600px;}
|
| 67 |
+
""") as demo:
|
| 68 |
+
gr.Markdown("# Disaster Tweet Classifier")
|
| 69 |
+
gr.Markdown("Input a sentence or tweet, and the model will predict whether it describes a real disaster.")
|
| 70 |
+
|
| 71 |
+
# 一行布局,里面放一个输入列
|
| 72 |
+
with gr.Row():
|
| 73 |
+
with gr.Column():
|
| 74 |
+
# 用户输入文本
|
| 75 |
+
input_text = gr.Textbox(
|
| 76 |
+
placeholder="Insert your text here...",
|
| 77 |
+
label="Input Text",
|
| 78 |
+
lines=4,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# 显示模型预测结果
|
| 82 |
+
answer = gr.Textbox(label="Prediction")
|
| 83 |
+
|
| 84 |
+
# 点击按钮后触发推理
|
| 85 |
+
generate_bt = gr.Button("Generate")
|
| 86 |
+
|
| 87 |
+
# 把按钮、输入框、输出框和推理函数绑定起来
|
| 88 |
+
generate_bt.click(
|
| 89 |
+
fn=inference,
|
| 90 |
+
inputs=[input_text],
|
| 91 |
+
outputs=[answer],
|
| 92 |
+
show_progress=True,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# 提供几个示例,方便在线体验
|
| 96 |
+
gr.Examples(
|
| 97 |
+
examples=[
|
| 98 |
+
["Forest fire near La Ronge Sask. Canada"],
|
| 99 |
+
["I love fruits and summer weather."],
|
| 100 |
+
["There is an emergency evacuation happening now in the building across the street."],
|
| 101 |
+
],
|
| 102 |
+
inputs=input_text,
|
| 103 |
+
outputs=answer,
|
| 104 |
+
fn=inference,
|
| 105 |
+
cache_examples=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# 启动 Gradio 服务
|
| 109 |
+
demo.launch()
|
experiments/config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"layer_norm_eps": 1e-12,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "bert",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"position_embedding_type": "absolute",
|
| 21 |
+
"problem_type": "single_label_classification",
|
| 22 |
+
"transformers_version": "4.57.6",
|
| 23 |
+
"type_vocab_size": 2,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"vocab_size": 30522
|
| 26 |
+
}
|
experiments/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f9bd5192084f1fc6013bc7e52e4ee9ead272cc3d1c15593c451f5620946a5d8
|
| 3 |
+
size 437958648
|
experiments/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
experiments/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiments/tokenizer_config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_lower_case": true,
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "[MASK]",
|
| 49 |
+
"model_max_length": 512,
|
| 50 |
+
"pad_token": "[PAD]",
|
| 51 |
+
"sep_token": "[SEP]",
|
| 52 |
+
"strip_accents": null,
|
| 53 |
+
"tokenize_chinese_chars": true,
|
| 54 |
+
"tokenizer_class": "BertTokenizer",
|
| 55 |
+
"unk_token": "[UNK]"
|
| 56 |
+
}
|
experiments/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_bert.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.checkpoint
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 6 |
+
from transformers import BertPreTrainedModel, BertModel
|
| 7 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
| 11 |
+
def __init__(self, config):
|
| 12 |
+
super().__init__(config)
|
| 13 |
+
self.num_labels = config.num_labels
|
| 14 |
+
self.config = config
|
| 15 |
+
|
| 16 |
+
# 主干网络仍然是标准 BERT,用它提取整句语义表示。
|
| 17 |
+
self.bert = BertModel(config)
|
| 18 |
+
classifier_dropout = (
|
| 19 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 20 |
+
)
|
| 21 |
+
# 分类头非常简单:dropout + 全连接层。
|
| 22 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 23 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 24 |
+
|
| 25 |
+
# Initialize weights and apply final processing
|
| 26 |
+
self.post_init()
|
| 27 |
+
|
| 28 |
+
def forward(
|
| 29 |
+
self,
|
| 30 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 31 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 32 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 33 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 34 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 35 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 36 |
+
labels: Optional[torch.Tensor] = None,
|
| 37 |
+
output_attentions: Optional[bool] = None,
|
| 38 |
+
output_hidden_states: Optional[bool] = None,
|
| 39 |
+
return_dict: Optional[bool] = None,
|
| 40 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 41 |
+
r"""
|
| 42 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 43 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 44 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 45 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 46 |
+
"""
|
| 47 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 48 |
+
|
| 49 |
+
# 先经过 BERT 编码,得到 token 级表示和 pooled_output。
|
| 50 |
+
outputs = self.bert(
|
| 51 |
+
input_ids,
|
| 52 |
+
attention_mask=attention_mask,
|
| 53 |
+
token_type_ids=token_type_ids,
|
| 54 |
+
position_ids=position_ids,
|
| 55 |
+
head_mask=head_mask,
|
| 56 |
+
inputs_embeds=inputs_embeds,
|
| 57 |
+
output_attentions=output_attentions,
|
| 58 |
+
output_hidden_states=output_hidden_states,
|
| 59 |
+
return_dict=return_dict,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# outputs[1] 对应 [CLS] 的句级表示,常用于分类任务。
|
| 63 |
+
pooled_output = outputs[1]
|
| 64 |
+
|
| 65 |
+
pooled_output = self.dropout(pooled_output)
|
| 66 |
+
logits = self.classifier(pooled_output)
|
| 67 |
+
|
| 68 |
+
loss = None
|
| 69 |
+
if labels is not None:
|
| 70 |
+
# 根据任务形式自动选择 loss。
|
| 71 |
+
# 当前数据集是二分类,实际会走 single_label_classification + CrossEntropyLoss。
|
| 72 |
+
if self.config.problem_type is None:
|
| 73 |
+
if self.num_labels == 1:
|
| 74 |
+
self.config.problem_type = "regression"
|
| 75 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 76 |
+
self.config.problem_type = "single_label_classification"
|
| 77 |
+
else:
|
| 78 |
+
self.config.problem_type = "multi_label_classification"
|
| 79 |
+
|
| 80 |
+
if self.config.problem_type == "regression":
|
| 81 |
+
loss_fct = MSELoss()
|
| 82 |
+
if self.num_labels == 1:
|
| 83 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 84 |
+
else:
|
| 85 |
+
loss = loss_fct(logits, labels)
|
| 86 |
+
elif self.config.problem_type == "single_label_classification":
|
| 87 |
+
loss_fct = CrossEntropyLoss()
|
| 88 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 89 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 90 |
+
loss_fct = BCEWithLogitsLoss()
|
| 91 |
+
loss = loss_fct(logits, labels)
|
| 92 |
+
if not return_dict:
|
| 93 |
+
output = (logits,) + outputs[2:]
|
| 94 |
+
return ((loss,) + output) if loss is not None else output
|
| 95 |
+
|
| 96 |
+
return SequenceClassifierOutput(
|
| 97 |
+
loss=loss,
|
| 98 |
+
logits=logits,
|
| 99 |
+
hidden_states=outputs.hidden_states,
|
| 100 |
+
attentions=outputs.attentions,
|
| 101 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
transformers>=4.30.2
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
safetensors
|
| 5 |
+
sentencepiece!=0.1.92
|