File size: 3,632 Bytes
439a4dd be52a7a f9979bb be52a7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
---
license: apache-2.0
datasets:
- WorkInTheDark/FairytaleQA
language:
- en
metrics:
- f1
- accuracy
- recall
base_model:
- google-bert/bert-base-uncased
pipeline_tag: text-classification
library_name: transformers
---
# BertForStorySkillClassification
## Model Overview
`BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:
1. **Character**
2. **Setting**
3. **Feeling**
4. **Action**
5. **Causal Relationship**
6. **Outcome Resolution**
7. **Prediction**
This model is suitable for applications in education, literary analysis, and story comprehension.
---
## Model Architecture
- **Base Model**: `bert-base-uncased`
- **Classification Layer**: A fully connected layer on top of BERT for 7-class classification.
- **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor \<context> alice is ... ")
- **Output**: Predicted label and confidence score.
---
## Quick Start
### Install Dependencies
Ensure you have the `transformers` library installed:
```bash
pip install transformers
```
### Load Model and Tokenizer
```python
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")
```
### Use the predict Method for Inference
```python
# Single text prediction
result = model.predict(
texts="Where does this story take place?",
tokenizer=tokenizer,
return_probabilities=True
)
print(result)
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]
# Batch prediction
results = model.predict(
texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
tokenizer=tokenizer,
batch_size=16,
device="cuda"
)
print(results)
"""
output:
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
{'text': 'How does the story end?', 'label': 'action'},
{'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
'label': 'causal relationship'}]
"""
```
## Training Details
### Dataset
Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData)
### Training Parameters
Learning Rate: 2e-5
Batch Size: 32
Epochs: 3
Optimizer: AdamW
### Performance Metrics
Accuracy: 97.3%
Recall: 96.59%
F1 Score: 96.96%
## Notes
1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model.
## Citation
If you use this model, please cite the following:
```
@misc{BertForStorySkillClassification,
author = {curious},
title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
}
```
## License
This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file. |