--- license: apache-2.0 datasets: - WorkInTheDark/FairytaleQA language: - en metrics: - f1 - accuracy - recall base_model: - google-bert/bert-base-uncased pipeline_tag: text-classification library_name: transformers --- # BertForStorySkillClassification ## Model Overview `BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes: 1. **Character** 2. **Setting** 3. **Feeling** 4. **Action** 5. **Causal Relationship** 6. **Outcome Resolution** 7. **Prediction** This model is suitable for applications in education, literary analysis, and story comprehension. --- ## Model Architecture - **Base Model**: `bert-base-uncased` - **Classification Layer**: A fully connected layer on top of BERT for 7-class classification. - **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \ because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \ because her family was very poor \ alice is ... ") - **Output**: Predicted label and confidence score. --- ## Quick Start ### Install Dependencies Ensure you have the `transformers` library installed: ```bash pip install transformers ``` ### Load Model and Tokenizer ```python from transformers import AutoModelForSequenceClassification, AutoTokenizer model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification") tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification") ``` ### Use the predict Method for Inference ```python # Single text prediction result = model.predict( texts="Where does this story take place?", tokenizer=tokenizer, return_probabilities=True ) print(result) # Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}] # Batch prediction results = model.predict( texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? because her family was very poor "], tokenizer=tokenizer, batch_size=16, device="cuda" ) print(results) """ output: [{'text': 'Why is the character sad?', 'label': 'causal relationship'}, {'text': 'How does the story end?', 'label': 'action'}, {'text': "why could n't alice get a doll as a child ? because her family was very poor ", 'label': 'causal relationship'}] """ ``` ## Training Details ### Dataset Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData) ### Training Parameters Learning Rate: 2e-5 Batch Size: 32 Epochs: 3 Optimizer: AdamW ### Performance Metrics Accuracy: 97.3% Recall: 96.59% F1 Score: 96.96% ## Notes 1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated. 2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance. 3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model. ## Citation If you use this model, please cite the following: ``` @misc{BertForStorySkillClassification, author = {curious}, title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification}, year = {2025}, publisher = {Hugging Face}, howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}} } ``` ## License This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file.