| # Retroactive Reasoning Network (RRN) for Question Answering | |
| ## Model Description | |
| This model implements an Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states. | |
| ### Key Features | |
| - **Multi-step Reasoning**: The model performs 3 reasoning steps to iteratively refine its predictions. | |
| - **Dynamic Reasoning Steps**: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5) | |
| - **Gating Mechanism**: Selectively applies updates to hidden states. | |
| - **Delta Magnitude Constraint**: Prevents destabilizing updates with a target ratio of 0.2. | |
| - **Active Memory**: Stores and retrieves examples to enhance reasoning. | |
| ## Usage | |
| ```python | |
| from transformers import AutoTokenizer | |
| from model import EnhancedRRN_QA_Model | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("will4381/rrn-qa") | |
| model = EnhancedRRN_QA_Model("will4381/rrn-qa") | |
| # Load custom components | |
| import torch | |
| import os | |
| model.qa_head.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "qa_head.pth"))) | |
| model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "retroactive_layer.pth"))) | |
| model.gating_mechanism.load_state_dict(torch.load(os.path.join("will4381/rrn-qa]", "gating_mechanism.pth"))) | |
| # If using learned dynamic steps | |
| if os.path.exists(os.path.join("will4381/rrn-qa", "step_controller.pth")) and hasattr(model, "step_controller"): | |
| model.step_controller.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "step_controller.pth"))) | |
| # Example usage | |
| inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt") | |
| outputs = model(**inputs) | |
| ``` | |
| ## Training | |
| This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the `code` directory of this repository. | |
| To train your own model: | |
| ```bash | |
| python code/train.py | |
| ``` | |
| To evaluate the model: | |
| ```bash | |
| python code/test_model.py | |
| ``` | |
| ## Model Architecture | |
| The RRN architecture consists of: | |
| 1. A base language model (BERT) | |
| 2. A retroactive update layer that computes delta updates | |
| 3. A gating mechanism for selective updates | |
| 4. An enhanced QA head for answer prediction | |
| 5. A step controller for dynamic reasoning steps (if enabled) | |
| ## Evaluation Results | |
| {'exact_match': 78.79848628193, 'f1': 86.94253357952118} | |