k999fff commited on
Commit
be52a7a
·
1 Parent(s): 0bbb0d4

Add new file

Browse files
README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BertForStorySkillClassification
2
+
3
+ ## Model Overview
4
+ `BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:
5
+ 1. **Character**
6
+ 2. **Setting**
7
+ 3. **Feeling**
8
+ 4. **Action**
9
+ 5. **Causal Relationship**
10
+ 6. **Outcome Resolution**
11
+ 7. **Prediction**
12
+
13
+ This model is suitable for applications in education, literary analysis, and story comprehension.
14
+
15
+ ---
16
+
17
+ ## Model Architecture
18
+ - **Base Model**: `bert-base-uncased`
19
+ - **Classification Layer**: A fully connected layer on top of BERT for 7-class classification.
20
+ - **Input**: Question text (e.g., "Who is the main character in the story?" or "why could n't alice get a doll as a child ? \<SEP> because her family was very poor " )
21
+ - **Output**: Predicted label and confidence score.
22
+
23
+ ---
24
+
25
+ ## Quick Start
26
+
27
+ ### Install Dependencies
28
+ Ensure you have the `transformers` library installed:
29
+ ```bash
30
+ pip install transformers
31
+ ```
32
+
33
+ ### Load Model and Tokenizer
34
+
35
+ ```python
36
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
37
+
38
+ model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
39
+ tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")
40
+ ```
41
+
42
+ ### Use the predict Method for Inference
43
+
44
+ ```python
45
+ # Single text prediction
46
+ result = model.predict(
47
+ texts="Where does this story take place?",
48
+ tokenizer=tokenizer,
49
+ return_probabilities=True
50
+ )
51
+ print(result)
52
+ # Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]
53
+
54
+ # Batch prediction
55
+ results = model.predict(
56
+ texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
57
+ tokenizer=tokenizer,
58
+ batch_size=16,
59
+ device="cuda"
60
+ )
61
+ print(results)
62
+ """
63
+ output:
64
+ [{'text': 'Why is the character sad?', 'label': 'causal relationship'},
65
+ {'text': 'How does the story end?', 'label': 'action'},
66
+ {'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
67
+ 'label': 'causal relationship'}]
68
+ """
69
+ ```
70
+
71
+ ## Training Details
72
+ ### Dataset
73
+ Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData)
74
+
75
+ ### Training Parameters
76
+ Learning Rate: 2e-5
77
+ Batch Size: 32
78
+ Epochs: 3
79
+ Optimizer: AdamW
80
+
81
+ ### Performance Metrics
82
+ Accuracy: 97.3%
83
+
84
+ Recall: 96.59%
85
+
86
+ F1 Score: 96.96%
87
+
88
+ ## Notes
89
+ 1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
90
+ 2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
91
+ 3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model.
92
+
93
+ ## Citation
94
+
95
+ If you use this model, please cite the following:
96
+
97
+ ```
98
+ @misc{BertForStorySkillClassification,
99
+ author = {curious},
100
+ title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
101
+ year = {2025},
102
+ publisher = {Hugging Face},
103
+ howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
104
+ }
105
+ ```
106
+
107
+ ## License
108
+ This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file.
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<sep>": 30522
3
+ }
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "revision": "1.0",
3
+ "_name_or_path": "google-bert/bert-base-uncased",
4
+ "architectures": [
5
+ "BertForStorySkillClassification"
6
+ ],
7
+ "auto_map": {
8
+ "AutoModelForSequenceClassification": "modeling_bert_classifier.BertForStorySkillClassification"
9
+ },
10
+ "attention_probs_dropout_prob": 0.1,
11
+ "classifier_dropout": null,
12
+ "gradient_checkpointing": false,
13
+ "hidden_act": "gelu",
14
+ "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 768,
16
+ "id2label": {
17
+ "0": "character",
18
+ "1": "setting",
19
+ "2": "feeling",
20
+ "3": "action",
21
+ "4": "causal relationship",
22
+ "5": "outcome resolution",
23
+ "6": "prediction"
24
+ },
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 3072,
27
+ "label2id": {
28
+ "action": 3,
29
+ "causal relationship": 4,
30
+ "character": 0,
31
+ "feeling": 2,
32
+ "outcome resolution": 5,
33
+ "prediction": 6,
34
+ "setting": 1
35
+ },
36
+ "layer_norm_eps": 1e-12,
37
+ "max_position_embeddings": 512,
38
+ "model_type": "bert",
39
+ "num_attention_heads": 12,
40
+ "num_hidden_layers": 12,
41
+ "pad_token_id": 0,
42
+ "position_embedding_type": "absolute",
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.20.1",
45
+ "type_vocab_size": 2,
46
+ "use_cache": true,
47
+ "vocab_size": 30523
48
+ }
modeling_bert_classifier.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+ from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer
3
+ import torch.nn as nn
4
+ import torch
5
+ class BertForStorySkillClassification(BertPreTrainedModel):
6
+ def __init__(self,config):
7
+ super(BertForStorySkillClassification,self).__init__(config)
8
+ self.num_labels = config.num_labels
9
+ self.bert = BertModel(config)
10
+ self.classifier = nn.Linear(config.hidden_size, self.num_labels)
11
+ self.post_init()
12
+
13
+ def forward(self,input_ids,attention_mask=None,labels=None,**kwargs):
14
+ outputs = self.bert(input_ids,attention_mask=attention_mask)
15
+ cls_hidden_state = outputs.last_hidden_state[:,0,:] ## [batch_size,seq_len,hidden_size]
16
+ logits = self.classifier(cls_hidden_state) ## [batch_size,num_labels]
17
+ if labels is not None:
18
+ loss_fct = nn.CrossEntropyLoss()
19
+ loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1))
20
+ return loss
21
+ return logits
22
+
23
+
24
+ def predict(
25
+ self,
26
+ texts: Union[str, List[str]],
27
+ tokenizer: PreTrainedTokenizer,
28
+ batch_size: int = 32,
29
+ return_probabilities: bool = False,
30
+ device: Union[str, torch.device] = 'cpu',
31
+ ) -> List[Dict]:
32
+ """
33
+ 对输入文本进行分类预测。
34
+
35
+ Args:
36
+ texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"]
37
+ tokenizer: 分词器实例(需与模型兼容)
38
+ batch_size: 批处理大小(提升推理速度)
39
+ return_probabilities: 是否返回概率值(默认返回标签)
40
+ device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备
41
+
42
+ Returns:
43
+ 预测结果列表,格式为:
44
+ [{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...]
45
+ """
46
+ # 自动获取模型所在设备
47
+ if device is None:
48
+ device = self.device
49
+
50
+ # 统一输入格式为列表
51
+ if isinstance(texts, str):
52
+ texts = [texts]
53
+
54
+ # 结果存储
55
+ predictions = []
56
+
57
+ # 批处理预测
58
+ with torch.no_grad():
59
+ for i in range(0, len(texts), batch_size):
60
+ batch_texts = texts[i : i + batch_size]
61
+
62
+ # 分词并转换为张量
63
+ inputs = tokenizer(
64
+ batch_texts,
65
+ padding=True,
66
+ truncation=True,
67
+ return_tensors="pt",
68
+ max_length=512, # 与BERT最大长度一致
69
+ ).to(device)
70
+
71
+ # 模型推理
72
+ logits = self(**inputs)
73
+ probs = torch.softmax(logits, dim=-1)
74
+ scores, class_ids = torch.max(probs, dim=-1)
75
+
76
+ # 转换为标签和分数
77
+ for text, class_id, score in zip(batch_texts, class_ids, scores):
78
+ label = self.config.id2label[class_id.item()]
79
+ result = {"text": text, "label": label}
80
+ if return_probabilities:
81
+ result["score"] = score.item()
82
+ predictions.append(result)
83
+
84
+ return predictions
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a93fb540b17ffb5e5f40dae50d13001f35fdc30f25e1029b570209349b00d0d
3
+ size 438021741
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "name_or_path": "/remote-home/CS_IMIAPD_chensong22/python/weights/google-bert/bert-base-uncased",
7
+ "never_split": null,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "special_tokens_map_file": null,
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff