rrn-qa / README.md
will4381's picture
Update README.md
904e3eb verified

Retroactive Reasoning Network (RRN) for Question Answering

Model Description

This model implements an Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states.

Key Features

  • Multi-step Reasoning: The model performs 3 reasoning steps to iteratively refine its predictions.
  • Dynamic Reasoning Steps: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5)
  • Gating Mechanism: Selectively applies updates to hidden states.
  • Delta Magnitude Constraint: Prevents destabilizing updates with a target ratio of 0.2.
  • Active Memory: Stores and retrieves examples to enhance reasoning.

Usage

from transformers import AutoTokenizer
from model import EnhancedRRN_QA_Model

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("will4381/rrn-qa")
model = EnhancedRRN_QA_Model("will4381/rrn-qa")

# Load custom components
import torch
import os

model.qa_head.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "qa_head.pth")))
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "retroactive_layer.pth")))
model.gating_mechanism.load_state_dict(torch.load(os.path.join("will4381/rrn-qa]", "gating_mechanism.pth")))

# If using learned dynamic steps
if os.path.exists(os.path.join("will4381/rrn-qa", "step_controller.pth")) and hasattr(model, "step_controller"):
    model.step_controller.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "step_controller.pth")))

# Example usage
inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt")
outputs = model(**inputs)

Training

This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the code directory of this repository.

To train your own model:

python code/train.py

To evaluate the model:

python code/test_model.py

Model Architecture

The RRN architecture consists of:

  1. A base language model (BERT)
  2. A retroactive update layer that computes delta updates
  3. A gating mechanism for selective updates
  4. An enhanced QA head for answer prediction
  5. A step controller for dynamic reasoning steps (if enabled)

Evaluation Results

{'exact_match': 78.79848628193, 'f1': 86.94253357952118}