rrn-qa / README.md
will4381's picture
Update README.md
904e3eb verified
# Retroactive Reasoning Network (RRN) for Question Answering
## Model Description
This model implements an Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states.
### Key Features
- **Multi-step Reasoning**: The model performs 3 reasoning steps to iteratively refine its predictions.
- **Dynamic Reasoning Steps**: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5)
- **Gating Mechanism**: Selectively applies updates to hidden states.
- **Delta Magnitude Constraint**: Prevents destabilizing updates with a target ratio of 0.2.
- **Active Memory**: Stores and retrieves examples to enhance reasoning.
## Usage
```python
from transformers import AutoTokenizer
from model import EnhancedRRN_QA_Model
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("will4381/rrn-qa")
model = EnhancedRRN_QA_Model("will4381/rrn-qa")
# Load custom components
import torch
import os
model.qa_head.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "qa_head.pth")))
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "retroactive_layer.pth")))
model.gating_mechanism.load_state_dict(torch.load(os.path.join("will4381/rrn-qa]", "gating_mechanism.pth")))
# If using learned dynamic steps
if os.path.exists(os.path.join("will4381/rrn-qa", "step_controller.pth")) and hasattr(model, "step_controller"):
model.step_controller.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "step_controller.pth")))
# Example usage
inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt")
outputs = model(**inputs)
```
## Training
This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the `code` directory of this repository.
To train your own model:
```bash
python code/train.py
```
To evaluate the model:
```bash
python code/test_model.py
```
## Model Architecture
The RRN architecture consists of:
1. A base language model (BERT)
2. A retroactive update layer that computes delta updates
3. A gating mechanism for selective updates
4. An enhanced QA head for answer prediction
5. A step controller for dynamic reasoning steps (if enabled)
## Evaluation Results
{'exact_match': 78.79848628193, 'f1': 86.94253357952118}