File size: 3,632 Bytes
439a4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be52a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9979bb
be52a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
---
license: apache-2.0
datasets:
- WorkInTheDark/FairytaleQA
language:
- en
metrics:
- f1
- accuracy
- recall
base_model:
- google-bert/bert-base-uncased
pipeline_tag: text-classification
library_name: transformers
---
# BertForStorySkillClassification

## Model Overview
`BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:
1. **Character**
2. **Setting**
3. **Feeling**
4. **Action**
5. **Causal Relationship**
6. **Outcome Resolution**
7. **Prediction**

This model is suitable for applications in education, literary analysis, and story comprehension.

---

## Model Architecture
- **Base Model**: `bert-base-uncased`
- **Classification Layer**: A fully connected layer on top of BERT for 7-class classification.
- **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor \<context> alice is ... ")
- **Output**: Predicted label and confidence score.

---

## Quick Start

### Install Dependencies
Ensure you have the `transformers` library installed:
```bash
pip install transformers
```

### Load Model and Tokenizer

```python
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")
```

### Use the predict Method for Inference

```python
# Single text prediction
result = model.predict(
    texts="Where does this story take place?",
    tokenizer=tokenizer,
    return_probabilities=True
)
print(result)
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]

# Batch prediction
results = model.predict(
    texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
    tokenizer=tokenizer,
    batch_size=16,
    device="cuda"
)
print(results)
"""
output:
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
 {'text': 'How does the story end?', 'label': 'action'},
 {'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
  'label': 'causal relationship'}]
"""
```

## Training Details
### Dataset
Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData)

### Training Parameters
Learning Rate: 2e-5
Batch Size: 32
Epochs: 3
Optimizer: AdamW

### Performance Metrics
Accuracy: 97.3%

Recall: 96.59%

F1 Score: 96.96%

## Notes
1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model.

## Citation

If you use this model, please cite the following:

```
@misc{BertForStorySkillClassification,
  author = {curious},
  title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
}
```

## License
This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file.