Add new file
Browse files- README.md +108 -0
- added_tokens.json +3 -0
- config.json +48 -0
- modeling_bert_classifier.py +84 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +7 -0
- tokenizer_config.json +15 -0
- vocab.txt +0 -0
README.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BertForStorySkillClassification
|
| 2 |
+
|
| 3 |
+
## Model Overview
|
| 4 |
+
`BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:
|
| 5 |
+
1. **Character**
|
| 6 |
+
2. **Setting**
|
| 7 |
+
3. **Feeling**
|
| 8 |
+
4. **Action**
|
| 9 |
+
5. **Causal Relationship**
|
| 10 |
+
6. **Outcome Resolution**
|
| 11 |
+
7. **Prediction**
|
| 12 |
+
|
| 13 |
+
This model is suitable for applications in education, literary analysis, and story comprehension.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## Model Architecture
|
| 18 |
+
- **Base Model**: `bert-base-uncased`
|
| 19 |
+
- **Classification Layer**: A fully connected layer on top of BERT for 7-class classification.
|
| 20 |
+
- **Input**: Question text (e.g., "Who is the main character in the story?" or "why could n't alice get a doll as a child ? \<SEP> because her family was very poor " )
|
| 21 |
+
- **Output**: Predicted label and confidence score.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Quick Start
|
| 26 |
+
|
| 27 |
+
### Install Dependencies
|
| 28 |
+
Ensure you have the `transformers` library installed:
|
| 29 |
+
```bash
|
| 30 |
+
pip install transformers
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Load Model and Tokenizer
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 37 |
+
|
| 38 |
+
model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### Use the predict Method for Inference
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
# Single text prediction
|
| 46 |
+
result = model.predict(
|
| 47 |
+
texts="Where does this story take place?",
|
| 48 |
+
tokenizer=tokenizer,
|
| 49 |
+
return_probabilities=True
|
| 50 |
+
)
|
| 51 |
+
print(result)
|
| 52 |
+
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]
|
| 53 |
+
|
| 54 |
+
# Batch prediction
|
| 55 |
+
results = model.predict(
|
| 56 |
+
texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
|
| 57 |
+
tokenizer=tokenizer,
|
| 58 |
+
batch_size=16,
|
| 59 |
+
device="cuda"
|
| 60 |
+
)
|
| 61 |
+
print(results)
|
| 62 |
+
"""
|
| 63 |
+
output:
|
| 64 |
+
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
|
| 65 |
+
{'text': 'How does the story end?', 'label': 'action'},
|
| 66 |
+
{'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
|
| 67 |
+
'label': 'causal relationship'}]
|
| 68 |
+
"""
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Training Details
|
| 72 |
+
### Dataset
|
| 73 |
+
Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData)
|
| 74 |
+
|
| 75 |
+
### Training Parameters
|
| 76 |
+
Learning Rate: 2e-5
|
| 77 |
+
Batch Size: 32
|
| 78 |
+
Epochs: 3
|
| 79 |
+
Optimizer: AdamW
|
| 80 |
+
|
| 81 |
+
### Performance Metrics
|
| 82 |
+
Accuracy: 97.3%
|
| 83 |
+
|
| 84 |
+
Recall: 96.59%
|
| 85 |
+
|
| 86 |
+
F1 Score: 96.96%
|
| 87 |
+
|
| 88 |
+
## Notes
|
| 89 |
+
1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
|
| 90 |
+
2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
|
| 91 |
+
3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model.
|
| 92 |
+
|
| 93 |
+
## Citation
|
| 94 |
+
|
| 95 |
+
If you use this model, please cite the following:
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
@misc{BertForStorySkillClassification,
|
| 99 |
+
author = {curious},
|
| 100 |
+
title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
|
| 101 |
+
year = {2025},
|
| 102 |
+
publisher = {Hugging Face},
|
| 103 |
+
howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
|
| 104 |
+
}
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## License
|
| 108 |
+
This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file.
|
added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<sep>": 30522
|
| 3 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"revision": "1.0",
|
| 3 |
+
"_name_or_path": "google-bert/bert-base-uncased",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"BertForStorySkillClassification"
|
| 6 |
+
],
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoModelForSequenceClassification": "modeling_bert_classifier.BertForStorySkillClassification"
|
| 9 |
+
},
|
| 10 |
+
"attention_probs_dropout_prob": 0.1,
|
| 11 |
+
"classifier_dropout": null,
|
| 12 |
+
"gradient_checkpointing": false,
|
| 13 |
+
"hidden_act": "gelu",
|
| 14 |
+
"hidden_dropout_prob": 0.1,
|
| 15 |
+
"hidden_size": 768,
|
| 16 |
+
"id2label": {
|
| 17 |
+
"0": "character",
|
| 18 |
+
"1": "setting",
|
| 19 |
+
"2": "feeling",
|
| 20 |
+
"3": "action",
|
| 21 |
+
"4": "causal relationship",
|
| 22 |
+
"5": "outcome resolution",
|
| 23 |
+
"6": "prediction"
|
| 24 |
+
},
|
| 25 |
+
"initializer_range": 0.02,
|
| 26 |
+
"intermediate_size": 3072,
|
| 27 |
+
"label2id": {
|
| 28 |
+
"action": 3,
|
| 29 |
+
"causal relationship": 4,
|
| 30 |
+
"character": 0,
|
| 31 |
+
"feeling": 2,
|
| 32 |
+
"outcome resolution": 5,
|
| 33 |
+
"prediction": 6,
|
| 34 |
+
"setting": 1
|
| 35 |
+
},
|
| 36 |
+
"layer_norm_eps": 1e-12,
|
| 37 |
+
"max_position_embeddings": 512,
|
| 38 |
+
"model_type": "bert",
|
| 39 |
+
"num_attention_heads": 12,
|
| 40 |
+
"num_hidden_layers": 12,
|
| 41 |
+
"pad_token_id": 0,
|
| 42 |
+
"position_embedding_type": "absolute",
|
| 43 |
+
"torch_dtype": "float32",
|
| 44 |
+
"transformers_version": "4.20.1",
|
| 45 |
+
"type_vocab_size": 2,
|
| 46 |
+
"use_cache": true,
|
| 47 |
+
"vocab_size": 30523
|
| 48 |
+
}
|
modeling_bert_classifier.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Union
|
| 2 |
+
from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
class BertForStorySkillClassification(BertPreTrainedModel):
|
| 6 |
+
def __init__(self,config):
|
| 7 |
+
super(BertForStorySkillClassification,self).__init__(config)
|
| 8 |
+
self.num_labels = config.num_labels
|
| 9 |
+
self.bert = BertModel(config)
|
| 10 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
| 11 |
+
self.post_init()
|
| 12 |
+
|
| 13 |
+
def forward(self,input_ids,attention_mask=None,labels=None,**kwargs):
|
| 14 |
+
outputs = self.bert(input_ids,attention_mask=attention_mask)
|
| 15 |
+
cls_hidden_state = outputs.last_hidden_state[:,0,:] ## [batch_size,seq_len,hidden_size]
|
| 16 |
+
logits = self.classifier(cls_hidden_state) ## [batch_size,num_labels]
|
| 17 |
+
if labels is not None:
|
| 18 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 19 |
+
loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1))
|
| 20 |
+
return loss
|
| 21 |
+
return logits
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def predict(
|
| 25 |
+
self,
|
| 26 |
+
texts: Union[str, List[str]],
|
| 27 |
+
tokenizer: PreTrainedTokenizer,
|
| 28 |
+
batch_size: int = 32,
|
| 29 |
+
return_probabilities: bool = False,
|
| 30 |
+
device: Union[str, torch.device] = 'cpu',
|
| 31 |
+
) -> List[Dict]:
|
| 32 |
+
"""
|
| 33 |
+
对输入文本进行分类预测。
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"]
|
| 37 |
+
tokenizer: 分词器实例(需与模型兼容)
|
| 38 |
+
batch_size: 批处理大小(提升推理速度)
|
| 39 |
+
return_probabilities: 是否返回概率值(默认返回标签)
|
| 40 |
+
device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
预测结果列表,格式为:
|
| 44 |
+
[{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...]
|
| 45 |
+
"""
|
| 46 |
+
# 自动获取模型所在设备
|
| 47 |
+
if device is None:
|
| 48 |
+
device = self.device
|
| 49 |
+
|
| 50 |
+
# 统一输入格式为列表
|
| 51 |
+
if isinstance(texts, str):
|
| 52 |
+
texts = [texts]
|
| 53 |
+
|
| 54 |
+
# 结果存储
|
| 55 |
+
predictions = []
|
| 56 |
+
|
| 57 |
+
# 批处理预测
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
for i in range(0, len(texts), batch_size):
|
| 60 |
+
batch_texts = texts[i : i + batch_size]
|
| 61 |
+
|
| 62 |
+
# 分词并转换为张量
|
| 63 |
+
inputs = tokenizer(
|
| 64 |
+
batch_texts,
|
| 65 |
+
padding=True,
|
| 66 |
+
truncation=True,
|
| 67 |
+
return_tensors="pt",
|
| 68 |
+
max_length=512, # 与BERT最大长度一致
|
| 69 |
+
).to(device)
|
| 70 |
+
|
| 71 |
+
# 模型推理
|
| 72 |
+
logits = self(**inputs)
|
| 73 |
+
probs = torch.softmax(logits, dim=-1)
|
| 74 |
+
scores, class_ids = torch.max(probs, dim=-1)
|
| 75 |
+
|
| 76 |
+
# 转换为标签和分数
|
| 77 |
+
for text, class_id, score in zip(batch_texts, class_ids, scores):
|
| 78 |
+
label = self.config.id2label[class_id.item()]
|
| 79 |
+
result = {"text": text, "label": label}
|
| 80 |
+
if return_probabilities:
|
| 81 |
+
result["score"] = score.item()
|
| 82 |
+
predictions.append(result)
|
| 83 |
+
|
| 84 |
+
return predictions
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a93fb540b17ffb5e5f40dae50d13001f35fdc30f25e1029b570209349b00d0d
|
| 3 |
+
size 438021741
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"do_basic_tokenize": true,
|
| 4 |
+
"do_lower_case": true,
|
| 5 |
+
"mask_token": "[MASK]",
|
| 6 |
+
"name_or_path": "/remote-home/CS_IMIAPD_chensong22/python/weights/google-bert/bert-base-uncased",
|
| 7 |
+
"never_split": null,
|
| 8 |
+
"pad_token": "[PAD]",
|
| 9 |
+
"sep_token": "[SEP]",
|
| 10 |
+
"special_tokens_map_file": null,
|
| 11 |
+
"strip_accents": null,
|
| 12 |
+
"tokenize_chinese_chars": true,
|
| 13 |
+
"tokenizer_class": "BertTokenizer",
|
| 14 |
+
"unk_token": "[UNK]"
|
| 15 |
+
}
|
vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|