Garrulus21yyx commited on
Commit
c772fc0
·
1 Parent(s): 2a07c48

Add minimal Gradio Space files

Browse files
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+
7
+ from modeling_bert import BertForSequenceClassification
8
+
9
+
10
+ # 当前 app.py 所在目录
11
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12
+
13
+ # 训练完成后保存模型的目录
14
+ MODEL_DIR = os.path.join(BASE_DIR, "experiments")
15
+
16
+ # 如果 Spaces 提供 GPU 就用 GPU,否则自动回退到 CPU
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # 类别 id 到文本标签的映射
20
+ ID2LABEL = {
21
+ 0: "not_disaster",
22
+ 1: "disaster",
23
+ }
24
+
25
+ # 加载 tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
27
+
28
+ # 加载训练好的分类模型
29
+ model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
30
+ model.to(DEVICE)
31
+ model.eval()
32
+
33
+
34
+ def inference(input_text):
35
+ # 处理空输入,避免直接送进模型报错
36
+ input_text = (input_text or "").strip()
37
+ if not input_text:
38
+ return "Please input a sentence."
39
+
40
+ # 把文本编码成模型可接收的输入格式
41
+ # 包括 input_ids 和 attention_mask
42
+ inputs = tokenizer(
43
+ input_text,
44
+ max_length=128,
45
+ truncation=True,
46
+ padding="max_length",
47
+ return_tensors="pt",
48
+ )
49
+
50
+ # 把输入张量移动到和模型相同的设备上
51
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
52
+
53
+ # 推理阶段不需要计算梯度
54
+ with torch.no_grad():
55
+ logits = model(**inputs).logits
56
+
57
+ # 取分数最高的类别作为最终预测
58
+ predicted_class_id = logits.argmax(dim=-1).item()
59
+ output = ID2LABEL[predicted_class_id]
60
+ return output
61
+
62
+
63
+ # 使用 Gradio Blocks 搭建一个简单网页界面
64
+ with gr.Blocks(css="""
65
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
66
+ #component-2 > div.wrap.svelte-w6rprc {height: 600px;}
67
+ """) as demo:
68
+ gr.Markdown("# Disaster Tweet Classifier")
69
+ gr.Markdown("Input a sentence or tweet, and the model will predict whether it describes a real disaster.")
70
+
71
+ # 一行布局,里面放一个输入列
72
+ with gr.Row():
73
+ with gr.Column():
74
+ # 用户输入文本
75
+ input_text = gr.Textbox(
76
+ placeholder="Insert your text here...",
77
+ label="Input Text",
78
+ lines=4,
79
+ )
80
+
81
+ # 显示模型预测结果
82
+ answer = gr.Textbox(label="Prediction")
83
+
84
+ # 点击按钮后触发推理
85
+ generate_bt = gr.Button("Generate")
86
+
87
+ # 把按钮、输入框、输出框和推理函数绑定起来
88
+ generate_bt.click(
89
+ fn=inference,
90
+ inputs=[input_text],
91
+ outputs=[answer],
92
+ show_progress=True,
93
+ )
94
+
95
+ # 提供几个示例,方便在线体验
96
+ gr.Examples(
97
+ examples=[
98
+ ["Forest fire near La Ronge Sask. Canada"],
99
+ ["I love fruits and summer weather."],
100
+ ["There is an emergency evacuation happening now in the building across the street."],
101
+ ],
102
+ inputs=input_text,
103
+ outputs=answer,
104
+ fn=inference,
105
+ cache_examples=False,
106
+ )
107
+
108
+ # 启动 Gradio 服务
109
+ demo.launch()
experiments/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "dtype": "float32",
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 0,
20
+ "position_embedding_type": "absolute",
21
+ "problem_type": "single_label_classification",
22
+ "transformers_version": "4.57.6",
23
+ "type_vocab_size": 2,
24
+ "use_cache": true,
25
+ "vocab_size": 30522
26
+ }
experiments/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f9bd5192084f1fc6013bc7e52e4ee9ead272cc3d1c15593c451f5620946a5d8
3
+ size 437958648
experiments/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
experiments/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
experiments/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
experiments/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_bert.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ import torch.utils.checkpoint
4
+ from torch import nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers import BertPreTrainedModel, BertModel
7
+ from transformers.modeling_outputs import SequenceClassifierOutput
8
+
9
+
10
+ class BertForSequenceClassification(BertPreTrainedModel):
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.num_labels = config.num_labels
14
+ self.config = config
15
+
16
+ # 主干网络仍然是标准 BERT,用它提取整句语义表示。
17
+ self.bert = BertModel(config)
18
+ classifier_dropout = (
19
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
20
+ )
21
+ # 分类头非常简单:dropout + 全连接层。
22
+ self.dropout = nn.Dropout(classifier_dropout)
23
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
24
+
25
+ # Initialize weights and apply final processing
26
+ self.post_init()
27
+
28
+ def forward(
29
+ self,
30
+ input_ids: Optional[torch.Tensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ token_type_ids: Optional[torch.Tensor] = None,
33
+ position_ids: Optional[torch.Tensor] = None,
34
+ head_mask: Optional[torch.Tensor] = None,
35
+ inputs_embeds: Optional[torch.Tensor] = None,
36
+ labels: Optional[torch.Tensor] = None,
37
+ output_attentions: Optional[bool] = None,
38
+ output_hidden_states: Optional[bool] = None,
39
+ return_dict: Optional[bool] = None,
40
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
41
+ r"""
42
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
43
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
44
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
45
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
46
+ """
47
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
48
+
49
+ # 先经过 BERT 编码,得到 token 级表示和 pooled_output。
50
+ outputs = self.bert(
51
+ input_ids,
52
+ attention_mask=attention_mask,
53
+ token_type_ids=token_type_ids,
54
+ position_ids=position_ids,
55
+ head_mask=head_mask,
56
+ inputs_embeds=inputs_embeds,
57
+ output_attentions=output_attentions,
58
+ output_hidden_states=output_hidden_states,
59
+ return_dict=return_dict,
60
+ )
61
+
62
+ # outputs[1] 对应 [CLS] 的句级表示,常用于分类任务。
63
+ pooled_output = outputs[1]
64
+
65
+ pooled_output = self.dropout(pooled_output)
66
+ logits = self.classifier(pooled_output)
67
+
68
+ loss = None
69
+ if labels is not None:
70
+ # 根据任务形式自动选择 loss。
71
+ # 当前数据集是二分类,实际会走 single_label_classification + CrossEntropyLoss。
72
+ if self.config.problem_type is None:
73
+ if self.num_labels == 1:
74
+ self.config.problem_type = "regression"
75
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
76
+ self.config.problem_type = "single_label_classification"
77
+ else:
78
+ self.config.problem_type = "multi_label_classification"
79
+
80
+ if self.config.problem_type == "regression":
81
+ loss_fct = MSELoss()
82
+ if self.num_labels == 1:
83
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
84
+ else:
85
+ loss = loss_fct(logits, labels)
86
+ elif self.config.problem_type == "single_label_classification":
87
+ loss_fct = CrossEntropyLoss()
88
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
89
+ elif self.config.problem_type == "multi_label_classification":
90
+ loss_fct = BCEWithLogitsLoss()
91
+ loss = loss_fct(logits, labels)
92
+ if not return_dict:
93
+ output = (logits,) + outputs[2:]
94
+ return ((loss,) + output) if loss is not None else output
95
+
96
+ return SequenceClassifierOutput(
97
+ loss=loss,
98
+ logits=logits,
99
+ hidden_states=outputs.hidden_states,
100
+ attentions=outputs.attentions,
101
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers>=4.30.2
3
+ torch>=2.0.0
4
+ safetensors
5
+ sentencepiece!=0.1.92