will4381 commited on
Commit
904e3eb
·
verified ·
1 Parent(s): 3451ca0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +69 -79
README.md CHANGED
@@ -1,79 +1,69 @@
1
- # Retroactive Reasoning Network (RRN) for Question Answering
2
-
3
- ## Model Description
4
-
5
- This model implements an Enhanced Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states.
6
-
7
- ### Key Features
8
-
9
- - **Multi-step Reasoning**: The model performs 3 reasoning steps to iteratively refine its predictions.
10
- - **Dynamic Reasoning Steps**: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5)
11
- - **Gating Mechanism**: Selectively applies updates to hidden states.
12
- - **Delta Magnitude Constraint**: Prevents destabilizing updates with a target ratio of 0.2.
13
- - **Active Memory**: Stores and retrieves examples to enhance reasoning.
14
-
15
- ## Usage
16
-
17
- ```python
18
- from transformers import AutoTokenizer
19
- from model import EnhancedRRN_QA_Model
20
-
21
- # Load tokenizer and model
22
- tokenizer = AutoTokenizer.from_pretrained("[MODEL_REPO_ID]")
23
- model = EnhancedRRN_QA_Model("[MODEL_REPO_ID]/base_model")
24
-
25
- # Load custom components
26
- import torch
27
- import os
28
-
29
- model.qa_head.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "qa_head.pth")))
30
- model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "retroactive_layer.pth")))
31
- model.gating_mechanism.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "gating_mechanism.pth")))
32
-
33
- # If using learned dynamic steps
34
- if os.path.exists(os.path.join("[MODEL_REPO_ID]", "step_controller.pth")) and hasattr(model, "step_controller"):
35
- model.step_controller.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "step_controller.pth")))
36
-
37
- # Example usage
38
- inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt")
39
- outputs = model(**inputs)
40
- ```
41
-
42
- ## Training
43
-
44
- This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the `code` directory of this repository.
45
-
46
- To train your own model:
47
-
48
- ```bash
49
- python code/train.py
50
- ```
51
-
52
- To evaluate the model:
53
-
54
- ```bash
55
- python code/test_model.py
56
- ```
57
-
58
- ## Model Architecture
59
-
60
- The RRN architecture consists of:
61
-
62
- 1. A base language model (BERT)
63
- 2. A retroactive update layer that computes delta updates
64
- 3. A gating mechanism for selective updates
65
- 4. An enhanced QA head for answer prediction
66
- 5. A step controller for dynamic reasoning steps (if enabled)
67
-
68
- ## Citation
69
-
70
- If you use this model in your research, please cite:
71
-
72
- ```
73
- @article{rrn_qa_model,
74
- title={Retroactive Reasoning Networks for Question Answering},
75
- author={[Authors]},
76
- journal={[Journal]},
77
- year={2025}
78
- }
79
- ```
 
1
+ # Retroactive Reasoning Network (RRN) for Question Answering
2
+
3
+ ## Model Description
4
+
5
+ This model implements an Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states.
6
+
7
+ ### Key Features
8
+
9
+ - **Multi-step Reasoning**: The model performs 3 reasoning steps to iteratively refine its predictions.
10
+ - **Dynamic Reasoning Steps**: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5)
11
+ - **Gating Mechanism**: Selectively applies updates to hidden states.
12
+ - **Delta Magnitude Constraint**: Prevents destabilizing updates with a target ratio of 0.2.
13
+ - **Active Memory**: Stores and retrieves examples to enhance reasoning.
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ from transformers import AutoTokenizer
19
+ from model import EnhancedRRN_QA_Model
20
+
21
+ # Load tokenizer and model
22
+ tokenizer = AutoTokenizer.from_pretrained("will4381/rrn-qa")
23
+ model = EnhancedRRN_QA_Model("will4381/rrn-qa")
24
+
25
+ # Load custom components
26
+ import torch
27
+ import os
28
+
29
+ model.qa_head.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "qa_head.pth")))
30
+ model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "retroactive_layer.pth")))
31
+ model.gating_mechanism.load_state_dict(torch.load(os.path.join("will4381/rrn-qa]", "gating_mechanism.pth")))
32
+
33
+ # If using learned dynamic steps
34
+ if os.path.exists(os.path.join("will4381/rrn-qa", "step_controller.pth")) and hasattr(model, "step_controller"):
35
+ model.step_controller.load_state_dict(torch.load(os.path.join("will4381/rrn-qa", "step_controller.pth")))
36
+
37
+ # Example usage
38
+ inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt")
39
+ outputs = model(**inputs)
40
+ ```
41
+
42
+ ## Training
43
+
44
+ This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the `code` directory of this repository.
45
+
46
+ To train your own model:
47
+
48
+ ```bash
49
+ python code/train.py
50
+ ```
51
+
52
+ To evaluate the model:
53
+
54
+ ```bash
55
+ python code/test_model.py
56
+ ```
57
+
58
+ ## Model Architecture
59
+
60
+ The RRN architecture consists of:
61
+
62
+ 1. A base language model (BERT)
63
+ 2. A retroactive update layer that computes delta updates
64
+ 3. A gating mechanism for selective updates
65
+ 4. An enhanced QA head for answer prediction
66
+ 5. A step controller for dynamic reasoning steps (if enabled)
67
+
68
+ ## Evaluation Results
69
+ {'exact_match': 78.79848628193, 'f1': 86.94253357952118}