will4381 commited on
Commit
3451ca0
·
verified ·
1 Parent(s): 1f78c84

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Retroactive Reasoning Network (RRN) for Question Answering
2
+
3
+ ## Model Description
4
+
5
+ This model implements an Enhanced Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states.
6
+
7
+ ### Key Features
8
+
9
+ - **Multi-step Reasoning**: The model performs 3 reasoning steps to iteratively refine its predictions.
10
+ - **Dynamic Reasoning Steps**: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5)
11
+ - **Gating Mechanism**: Selectively applies updates to hidden states.
12
+ - **Delta Magnitude Constraint**: Prevents destabilizing updates with a target ratio of 0.2.
13
+ - **Active Memory**: Stores and retrieves examples to enhance reasoning.
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ from transformers import AutoTokenizer
19
+ from model import EnhancedRRN_QA_Model
20
+
21
+ # Load tokenizer and model
22
+ tokenizer = AutoTokenizer.from_pretrained("[MODEL_REPO_ID]")
23
+ model = EnhancedRRN_QA_Model("[MODEL_REPO_ID]/base_model")
24
+
25
+ # Load custom components
26
+ import torch
27
+ import os
28
+
29
+ model.qa_head.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "qa_head.pth")))
30
+ model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "retroactive_layer.pth")))
31
+ model.gating_mechanism.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "gating_mechanism.pth")))
32
+
33
+ # If using learned dynamic steps
34
+ if os.path.exists(os.path.join("[MODEL_REPO_ID]", "step_controller.pth")) and hasattr(model, "step_controller"):
35
+ model.step_controller.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "step_controller.pth")))
36
+
37
+ # Example usage
38
+ inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt")
39
+ outputs = model(**inputs)
40
+ ```
41
+
42
+ ## Training
43
+
44
+ This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the `code` directory of this repository.
45
+
46
+ To train your own model:
47
+
48
+ ```bash
49
+ python code/train.py
50
+ ```
51
+
52
+ To evaluate the model:
53
+
54
+ ```bash
55
+ python code/test_model.py
56
+ ```
57
+
58
+ ## Model Architecture
59
+
60
+ The RRN architecture consists of:
61
+
62
+ 1. A base language model (BERT)
63
+ 2. A retroactive update layer that computes delta updates
64
+ 3. A gating mechanism for selective updates
65
+ 4. An enhanced QA head for answer prediction
66
+ 5. A step controller for dynamic reasoning steps (if enabled)
67
+
68
+ ## Citation
69
+
70
+ If you use this model in your research, please cite:
71
+
72
+ ```
73
+ @article{rrn_qa_model,
74
+ title={Retroactive Reasoning Networks for Question Answering},
75
+ author={[Authors]},
76
+ journal={[Journal]},
77
+ year={2025}
78
+ }
79
+ ```
base_model/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 12,
18
+ "output_hidden_states": true,
19
+ "pad_token_id": 0,
20
+ "position_embedding_type": "absolute",
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.51.2",
23
+ "type_vocab_size": 2,
24
+ "use_cache": true,
25
+ "vocab_size": 30522
26
+ }
base_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0e649b208cf08b8748542d4303a944dfe14f9aeefc6bfe2bed4fca9dbb7c0ba
3
+ size 437951328
code/config.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py (Updated to Disable PEFT)
2
+ import torch
3
+
4
+ # --- Model Configuration ---
5
+ # Base model from Hugging Face (ensure it's suitable for QA)
6
+ # Example: 'bert-base-uncased', 'roberta-base', 'bert-large-uncased-whole-word-masking-finetuned-squad'
7
+ BASE_MODEL_NAME = "bert-base-uncased"
8
+
9
+ # --- RRN Specific Configuration ---
10
+ # Coherence loss weight
11
+ LAMBDA_COHERENCE = 0.1 # Hyperparameter to tune
12
+
13
+ # --- Delta Constraint Configuration ---
14
+ DELTA_TARGET_RATIO = 0.2 # Target ratio of delta norm to h0 norm
15
+ LAMBDA_DELTA_REG = 0.5 # Weight for delta regularization loss
16
+
17
+ # --- Multi-step Reasoning Configuration ---
18
+ NUM_REASONING_STEPS = 3 # Default number of reasoning steps (used when dynamic steps disabled)
19
+
20
+ # --- Dynamic Reasoning Steps Configuration ---
21
+ USE_DYNAMIC_STEPS = True # Enable/disable dynamic reasoning steps
22
+ MAX_REASONING_STEPS = 5 # Maximum number of reasoning steps
23
+ MIN_REASONING_STEPS = 1 # Minimum number of reasoning steps
24
+ REASONING_STEP_TYPE = "learned" # Options: "fixed", "confidence", "learned"
25
+ EARLY_STOP_THRESHOLD = 0.01 # Delta magnitude threshold for early stopping (used with "confidence")
26
+
27
+ # --- Mixed Precision Configuration ---
28
+ USE_MIXED_PRECISION = False # Enable/disable mixed precision training
29
+
30
+ # --- Memory Configuration ---
31
+ MEMORY_MAX_SIZE = 50 # Max number of entries in the memory
32
+ MEMORY_USE_DURING_TRAINING = False # Whether to use memory during training
33
+ MEMORY_RETRIEVAL_K = 3 # Number of examples to retrieve from memory
34
+
35
+ # --- PEFT (LoRA) Configuration ---
36
+ USE_PEFT = False # <--- SET TO False TO DISABLE PEFT ---
37
+
38
+ # --- Optional: Comment out or leave the LoRA specific settings ---
39
+ # LORA_R = 8
40
+ # LORA_ALPHA = 16
41
+ # LORA_DROPOUT = 0.1
42
+ # LORA_TARGET_MODULES = ["query", "value"]
43
+
44
+ # --- Testing Configuration ---
45
+ BYPASS_DELTA_CALCULATION = False # Set to True to bypass delta calculation for testing
46
+
47
+ # --- Training Configuration ---
48
+ # <<< --- Device Detection (CUDA prioritized over MPS) --- >>>
49
+ if torch.cuda.is_available():
50
+ DEVICE = "cuda"
51
+ print("CUDA GPU acceleration is available.")
52
+ elif torch.backends.mps.is_available():
53
+ DEVICE = "mps"
54
+ print("Apple Silicon MPS acceleration is available.")
55
+ else:
56
+ DEVICE = "cpu"
57
+ print("No GPU or MPS acceleration available, using CPU.")
58
+ # <<< --- End of Device Detection --- >>>
59
+
60
+ LEARNING_RATE = 1e-5 # Full fine-tuning often uses a smaller LR than PEFT
61
+ EPOCHS = 3
62
+ # --- Adjust Batch Size for Full Fine-tuning ---
63
+ # Full fine-tuning requires significantly more memory
64
+ BATCH_SIZE = 4 # Start smaller, adjust based on your CUDA memory
65
+ GRADIENT_ACCUMULATION_STEPS = 8 # Increase to compensate for smaller batch size
66
+
67
+ # --- Dataset Configuration ---
68
+ # Example for SQuAD
69
+ MAX_SEQ_LENGTH = 320 # Max input sequence length for QA
70
+ DOC_STRIDE = 128 # Stride for overlapping chunks for long documents
71
+
72
+ print(f"Using device: {DEVICE}")
73
+ print(f"Base model: {BASE_MODEL_NAME}")
74
+ # Update print statement to reflect PEFT status
75
+ print(f"Using PEFT (LoRA): {USE_PEFT} - Full Fine-tuning Enabled")
code/memory.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # memory.py
2
+ from collections import deque
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import config
6
+
7
+ class ActiveMemory:
8
+ """
9
+ An active memory module that stores and retrieves examples to enhance reasoning.
10
+ Supports both logging for analysis and retrieval for improved predictions.
11
+ """
12
+ def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
13
+ self.max_size = max_size
14
+ self.retrieval_k = retrieval_k
15
+ self.memory = deque(maxlen=max_size)
16
+ self.device = config.DEVICE
17
+ print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")
18
+
19
+ def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
20
+ """
21
+ Adds a new entry to the memory.
22
+
23
+ Args:
24
+ input_data: The input to the model (tokenized IDs, attention masks, etc.)
25
+ hidden_states (H0): Initial hidden states from the base model
26
+ output (y0): Initial prediction from the model
27
+ reasoning_trace (T): Reasoning trace (all hidden states)
28
+ final_hidden_states (H1, optional): Final hidden states after retroactive update
29
+ final_output (y1, optional): Final prediction after retroactive update
30
+ """
31
+ # Create a memory entry with detached tensors moved to CPU
32
+ entry = {
33
+ 'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
34
+ 'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
35
+ 'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
36
+ 'hidden_states': hidden_states.cpu().detach(),
37
+ 'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
38
+ 'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
39
+ }
40
+
41
+ # Add final states if provided
42
+ if final_hidden_states is not None:
43
+ entry['final_hidden_states'] = final_hidden_states.cpu().detach()
44
+ if final_output is not None:
45
+ entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
46
+
47
+ # Compute and store a summary vector for efficient retrieval
48
+ # Use mean pooling of hidden states as the summary vector
49
+ if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
50
+ # Mean pooling with attention mask
51
+ mask = entry['attention_mask'].unsqueeze(-1).float()
52
+ masked_embeddings = entry['hidden_states'] * mask
53
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
54
+ sum_mask = torch.sum(mask, dim=1)
55
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
56
+ entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
57
+ else:
58
+ # Fallback to simple mean if attention mask is not available
59
+ entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
60
+
61
+ self.memory.append(entry)
62
+
63
+ def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
64
+ """
65
+ Retrieves the k most similar examples from memory based on hidden state similarity.
66
+
67
+ Args:
68
+ query_hidden_states: Hidden states to compare against memory
69
+ query_attention_mask: Attention mask for the query
70
+ k: Number of examples to retrieve (defaults to self.retrieval_k)
71
+
72
+ Returns:
73
+ List of retrieved memory entries, ordered by similarity (most similar first)
74
+ """
75
+ if len(self.memory) == 0:
76
+ return []
77
+
78
+ if k is None:
79
+ k = self.retrieval_k
80
+
81
+ k = min(k, len(self.memory))
82
+
83
+ # Compute query summary vector (mean pooling with attention mask)
84
+ if query_attention_mask is not None:
85
+ mask = query_attention_mask.unsqueeze(-1).float()
86
+ masked_embeddings = query_hidden_states * mask
87
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
88
+ sum_mask = torch.sum(mask, dim=1)
89
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
90
+ query_vector = (sum_embeddings / sum_mask).squeeze(0)
91
+ else:
92
+ query_vector = query_hidden_states.mean(dim=1).squeeze(0)
93
+
94
+ # Move query vector to CPU for comparison with memory
95
+ query_vector = query_vector.cpu().detach()
96
+
97
+ # Compute similarities with all memory entries
98
+ similarities = []
99
+ for i, entry in enumerate(self.memory):
100
+ memory_vector = entry['summary_vector']
101
+ # Compute cosine similarity
102
+ similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
103
+ similarities.append((i, similarity.item()))
104
+
105
+ # Sort by similarity (descending) and get top k
106
+ similarities.sort(key=lambda x: x[1], reverse=True)
107
+ top_k_indices = [idx for idx, _ in similarities[:k]]
108
+
109
+ # Retrieve the top k entries
110
+ retrieved_entries = [self.memory[idx] for idx in top_k_indices]
111
+
112
+ # Move retrieved entries to the same device as the query
113
+ device = query_hidden_states.device
114
+ for entry in retrieved_entries:
115
+ # Only move the tensors we'll actually use (hidden_states and final_hidden_states)
116
+ if 'hidden_states' in entry:
117
+ entry['hidden_states'] = entry['hidden_states'].to(device)
118
+ if 'final_hidden_states' in entry:
119
+ entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
120
+
121
+ return retrieved_entries
122
+
123
+ def get_memory_context(self, query_hidden_states, query_attention_mask=None):
124
+ """
125
+ Retrieves and processes memory entries to create a context tensor for the model.
126
+
127
+ Args:
128
+ query_hidden_states: Hidden states to compare against memory
129
+ query_attention_mask: Attention mask for the query
130
+
131
+ Returns:
132
+ memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing
133
+ processed memory information, or None if memory is empty
134
+ """
135
+ # Retrieve similar examples from memory
136
+ retrieved = self.retrieve(query_hidden_states, query_attention_mask)
137
+
138
+ if not retrieved:
139
+ return None
140
+
141
+ # Use the device of the query
142
+ device = query_hidden_states.device
143
+ batch_size, seq_len, hidden_dim = query_hidden_states.shape
144
+
145
+ # Process retrieved examples to create memory context
146
+ # Strategy: Average the final hidden states of retrieved examples
147
+ memory_tensors = []
148
+ for entry in retrieved:
149
+ # Prefer final hidden states if available, otherwise use initial hidden states
150
+ if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
151
+ memory_tensors.append(entry['final_hidden_states'])
152
+ elif 'hidden_states' in entry:
153
+ memory_tensors.append(entry['hidden_states'])
154
+
155
+ if not memory_tensors:
156
+ return None
157
+
158
+ # Average the memory tensors
159
+ # First ensure all tensors have the same sequence length by padding or truncating
160
+ padded_tensors = []
161
+ for tensor in memory_tensors:
162
+ if tensor.size(1) < seq_len:
163
+ # Pad
164
+ padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
165
+ padded_tensor = torch.cat([tensor, padding], dim=1)
166
+ padded_tensors.append(padded_tensor)
167
+ elif tensor.size(1) > seq_len:
168
+ # Truncate
169
+ padded_tensors.append(tensor[:, :seq_len, :])
170
+ else:
171
+ padded_tensors.append(tensor)
172
+
173
+ # Stack and average
174
+ memory_context = torch.stack(padded_tensors).mean(dim=0)
175
+
176
+ # Expand to match batch size if needed
177
+ if memory_context.size(0) == 1 and batch_size > 1:
178
+ memory_context = memory_context.expand(batch_size, -1, -1)
179
+
180
+ return memory_context
181
+
182
+ def clear(self):
183
+ """Clears all entries from memory."""
184
+ self.memory.clear()
185
+
186
+ def __len__(self):
187
+ return len(self.memory)
code/model.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py (Enhanced RRN Implementation)
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import AutoModelForQuestionAnswering, AutoConfig, AutoModel
6
+ from transformers.modeling_outputs import QuestionAnsweringModelOutput
7
+
8
+ import config
9
+ from modules import CrossAttentionDelta, GatingMechanism, EnhancedQAHead
10
+ from memory import ActiveMemory
11
+
12
+ class EnhancedRRN_QA_Model(nn.Module):
13
+ """
14
+ Enhanced Retroactive Reasoning Network for Question Answering.
15
+ Improvements:
16
+ 1. Delta magnitude constraint
17
+ 2. Gating mechanism
18
+ 3. Multi-step reasoning
19
+ 4. Active memory usage
20
+ 5. Enhanced QA head
21
+ 6. Improved cross-attention
22
+ """
23
+ def __init__(self, model_name=config.BASE_MODEL_NAME):
24
+ super().__init__()
25
+ self.model_name = model_name
26
+
27
+ # --- Configuration ---
28
+ self.num_reasoning_steps = config.NUM_REASONING_STEPS
29
+ self.delta_target_ratio = config.DELTA_TARGET_RATIO
30
+
31
+ # --- Dynamic Reasoning Steps Configuration ---
32
+ self.use_dynamic_steps = config.USE_DYNAMIC_STEPS
33
+ self.max_reasoning_steps = config.MAX_REASONING_STEPS
34
+ self.min_reasoning_steps = config.MIN_REASONING_STEPS
35
+ self.reasoning_step_type = config.REASONING_STEP_TYPE
36
+ self.early_stop_threshold = config.EARLY_STOP_THRESHOLD
37
+
38
+ # --- Load Base Model Configuration ---
39
+ self.base_config = AutoConfig.from_pretrained(
40
+ self.model_name,
41
+ output_hidden_states=True, # Crucial for Reasoning Trace (T)
42
+ )
43
+ self.hidden_dim = self.base_config.hidden_size
44
+
45
+ # Add step controller for learned approach (after hidden_dim is defined)
46
+ if self.use_dynamic_steps and self.reasoning_step_type == "learned":
47
+ self.step_controller = nn.Sequential(
48
+ nn.Linear(self.hidden_dim, 128),
49
+ nn.ReLU(),
50
+ nn.Linear(128, self.max_reasoning_steps - self.min_reasoning_steps + 1)
51
+ )
52
+ print(f"Using learned dynamic reasoning steps (min={self.min_reasoning_steps}, max={self.max_reasoning_steps})")
53
+
54
+ # --- Load Base Model ---
55
+ self.base_model = AutoModel.from_pretrained(
56
+ self.model_name,
57
+ config=self.base_config
58
+ )
59
+ print(f"Loaded base model: {self.model_name}")
60
+ print(f"Hidden dimension: {self.hidden_dim}")
61
+ print(f"Using {self.num_reasoning_steps} reasoning steps")
62
+
63
+ # --- Enhanced RRN Components ---
64
+ # Improved cross-attention delta mechanism
65
+ self.retroactive_update_layer = CrossAttentionDelta(self.hidden_dim)
66
+
67
+ # Gating mechanism for selective updates
68
+ self.gating_mechanism = GatingMechanism(self.hidden_dim)
69
+
70
+ # Enhanced QA head with deeper architecture and bilinear scoring
71
+ self.qa_head = EnhancedQAHead(self.hidden_dim)
72
+
73
+ # --- Active Memory Module ---
74
+ self.memory = ActiveMemory(
75
+ max_size=config.MEMORY_MAX_SIZE,
76
+ retrieval_k=config.MEMORY_RETRIEVAL_K
77
+ )
78
+
79
+ # --- Loss Functions ---
80
+ self.coherence_loss_fn = nn.MSELoss()
81
+ self.delta_reg_loss_fn = nn.MSELoss()
82
+
83
+ def _apply_delta_constraint(self, delta, h0, is_training=False):
84
+ """
85
+ Apply delta magnitude constraint to prevent destabilizing updates.
86
+
87
+ Args:
88
+ delta: The computed delta
89
+ h0: The initial hidden states
90
+ is_training: Whether we're in training mode
91
+
92
+ Returns:
93
+ constrained_delta: The constrained delta
94
+ delta_reg_loss: Regularization loss for delta magnitude (if training)
95
+ """
96
+ # Compute delta and h0 norms
97
+ delta_norm = delta.norm(dim=-1, keepdim=True)
98
+ h0_norm = h0.norm(dim=-1, keepdim=True).detach()
99
+
100
+ # Compute ratio
101
+ ratio = delta_norm / (h0_norm + 1e-9)
102
+
103
+ # Compute regularization loss if in training
104
+ delta_reg_loss = None
105
+ if is_training:
106
+ # Target ratio tensor (same shape as ratio)
107
+ target_ratio = torch.ones_like(ratio) * self.delta_target_ratio
108
+ delta_reg_loss = self.delta_reg_loss_fn(ratio, target_ratio)
109
+
110
+ # Apply direct constraint (both during training and inference)
111
+ # Only scale down deltas that are too large
112
+ scale_factor = torch.ones_like(ratio)
113
+ too_large = ratio > self.delta_target_ratio
114
+ if too_large.any():
115
+ scale_factor[too_large] = self.delta_target_ratio / ratio[too_large]
116
+
117
+ # Apply scaling
118
+ constrained_delta = delta * scale_factor
119
+
120
+ return constrained_delta, delta_reg_loss
121
+
122
+ def forward(
123
+ self,
124
+ input_ids=None,
125
+ attention_mask=None,
126
+ token_type_ids=None,
127
+ start_positions=None,
128
+ end_positions=None,
129
+ output_attentions=None,
130
+ output_hidden_states=None,
131
+ return_dict=None,
132
+ use_memory=True
133
+ ):
134
+ return_dict = return_dict if return_dict is not None else self.base_config.use_return_dict
135
+ is_training = self.training
136
+
137
+ # === 1. Initial Forward Pass ===
138
+ # Determine if token_type_ids should be passed
139
+ include_token_type_ids = token_type_ids is not None
140
+
141
+ if include_token_type_ids:
142
+ outputs = self.base_model(
143
+ input_ids=input_ids,
144
+ attention_mask=attention_mask,
145
+ token_type_ids=token_type_ids,
146
+ output_hidden_states=True,
147
+ output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
148
+ return_dict=True
149
+ )
150
+ else:
151
+ outputs = self.base_model(
152
+ input_ids=input_ids,
153
+ attention_mask=attention_mask,
154
+ output_hidden_states=True,
155
+ output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
156
+ return_dict=True
157
+ )
158
+
159
+ # H(0): Last hidden state from the base model
160
+ h0 = outputs.last_hidden_state
161
+
162
+ # T: Reasoning Trace (all hidden states)
163
+ reasoning_trace_T = outputs.hidden_states
164
+
165
+ # y^(0): Initial QA prediction using H(0)
166
+ y0_output = self.qa_head(h0)
167
+ y0_start_logits, y0_end_logits = y0_output["start_logits"], y0_output["end_logits"]
168
+
169
+ # === 2. Memory Integration (if enabled) ===
170
+ memory_context = None
171
+ if use_memory and (is_training and config.MEMORY_USE_DURING_TRAINING or not is_training):
172
+ if len(self.memory) > 0:
173
+ memory_context = self.memory.get_memory_context(h0, attention_mask)
174
+
175
+ # === 3. Multi-step Reasoning ===
176
+ # Initialize current hidden state
177
+ h_current = h0
178
+
179
+ # Store all deltas and gates for loss calculation and analysis
180
+ all_deltas = []
181
+ all_gates = []
182
+ all_hidden_states = [h0]
183
+
184
+ # Determine number of reasoning steps to use
185
+ actual_steps_taken = 0
186
+
187
+ if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps:
188
+ if self.reasoning_step_type == "learned":
189
+ # Pool sequence dimension to get a single vector per example
190
+ pooled_h0 = h0.mean(dim=1)
191
+
192
+ # Get step logits from controller
193
+ step_logits = self.step_controller(pooled_h0)
194
+
195
+ if is_training:
196
+ # During training, sample from distribution (exploration)
197
+ step_probs = F.softmax(step_logits, dim=-1)
198
+ steps_idx = torch.multinomial(step_probs, 1).squeeze(-1)
199
+ num_steps = steps_idx + self.min_reasoning_steps
200
+ else:
201
+ # During inference, take argmax (exploitation)
202
+ steps_idx = torch.argmax(step_logits, dim=-1)
203
+ num_steps = steps_idx + self.min_reasoning_steps
204
+
205
+ # Store step logits for analysis
206
+ step_probs = F.softmax(step_logits, dim=-1)
207
+
208
+ # Get the maximum number of steps across the batch
209
+ max_num_steps = num_steps.max().item()
210
+ elif self.reasoning_step_type == "confidence":
211
+ # For confidence-based, we'll determine dynamically during the loop
212
+ max_num_steps = self.max_reasoning_steps
213
+ else:
214
+ # Fallback to fixed steps
215
+ max_num_steps = self.num_reasoning_steps
216
+ else:
217
+ # Use fixed number of steps
218
+ max_num_steps = self.num_reasoning_steps
219
+
220
+ # Perform reasoning steps
221
+ for step in range(max_num_steps):
222
+ # For confidence-based, check if we should continue for each example
223
+ if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "confidence" and step >= self.min_reasoning_steps:
224
+ # Check delta magnitude from previous step
225
+ if len(all_deltas) > 0:
226
+ prev_delta = all_deltas[-1]
227
+ delta_norm = prev_delta.norm(dim=-1).mean().item()
228
+ if delta_norm < self.early_stop_threshold:
229
+ break
230
+
231
+ # For learned approach, check if we've reached the determined number of steps
232
+ if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned":
233
+ # Create a mask for examples that should continue
234
+ if step > 0: # Skip first step check since all examples need at least 1 step
235
+ # Check which examples should continue
236
+ continue_mask = (step < num_steps).float().unsqueeze(-1).unsqueeze(-1)
237
+
238
+ # If no examples need more steps, break
239
+ if continue_mask.sum() == 0:
240
+ break
241
+
242
+ # Compute delta using the current hidden state and reasoning trace
243
+ if config.BYPASS_DELTA_CALCULATION:
244
+ # Bypass delta calculation for testing
245
+ delta = torch.zeros_like(h_current)
246
+ attn_weights = None
247
+ else:
248
+ delta, attn_weights = self.retroactive_update_layer(h_current, reasoning_trace_T)
249
+
250
+ # Apply delta magnitude constraint
251
+ constrained_delta, delta_reg_loss = self._apply_delta_constraint(delta, h0, is_training)
252
+
253
+ # For learned approach with continue_mask, apply mask to delta
254
+ if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned" and step > 0:
255
+ constrained_delta = constrained_delta * continue_mask
256
+
257
+ # Compute gate values for selective update
258
+ gate = self.gating_mechanism(h_current, constrained_delta)
259
+
260
+ # Apply gated update
261
+ h_current = h_current + gate * constrained_delta
262
+
263
+ # Store for later use
264
+ all_deltas.append(constrained_delta)
265
+ all_gates.append(gate)
266
+ all_hidden_states.append(h_current)
267
+ actual_steps_taken = step + 1
268
+
269
+ # Final hidden state after all reasoning steps
270
+ h_final = h_current
271
+
272
+ # === 4. Final Prediction ===
273
+ y_final_output = self.qa_head(h_final)
274
+ y_final_start_logits, y_final_end_logits = y_final_output["start_logits"], y_final_output["end_logits"]
275
+
276
+ # === 5. Loss Calculation ===
277
+ total_loss = None
278
+ loss_components = {}
279
+
280
+ if start_positions is not None and end_positions is not None:
281
+ # Prepare ground truth positions
282
+ if len(start_positions.size()) > 1:
283
+ start_positions = start_positions.squeeze(-1)
284
+ if len(end_positions.size()) > 1:
285
+ end_positions = end_positions.squeeze(-1)
286
+
287
+ ignored_index = y_final_start_logits.size(1)
288
+ start_positions = start_positions.clamp(0, ignored_index)
289
+ end_positions = end_positions.clamp(0, ignored_index)
290
+
291
+ # Task Loss (QA Loss)
292
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
293
+ start_loss = loss_fct(y_final_start_logits, start_positions)
294
+ end_loss = loss_fct(y_final_end_logits, end_positions)
295
+ task_loss = (start_loss + end_loss) / 2
296
+ loss_components["task_loss"] = task_loss.item()
297
+
298
+ # Coherence Loss
299
+ coherence_loss_start = self.coherence_loss_fn(y0_start_logits, y_final_start_logits.detach())
300
+ coherence_loss_end = self.coherence_loss_fn(y0_end_logits, y_final_end_logits.detach())
301
+ coherence_loss = (coherence_loss_start + coherence_loss_end) / 2
302
+ loss_components["coherence_loss"] = coherence_loss.item()
303
+
304
+ # Delta Regularization Loss (if computed)
305
+ if delta_reg_loss is not None:
306
+ loss_components["delta_reg_loss"] = delta_reg_loss.item()
307
+
308
+ # Total Loss
309
+ total_loss = task_loss + config.LAMBDA_COHERENCE * coherence_loss
310
+
311
+ # Add delta regularization if computed
312
+ if delta_reg_loss is not None:
313
+ total_loss = total_loss + config.LAMBDA_DELTA_REG * delta_reg_loss
314
+
315
+ # === 6. Memory Update ===
316
+ if use_memory:
317
+ # Prepare input data
318
+ input_data = {'input_ids': input_ids, 'attention_mask': attention_mask}
319
+ if token_type_ids is not None:
320
+ input_data['token_type_ids'] = token_type_ids
321
+
322
+ # Prepare outputs
323
+ initial_output = {'start_logits': y0_start_logits, 'end_logits': y0_end_logits}
324
+ final_output = {'start_logits': y_final_start_logits, 'end_logits': y_final_end_logits}
325
+
326
+ # Add to memory (during both training and inference if enabled)
327
+ if is_training and config.MEMORY_USE_DURING_TRAINING or not is_training:
328
+ self.memory.add(
329
+ input_data=input_data,
330
+ hidden_states=h0,
331
+ output=initial_output,
332
+ reasoning_trace=reasoning_trace_T,
333
+ final_hidden_states=h_final,
334
+ final_output=final_output
335
+ )
336
+
337
+ # === 7. Return Outputs ===
338
+ if not return_dict:
339
+ output = (y_final_start_logits, y_final_end_logits) + outputs[2:]
340
+ return ((total_loss,) + output) if total_loss is not None else output
341
+
342
+ # Store custom outputs as instance attributes for later access if needed
343
+ # This avoids passing them to QuestionAnsweringModelOutput which doesn't accept them
344
+ self.custom_outputs = {
345
+ "initial_hidden_states": h0,
346
+ "final_hidden_states": h_final,
347
+ "all_hidden_states": all_hidden_states,
348
+ "all_deltas": all_deltas,
349
+ "all_gates": all_gates,
350
+ "y0_start_logits": y0_start_logits,
351
+ "y0_end_logits": y0_end_logits,
352
+ "loss_components": loss_components if total_loss is not None else None,
353
+ "steps_taken": actual_steps_taken
354
+ }
355
+
356
+ # Add step controller outputs if using learned approach
357
+ if self.use_dynamic_steps and self.reasoning_step_type == "learned":
358
+ self.custom_outputs["step_probs"] = step_probs
359
+ self.custom_outputs["num_steps"] = num_steps
360
+
361
+ # Return standard QuestionAnsweringModelOutput without custom fields
362
+ return QuestionAnsweringModelOutput(
363
+ loss=total_loss,
364
+ start_logits=y_final_start_logits,
365
+ end_logits=y_final_end_logits,
366
+ hidden_states=outputs.hidden_states,
367
+ attentions=outputs.attentions
368
+ )
code/modules.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import config
6
+
7
+ class CrossAttentionDelta(nn.Module):
8
+ """
9
+ Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention.
10
+ Improvements:
11
+ 1. Pre-norm architecture (layer norm before attention)
12
+ 2. More sophisticated attention patterns
13
+ 3. Ability to incorporate reasoning trace
14
+ """
15
+ def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
16
+ super().__init__()
17
+ self.hidden_dim = hidden_dim
18
+ self.num_heads = num_heads
19
+
20
+ # Pre-norm layer normalization (applied before attention)
21
+ self.pre_norm = nn.LayerNorm(hidden_dim)
22
+
23
+ # Cross-attention mechanism
24
+ self.cross_attn = nn.MultiheadAttention(
25
+ embed_dim=hidden_dim,
26
+ num_heads=num_heads,
27
+ dropout=dropout,
28
+ batch_first=True
29
+ )
30
+
31
+ # Post-attention layer normalization
32
+ self.post_norm = nn.LayerNorm(hidden_dim)
33
+
34
+ # Trace integration module (to incorporate reasoning trace T)
35
+ self.trace_integration = nn.Sequential(
36
+ nn.Linear(hidden_dim * 2, hidden_dim),
37
+ nn.GELU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(hidden_dim, hidden_dim)
40
+ )
41
+
42
+ # Enhanced MLP for delta computation
43
+ self.delta_mlp = nn.Sequential(
44
+ nn.Linear(hidden_dim * 2, hidden_dim * 4), # Larger intermediate expansion
45
+ nn.GELU(),
46
+ nn.Dropout(dropout),
47
+ nn.Linear(hidden_dim * 4, hidden_dim * 2),
48
+ nn.GELU(),
49
+ nn.Dropout(dropout),
50
+ nn.Linear(hidden_dim * 2, hidden_dim)
51
+ )
52
+
53
+ # Final layer normalization
54
+ self.final_norm = nn.LayerNorm(hidden_dim)
55
+
56
+ def forward(self, h0, reasoning_trace=None):
57
+ """
58
+ Args:
59
+ h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
60
+ reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model.
61
+ Each tensor has shape (batch_size, seq_len, hidden_dim).
62
+
63
+ Returns:
64
+ delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim).
65
+ """
66
+ batch_size, seq_len, _ = h0.shape
67
+
68
+ # --- Pre-norm Architecture ---
69
+ # Apply layer normalization before attention (pre-norm)
70
+ h0_norm = self.pre_norm(h0)
71
+
72
+ # --- Enhanced Cross-Attention ---
73
+ # Get attention weights to visualize attention patterns
74
+ attn_output, attn_weights = self.cross_attn(
75
+ query=h0_norm,
76
+ key=h0_norm,
77
+ value=h0_norm,
78
+ need_weights=True
79
+ )
80
+
81
+ # Residual connection and post-norm
82
+ c = self.post_norm(h0 + attn_output)
83
+
84
+ # --- Reasoning Trace Integration (if provided) ---
85
+ if reasoning_trace is not None and len(reasoning_trace) > 0:
86
+ # Use the last layer from the reasoning trace (most semantic)
87
+ last_layer = reasoning_trace[-1]
88
+
89
+ # Integrate the reasoning trace with the current context
90
+ trace_info = self.trace_integration(
91
+ torch.cat([c, last_layer], dim=-1)
92
+ )
93
+
94
+ # Add the trace information to the context
95
+ c = c + trace_info
96
+
97
+ # --- Enhanced MLP for Delta ---
98
+ # Concatenate original h0 with context c
99
+ mlp_input = torch.cat((h0, c), dim=-1)
100
+
101
+ # Compute delta through enhanced MLP
102
+ delta = self.delta_mlp(mlp_input)
103
+
104
+ # Apply final normalization
105
+ delta = self.final_norm(delta)
106
+
107
+ return delta, attn_weights
108
+
109
+ class GatingMechanism(nn.Module):
110
+ """
111
+ Gating mechanism to selectively apply updates.
112
+ Learns when to apply the delta update based on the hidden state and delta.
113
+ """
114
+ def __init__(self, hidden_dim, dropout=0.1):
115
+ super().__init__()
116
+ self.gate_network = nn.Sequential(
117
+ nn.Linear(hidden_dim * 2, hidden_dim),
118
+ nn.GELU(),
119
+ nn.Dropout(dropout),
120
+ nn.Linear(hidden_dim, hidden_dim),
121
+ nn.GELU(),
122
+ nn.Dropout(dropout),
123
+ nn.Linear(hidden_dim, 1),
124
+ nn.Sigmoid() # Output between 0 and 1
125
+ )
126
+
127
+ def forward(self, h0, delta):
128
+ """
129
+ Args:
130
+ h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
131
+ delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim).
132
+
133
+ Returns:
134
+ gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1).
135
+ """
136
+ # Concatenate h0 and delta
137
+ gate_input = torch.cat([h0, delta], dim=-1)
138
+
139
+ # Compute gate values
140
+ gate = self.gate_network(gate_input)
141
+
142
+ return gate
143
+
144
+ class EnhancedQAHead(nn.Module):
145
+ """
146
+ Enhanced Question Answering head with deeper architecture and bilinear scoring.
147
+ """
148
+ def __init__(self, hidden_dim, dropout=0.1):
149
+ super().__init__()
150
+
151
+ # Deeper representation before prediction
152
+ self.start_transform = nn.Sequential(
153
+ nn.Linear(hidden_dim, hidden_dim),
154
+ nn.GELU(),
155
+ nn.Dropout(dropout),
156
+ nn.Linear(hidden_dim, hidden_dim)
157
+ )
158
+
159
+ self.end_transform = nn.Sequential(
160
+ nn.Linear(hidden_dim, hidden_dim),
161
+ nn.GELU(),
162
+ nn.Dropout(dropout),
163
+ nn.Linear(hidden_dim, hidden_dim)
164
+ )
165
+
166
+ # Bilinear layer for start position scoring
167
+ self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
168
+
169
+ # Bilinear layer for end position scoring
170
+ self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
171
+
172
+ # Global representation for bilinear scoring
173
+ self.global_rep = nn.Parameter(torch.randn(hidden_dim))
174
+
175
+ def forward(self, hidden_states):
176
+ """
177
+ Args:
178
+ hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim).
179
+
180
+ Returns:
181
+ dict: Dictionary with start_logits and end_logits.
182
+ """
183
+ batch_size, seq_len, hidden_dim = hidden_states.shape
184
+
185
+ # Transform hidden states
186
+ start_rep = self.start_transform(hidden_states)
187
+ end_rep = self.end_transform(hidden_states)
188
+
189
+ # Expand global representation for batch processing
190
+ global_rep = self.global_rep.expand(batch_size, seq_len, -1)
191
+
192
+ # Compute start and end logits using bilinear scoring
193
+ start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1)
194
+ end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1)
195
+
196
+ return {"start_logits": start_logits, "end_logits": end_logits}
code/test_model.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_model.py - RRN QA Model evaluation script with multi-step reasoning support
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from transformers import AutoTokenizer, AutoModel, default_data_collator
5
+ from datasets import load_dataset
6
+ from tqdm.auto import tqdm
7
+ import os
8
+ import evaluate as hf_evaluate # Import with alias to avoid naming conflict
9
+ import collections
10
+ import numpy as np
11
+ import logging
12
+ import multiprocessing # For Windows multiprocessing support
13
+ import json
14
+ import argparse
15
+ import matplotlib.pyplot as plt
16
+ from collections import defaultdict
17
+
18
+ # Import custom modules and config
19
+ import config
20
+ from model import EnhancedRRN_QA_Model # Import the enhanced model
21
+ # Make sure memory.py and modules.py are accessible
22
+
23
+ # --- Configuration ---
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def main():
28
+ # Parse command line arguments
29
+ parser = argparse.ArgumentParser(description="Test RRN QA Model")
30
+ parser.add_argument("--checkpoint", type=str, default="./rrn_qa_model_epoch_3",
31
+ help="Path to checkpoint directory (default: ./rrn_qa_model_epoch_3)")
32
+ parser.add_argument("--batch_size", type=int, default=8,
33
+ help="Evaluation batch size (default: 8)")
34
+ parser.add_argument("--fixed_steps", type=int, default=None,
35
+ help="Override to use fixed number of reasoning steps (default: None, use model's dynamic steps)")
36
+ parser.add_argument("--use_memory", action="store_true",
37
+ help="Enable active memory during evaluation")
38
+ parser.add_argument("--output_dir", type=str, default="./eval_results",
39
+ help="Directory to save evaluation results (default: ./eval_results)")
40
+ parser.add_argument("--visualize", action="store_true",
41
+ help="Generate visualizations of reasoning steps")
42
+ args = parser.parse_args()
43
+
44
+ CHECKPOINT_DIR = args.checkpoint
45
+ EVAL_BATCH_SIZE = args.batch_size
46
+ DEVICE = config.DEVICE
47
+ USE_MEMORY = args.use_memory
48
+ OUTPUT_DIR = args.output_dir
49
+
50
+ # Create output directory if it doesn't exist
51
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
52
+
53
+ logger.info(f"Evaluation configuration:")
54
+ logger.info(f" Checkpoint: {CHECKPOINT_DIR}")
55
+ logger.info(f" Batch size: {EVAL_BATCH_SIZE}")
56
+ logger.info(f" Device: {DEVICE}")
57
+ logger.info(f" Use memory: {USE_MEMORY}")
58
+ logger.info(f" Output directory: {OUTPUT_DIR}")
59
+ if args.fixed_steps is not None:
60
+ logger.info(f" Using fixed {args.fixed_steps} reasoning steps (overriding model config)")
61
+
62
+ # --- 1. Load Tokenizer and Model from Checkpoint ---
63
+ logger.info(f"Loading tokenizer from {CHECKPOINT_DIR}...")
64
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
65
+
66
+ logger.info(f"Loading Enhanced RRN QA Model architecture...")
67
+ # Instantiate the enhanced model architecture
68
+ model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
69
+
70
+ # Check if we're loading from a checkpoint with the enhanced architecture
71
+ base_model_path = os.path.join(CHECKPOINT_DIR, "base_model")
72
+ qa_head_path = os.path.join(CHECKPOINT_DIR, "qa_head.pth")
73
+ retroactive_layer_path = os.path.join(CHECKPOINT_DIR, "retroactive_layer.pth")
74
+ gating_mechanism_path = os.path.join(CHECKPOINT_DIR, "gating_mechanism.pth")
75
+ step_controller_path = os.path.join(CHECKPOINT_DIR, "step_controller.pth")
76
+
77
+ # Check for required components
78
+ if not os.path.exists(base_model_path):
79
+ logger.error(f"Base model directory not found at: {base_model_path}")
80
+ exit()
81
+ if not os.path.exists(qa_head_path):
82
+ logger.error(f"QA head weights not found at: {qa_head_path}")
83
+ exit()
84
+ if not os.path.exists(retroactive_layer_path):
85
+ logger.error(f"Retroactive layer weights not found at: {retroactive_layer_path}")
86
+ exit()
87
+
88
+ # Load base model weights
89
+ logger.info(f"Loading base model weights from {base_model_path}...")
90
+ model.base_model = AutoModel.from_pretrained(base_model_path)
91
+
92
+ # Check if we're loading from an enhanced checkpoint or a legacy checkpoint
93
+ is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
94
+
95
+ if is_enhanced_checkpoint:
96
+ # Load all enhanced components
97
+ logger.info("Loading enhanced model components...")
98
+ model.qa_head.load_state_dict(torch.load(qa_head_path, map_location='cpu'))
99
+ model.retroactive_update_layer.load_state_dict(torch.load(retroactive_layer_path, map_location='cpu'))
100
+ model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path, map_location='cpu'))
101
+
102
+ # Load step controller if available (for learned dynamic steps)
103
+ if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
104
+ logger.info("Loading step controller for learned dynamic steps...")
105
+ model.step_controller.load_state_dict(torch.load(step_controller_path, map_location='cpu'))
106
+
107
+ logger.info("Enhanced model loaded successfully.")
108
+ else:
109
+ # We're loading from a legacy checkpoint - need to adapt the weights
110
+ logger.info("Loading from legacy checkpoint - adapting weights to enhanced architecture...")
111
+
112
+ # For the QA head, we need to initialize the enhanced QA head from scratch
113
+ # since the architectures are different
114
+ logger.info("Initializing enhanced QA head with random weights...")
115
+
116
+ # For the retroactive layer, we can try to load the weights but might need adjustments
117
+ logger.warning("Note: The enhanced model uses a different architecture than the checkpoint.")
118
+ logger.warning("Some components will use random initialization.")
119
+
120
+ # Load enhanced config if available
121
+ enhanced_config_path = os.path.join(CHECKPOINT_DIR, "enhanced_config.json")
122
+ if os.path.exists(enhanced_config_path):
123
+ logger.info(f"Loading enhanced configuration from {enhanced_config_path}")
124
+ with open(enhanced_config_path, 'r') as f:
125
+ enhanced_config = json.load(f)
126
+
127
+ # Override model configuration with saved values
128
+ if "num_reasoning_steps" in enhanced_config:
129
+ model.num_reasoning_steps = enhanced_config["num_reasoning_steps"]
130
+ logger.info(f"Using {model.num_reasoning_steps} reasoning steps from config")
131
+
132
+ if "use_dynamic_steps" in enhanced_config:
133
+ model.use_dynamic_steps = enhanced_config["use_dynamic_steps"]
134
+ if model.use_dynamic_steps:
135
+ model.max_reasoning_steps = enhanced_config.get("max_reasoning_steps", config.MAX_REASONING_STEPS)
136
+ model.min_reasoning_steps = enhanced_config.get("min_reasoning_steps", config.MIN_REASONING_STEPS)
137
+ model.reasoning_step_type = enhanced_config.get("reasoning_step_type", config.REASONING_STEP_TYPE)
138
+ model.early_stop_threshold = enhanced_config.get("early_stop_threshold", config.EARLY_STOP_THRESHOLD)
139
+ logger.info(f"Using dynamic reasoning steps (type: {model.reasoning_step_type})")
140
+ logger.info(f"Min steps: {model.min_reasoning_steps}, Max steps: {model.max_reasoning_steps}")
141
+
142
+ # Override with fixed steps if specified
143
+ if args.fixed_steps is not None:
144
+ logger.info(f"Overriding with fixed {args.fixed_steps} reasoning steps")
145
+ model.use_dynamic_steps = False
146
+ model.num_reasoning_steps = args.fixed_steps
147
+
148
+ model.to(DEVICE)
149
+ model.eval() # Set model to evaluation mode
150
+ logger.info("Model loaded successfully and set to evaluation mode.")
151
+
152
+
153
+ # --- 2. Load and Preprocess Validation Dataset ---
154
+ logger.info("Loading SQuAD validation dataset...")
155
+ raw_datasets = load_dataset("squad", split="validation")
156
+
157
+
158
+ question_column_name = "question"
159
+ context_column_name = "context"
160
+ answer_column_name = "answers"
161
+ pad_on_right = tokenizer.padding_side == "right"
162
+
163
+ # Validation preprocessing: Keep example_id and offset_mapping
164
+ def prepare_validation_features(examples):
165
+ examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
166
+ tokenized_examples = tokenizer(
167
+ examples[question_column_name if pad_on_right else context_column_name],
168
+ examples[context_column_name if pad_on_right else question_column_name],
169
+ truncation="only_second" if pad_on_right else "only_first",
170
+ max_length=config.MAX_SEQ_LENGTH,
171
+ stride=config.DOC_STRIDE,
172
+ return_overflowing_tokens=True,
173
+ return_offsets_mapping=True,
174
+ padding="max_length",
175
+ )
176
+
177
+ # Keep track of which feature belongs to which example
178
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
179
+
180
+ # Add the example_id to link features to original examples
181
+ tokenized_examples["example_id"] = []
182
+ for i in range(len(tokenized_examples["input_ids"])):
183
+ sequence_ids = tokenized_examples.sequence_ids(i)
184
+ context_index = 1 if pad_on_right else 0
185
+ sample_index = sample_mapping[i]
186
+ tokenized_examples["example_id"].append(examples["id"][sample_index])
187
+
188
+ # Set offset mapping to None for question tokens to avoid predicting answers there
189
+ tokenized_examples["offset_mapping"][i] = [
190
+ (o if sequence_ids[k] == context_index else None)
191
+ for k, o in enumerate(tokenized_examples["offset_mapping"][i])
192
+ ]
193
+
194
+ return tokenized_examples
195
+
196
+ logger.info("Preprocessing validation dataset...")
197
+ # Disable multiprocessing which can hang on some systems
198
+ logger.info("Using single process for preprocessing to prevent hanging")
199
+ eval_dataset = raw_datasets.map(
200
+ prepare_validation_features,
201
+ batched=True,
202
+ remove_columns=raw_datasets.column_names,
203
+ num_proc=1, # Disable multiprocessing to avoid hanging
204
+ )
205
+
206
+ # Custom collator to handle None values in offset_mapping
207
+ def custom_data_collator(features):
208
+ # First, remove offset_mapping which contains None values that can't be batched
209
+ offset_mappings = [f.pop("offset_mapping") for f in features]
210
+
211
+ # Use default collator for everything else
212
+ batch = default_data_collator(features)
213
+
214
+ # Add offset_mapping back as a list since it can't be converted to a tensor
215
+ batch["offset_mapping"] = offset_mappings
216
+
217
+ return batch
218
+
219
+ # Use custom data collator
220
+ data_collator = custom_data_collator
221
+
222
+ eval_dataloader = DataLoader(
223
+ eval_dataset,
224
+ collate_fn=data_collator,
225
+ batch_size=EVAL_BATCH_SIZE
226
+ )
227
+
228
+ # --- 3. Run Inference ---
229
+ logger.info("***** Running Evaluation *****")
230
+ logger.info(f" Num examples = {len(eval_dataset)}")
231
+ logger.info(f" Batch size = {EVAL_BATCH_SIZE}")
232
+
233
+ all_start_logits = []
234
+ all_end_logits = []
235
+ feature_indices = [] # Keep track of the order
236
+
237
+ # Track multi-step reasoning metrics
238
+ reasoning_steps_taken = []
239
+ delta_magnitudes = []
240
+ gate_values = []
241
+ initial_vs_final_changes = []
242
+
243
+ with torch.no_grad():
244
+ for step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
245
+ # Move batch to device
246
+ batch_on_device = {k: v.to(DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)}
247
+ # Store feature indices corresponding to this batch
248
+ # Assuming 'input_ids' or similar key represents features in order
249
+ current_indices = list(range(step * EVAL_BATCH_SIZE, step * EVAL_BATCH_SIZE + len(batch_on_device['input_ids'])))
250
+ feature_indices.extend(current_indices)
251
+
252
+ # Forward pass - pass only inputs needed by model.forward
253
+ outputs = model(
254
+ input_ids=batch_on_device.get("input_ids"),
255
+ attention_mask=batch_on_device.get("attention_mask"),
256
+ token_type_ids=batch_on_device.get("token_type_ids"),
257
+ use_memory=USE_MEMORY, # Use memory if enabled
258
+ return_dict=True
259
+ )
260
+
261
+ # Get the final logits (y1)
262
+ start_logits = outputs.start_logits
263
+ end_logits = outputs.end_logits
264
+
265
+ all_start_logits.append(start_logits.cpu().numpy())
266
+ all_end_logits.append(end_logits.cpu().numpy())
267
+
268
+ # Collect multi-step reasoning metrics from custom_outputs
269
+ if hasattr(model, 'custom_outputs'):
270
+ # Number of reasoning steps taken
271
+ if 'steps_taken' in model.custom_outputs:
272
+ reasoning_steps_taken.append(model.custom_outputs['steps_taken'])
273
+
274
+ # Delta magnitudes (how much the model updates at each step)
275
+ if 'all_deltas' in model.custom_outputs and len(model.custom_outputs['all_deltas']) > 0:
276
+ batch_deltas = []
277
+ for delta in model.custom_outputs['all_deltas']:
278
+ # Calculate mean delta magnitude across sequence dimension
279
+ delta_norm = delta.norm(dim=-1).mean().cpu().item()
280
+ batch_deltas.append(delta_norm)
281
+ delta_magnitudes.append(batch_deltas)
282
+
283
+ # Gate values (how selective the updates are)
284
+ if 'all_gates' in model.custom_outputs and len(model.custom_outputs['all_gates']) > 0:
285
+ batch_gates = []
286
+ for gate in model.custom_outputs['all_gates']:
287
+ # Calculate mean gate value across sequence dimension
288
+ gate_mean = gate.mean().cpu().item()
289
+ batch_gates.append(gate_mean)
290
+ gate_values.append(batch_gates)
291
+
292
+ # Compare initial vs final predictions
293
+ if 'y0_start_logits' in model.custom_outputs and 'y0_end_logits' in model.custom_outputs:
294
+ y0_start = model.custom_outputs['y0_start_logits']
295
+ y0_end = model.custom_outputs['y0_end_logits']
296
+
297
+ # Calculate how much the predictions changed
298
+ start_change = (start_logits - y0_start).abs().mean().cpu().item()
299
+ end_change = (end_logits - y0_end).abs().mean().cpu().item()
300
+ initial_vs_final_changes.append((start_change + end_change) / 2)
301
+
302
+ # Concatenate all results
303
+ all_start_logits = np.concatenate(all_start_logits, axis=0)
304
+ all_end_logits = np.concatenate(all_end_logits, axis=0)
305
+
306
+ # Ensure the number of predictions matches the number of features
307
+ if len(all_start_logits) != len(eval_dataset):
308
+ logger.warning(f"Mismatch in prediction count ({len(all_start_logits)}) and feature count ({len(eval_dataset)}). Check dataloader/inference loop.")
309
+ # Attempt to slice if predictions exceed features (might happen if last batch wasn't full)
310
+ all_start_logits = all_start_logits[:len(eval_dataset)]
311
+ all_end_logits = all_end_logits[:len(eval_dataset)]
312
+
313
+
314
+ # Create dictionary mapping feature index to its logits
315
+ predictions_dict = {
316
+ feature_index: (start_logit, end_logit)
317
+ for feature_index, (start_logit, end_logit) in zip(feature_indices, zip(all_start_logits, all_end_logits))
318
+ }
319
+
320
+
321
+ # --- 4. Post-Processing ---
322
+ # (Adapted from Hugging Face run_qa.py example script)
323
+ def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30, tokenizer=tokenizer):
324
+ all_start_logits, all_end_logits = zip(*raw_predictions.values())
325
+
326
+ # Build a map from example ID to list of related feature indices
327
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
328
+ features_per_example = collections.defaultdict(list)
329
+ for i, feature in enumerate(features):
330
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
331
+
332
+ # Dictionary to store predictions
333
+ predictions = collections.OrderedDict()
334
+
335
+ logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
336
+
337
+ # Loop over all examples
338
+ for example_index, example in enumerate(tqdm(examples, desc="Post-processing")):
339
+ feature_indices = features_per_example[example_index] # Indices of features related to this example
340
+
341
+ min_null_score = None # Used to identify impossible answers
342
+ valid_answers = []
343
+ context = example["context"]
344
+
345
+ # Loop through features associated with the current example
346
+ for feature_index in feature_indices:
347
+ start_logits = all_start_logits[feature_index]
348
+ end_logits = all_end_logits[feature_index]
349
+ offset_mapping = features[feature_index]["offset_mapping"]
350
+
351
+ # Update minimum null prediction score
352
+ cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
353
+ feature_null_score = start_logits[cls_index] + end_logits[cls_index]
354
+ if min_null_score is None or min_null_score < feature_null_score:
355
+ min_null_score = feature_null_score
356
+
357
+ # Go through all possibilities for start/end positions
358
+ start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
359
+ end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
360
+ for start_index in start_indexes:
361
+ for end_index in end_indexes:
362
+ # Skip invalid pairs (start > end, index out of bounds, answer in question part)
363
+ if start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or \
364
+ offset_mapping[start_index] is None or offset_mapping[end_index] is None or \
365
+ end_index < start_index:
366
+ continue
367
+
368
+ # Check answer length
369
+ if end_index - start_index + 1 > max_answer_length:
370
+ continue
371
+
372
+ # Extract text and score
373
+ start_char = offset_mapping[start_index][0]
374
+ end_char = offset_mapping[end_index][1]
375
+ score = start_logits[start_index] + end_logits[end_index]
376
+
377
+ valid_answers.append({
378
+ "score": score,
379
+ "text": context[start_char: end_char]
380
+ })
381
+
382
+ # Select the best answer across all features for this example
383
+ if len(valid_answers) > 0:
384
+ best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
385
+ else:
386
+ # Fallback for no valid answers found
387
+ best_answer = {"text": "", "score": min_null_score} # Assign CLS score if needed
388
+
389
+ # Assign final prediction (use empty string if null score is best)
390
+ # Simple version: always take the best scoring valid answer
391
+ # More sophisticated versions might compare best_answer["score"] vs min_null_score
392
+ predictions[example["id"]] = best_answer["text"]
393
+
394
+
395
+ return predictions
396
+
397
+ logger.info("Starting post-processing...")
398
+ final_predictions = postprocess_qa_predictions(raw_datasets, eval_dataset, predictions_dict)
399
+
400
+
401
+ # --- 5. Compute Metrics ---
402
+ logger.info("Calculating SQuAD metrics...")
403
+ metric = hf_evaluate.load("squad")
404
+
405
+ # Format predictions and references for the metric
406
+ formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
407
+ formatted_references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in raw_datasets]
408
+
409
+ results = metric.compute(predictions=formatted_predictions, references=formatted_references)
410
+
411
+ logger.info("***** Evaluation Results *****")
412
+ print(results)
413
+
414
+ # --- 6. Analyze Multi-step Reasoning Metrics ---
415
+ logger.info("\n***** Multi-step Reasoning Analysis *****")
416
+
417
+ # Calculate average number of reasoning steps
418
+ if reasoning_steps_taken:
419
+ avg_steps = sum(reasoning_steps_taken) / len(reasoning_steps_taken)
420
+ logger.info(f"Average reasoning steps: {avg_steps:.2f}")
421
+
422
+ # Count frequency of each step count
423
+ step_counts = collections.Counter(reasoning_steps_taken)
424
+ logger.info(f"Step count distribution: {dict(sorted(step_counts.items()))}")
425
+
426
+ # Calculate average delta magnitudes per step
427
+ if delta_magnitudes:
428
+ # Transpose to get step-wise averages
429
+ steps_delta_magnitudes = defaultdict(list)
430
+ for batch_deltas in delta_magnitudes:
431
+ for step_idx, delta in enumerate(batch_deltas):
432
+ steps_delta_magnitudes[step_idx].append(delta)
433
+
434
+ avg_delta_by_step = {step: sum(deltas)/len(deltas) for step, deltas in steps_delta_magnitudes.items()}
435
+ logger.info(f"Average delta magnitude by step: {avg_delta_by_step}")
436
+
437
+ # Calculate average gate values per step
438
+ if gate_values:
439
+ # Transpose to get step-wise averages
440
+ steps_gate_values = defaultdict(list)
441
+ for batch_gates in gate_values:
442
+ for step_idx, gate in enumerate(batch_gates):
443
+ steps_gate_values[step_idx].append(gate)
444
+
445
+ avg_gate_by_step = {step: sum(gates)/len(gates) for step, gates in steps_gate_values.items()}
446
+ logger.info(f"Average gate value by step: {avg_gate_by_step}")
447
+
448
+ # Calculate average change from initial to final predictions
449
+ if initial_vs_final_changes:
450
+ avg_change = sum(initial_vs_final_changes) / len(initial_vs_final_changes)
451
+ logger.info(f"Average change from initial to final predictions: {avg_change:.4f}")
452
+
453
+ # --- 7. Save Results ---
454
+ results_file = os.path.join(OUTPUT_DIR, "eval_results.json")
455
+ with open(results_file, 'w') as f:
456
+ # Combine SQuAD metrics with multi-step reasoning metrics
457
+ full_results = {
458
+ "squad_metrics": results,
459
+ "multi_step_metrics": {
460
+ "avg_reasoning_steps": avg_steps if reasoning_steps_taken else None,
461
+ "step_count_distribution": dict(sorted(step_counts.items())) if reasoning_steps_taken else None,
462
+ "avg_delta_by_step": avg_delta_by_step if delta_magnitudes else None,
463
+ "avg_gate_by_step": avg_gate_by_step if gate_values else None,
464
+ "avg_prediction_change": avg_change if initial_vs_final_changes else None
465
+ }
466
+ }
467
+ json.dump(full_results, f, indent=2)
468
+
469
+ logger.info(f"Results saved to {results_file}")
470
+
471
+ # --- 8. Generate Visualizations (if requested) ---
472
+ if args.visualize and (delta_magnitudes or gate_values or reasoning_steps_taken):
473
+ logger.info("Generating visualizations...")
474
+
475
+ # Create visualization directory
476
+ viz_dir = os.path.join(OUTPUT_DIR, "visualizations")
477
+ os.makedirs(viz_dir, exist_ok=True)
478
+
479
+ # Plot step distribution
480
+ if reasoning_steps_taken:
481
+ plt.figure(figsize=(10, 6))
482
+ plt.bar(step_counts.keys(), step_counts.values())
483
+ plt.xlabel('Number of Reasoning Steps')
484
+ plt.ylabel('Frequency')
485
+ plt.title('Distribution of Reasoning Steps')
486
+ plt.savefig(os.path.join(viz_dir, 'step_distribution.png'))
487
+ plt.close()
488
+
489
+ # Plot delta magnitudes by step
490
+ if delta_magnitudes and steps_delta_magnitudes:
491
+ plt.figure(figsize=(10, 6))
492
+ steps = sorted(steps_delta_magnitudes.keys())
493
+ values = [avg_delta_by_step[step] for step in steps]
494
+ plt.plot(steps, values, marker='o')
495
+ plt.xlabel('Reasoning Step')
496
+ plt.ylabel('Average Delta Magnitude')
497
+ plt.title('Delta Magnitude by Reasoning Step')
498
+ plt.grid(True)
499
+ plt.savefig(os.path.join(viz_dir, 'delta_magnitudes.png'))
500
+ plt.close()
501
+
502
+ # Plot gate values by step
503
+ if gate_values and steps_gate_values:
504
+ plt.figure(figsize=(10, 6))
505
+ steps = sorted(steps_gate_values.keys())
506
+ values = [avg_gate_by_step[step] for step in steps]
507
+ plt.plot(steps, values, marker='o')
508
+ plt.xlabel('Reasoning Step')
509
+ plt.ylabel('Average Gate Value')
510
+ plt.title('Gate Value by Reasoning Step')
511
+ plt.grid(True)
512
+ plt.savefig(os.path.join(viz_dir, 'gate_values.png'))
513
+ plt.close()
514
+
515
+ logger.info(f"Visualizations saved to {viz_dir}")
516
+
517
+ if __name__ == "__main__":
518
+ # This is required for Windows to properly handle multiprocessing
519
+ multiprocessing.freeze_support()
520
+ main()
521
+
522
+ # Example usage:
523
+ # Test with default settings (epoch 3 checkpoint):
524
+ # python test_model.py
525
+
526
+ # Test with specific checkpoint:
527
+ # python test_model.py --checkpoint ./rrn_qa_model_epoch_2
528
+
529
+ # Test with fixed number of reasoning steps:
530
+ # python test_model.py --fixed_steps 3
531
+
532
+ # Test with active memory:
533
+ # python test_model.py --use_memory
534
+
535
+ # Test with visualizations:
536
+ # python test_model.py --visualize
code/train.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py (Updated for Full Fine-tuning)
2
+ import torch
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.amp import autocast, GradScaler # For mixed precision training (updated import)
6
+ from transformers import AutoTokenizer, default_data_collator
7
+ from datasets import load_dataset
8
+ from tqdm.auto import tqdm # Progress bar
9
+ import os
10
+ import evaluate # For metrics
11
+ import logging # Optional: Better logging
12
+ import multiprocessing # For Windows multiprocessing support
13
+ import argparse # For command line arguments
14
+
15
+ # Import our custom modules and config
16
+ import config
17
+ from model import EnhancedRRN_QA_Model
18
+
19
+ # Setup basic logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def main():
24
+ # Parse command line arguments
25
+ parser = argparse.ArgumentParser(description="Train RRN QA Model")
26
+ parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from")
27
+ parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from")
28
+ parser.add_argument(
29
+ "--subset_percentage",
30
+ type=float,
31
+ default=100.0,
32
+ help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)"
33
+ )
34
+ parser.add_argument(
35
+ "--bypass_delta",
36
+ action="store_true",
37
+ help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))"
38
+ )
39
+ args = parser.parse_args()
40
+
41
+ # Set bypass delta calculation flag if specified
42
+ if args.bypass_delta:
43
+ logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)")
44
+ config.BYPASS_DELTA_CALCULATION = True
45
+ else:
46
+ config.BYPASS_DELTA_CALCULATION = False
47
+
48
+ # --- 1. Load Tokenizer and Model ---
49
+ if args.checkpoint:
50
+ logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}")
51
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
52
+
53
+ logger.info(f"Loading model from checkpoint: {args.checkpoint}")
54
+ # Initialize the model with base architecture
55
+ model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model"))
56
+
57
+ # Check for enhanced model components
58
+ gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth")
59
+ is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
60
+
61
+ # Load custom module weights
62
+ logger.info("Loading model components...")
63
+ model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth")))
64
+ model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth")))
65
+
66
+ # Load gating mechanism if available
67
+ if is_enhanced_checkpoint:
68
+ logger.info("Loading gating mechanism...")
69
+ model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path))
70
+
71
+ # Load step controller if available (for learned dynamic steps)
72
+ step_controller_path = os.path.join(args.checkpoint, "step_controller.pth")
73
+ if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
74
+ logger.info("Loading step controller for learned dynamic steps...")
75
+ model.step_controller.load_state_dict(torch.load(step_controller_path))
76
+ else:
77
+ logger.info("Loading tokenizer...")
78
+ tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)
79
+
80
+ logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...")
81
+ model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
82
+
83
+ model.to(config.DEVICE)
84
+
85
+ # --- 2. Load and Preprocess Dataset ---
86
+ logger.info("Loading SQuAD dataset...")
87
+ raw_datasets = load_dataset("squad")
88
+
89
+ # Handle dataset subsetting
90
+ subset_percentage = args.subset_percentage
91
+ if subset_percentage < 100.0:
92
+ original_train_size = len(raw_datasets["train"])
93
+
94
+ # Calculate subset size and validate
95
+ subset_percentage = max(0.1, min(100.0, subset_percentage)) # Clamp between 0.1% and 100%
96
+ train_subset_size = int(original_train_size * subset_percentage / 100)
97
+ train_subset_size = max(100, min(original_train_size, train_subset_size)) # Ensure reasonable bounds
98
+
99
+ # Create reproducible subset with fixed seed for consistency
100
+ subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist()
101
+ raw_datasets["train"] = raw_datasets["train"].select(subset_indices)
102
+
103
+ logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)")
104
+ else:
105
+ logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)")
106
+
107
+ question_column_name = "question"
108
+ context_column_name = "context"
109
+ answer_column_name = "answers"
110
+ pad_on_right = tokenizer.padding_side == "right"
111
+
112
+ def prepare_train_features(examples):
113
+ examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
114
+ tokenized_examples = tokenizer(
115
+ examples[question_column_name if pad_on_right else context_column_name],
116
+ examples[context_column_name if pad_on_right else question_column_name],
117
+ truncation="only_second" if pad_on_right else "only_first",
118
+ max_length=config.MAX_SEQ_LENGTH,
119
+ stride=config.DOC_STRIDE,
120
+ return_overflowing_tokens=True,
121
+ return_offsets_mapping=True,
122
+ padding="max_length",
123
+ )
124
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
125
+ offset_mapping = tokenized_examples.pop("offset_mapping")
126
+ tokenized_examples["start_positions"] = []
127
+ tokenized_examples["end_positions"] = []
128
+
129
+ for i, offsets in enumerate(offset_mapping):
130
+ input_ids = tokenized_examples["input_ids"][i]
131
+ cls_index = input_ids.index(tokenizer.cls_token_id)
132
+ sequence_ids = tokenized_examples.sequence_ids(i)
133
+ sample_index = sample_mapping[i]
134
+ answers = examples[answer_column_name][sample_index]
135
+
136
+ if len(answers["answer_start"]) == 0:
137
+ tokenized_examples["start_positions"].append(cls_index)
138
+ tokenized_examples["end_positions"].append(cls_index)
139
+ else:
140
+ start_char = answers["answer_start"][0]
141
+ end_char = start_char + len(answers["text"][0])
142
+ token_start_index = 0
143
+ while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
144
+ token_start_index += 1
145
+ token_end_index = len(input_ids) - 1
146
+ while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
147
+ token_end_index -= 1
148
+ if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
149
+ tokenized_examples["start_positions"].append(cls_index)
150
+ tokenized_examples["end_positions"].append(cls_index)
151
+ else:
152
+ while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
153
+ token_start_index += 1
154
+ tokenized_examples["start_positions"].append(token_start_index - 1)
155
+ while offsets[token_end_index][1] >= end_char:
156
+ token_end_index -= 1
157
+ tokenized_examples["end_positions"].append(token_end_index + 1)
158
+ return tokenized_examples
159
+
160
+ logger.info("Preprocessing datasets...")
161
+ # Use single process on Windows to avoid multiprocessing issues
162
+ tokenized_datasets = raw_datasets.map(
163
+ prepare_train_features,
164
+ batched=True,
165
+ remove_columns=raw_datasets["train"].column_names,
166
+ num_proc=1 # Use single process to avoid Windows multiprocessing issues
167
+ )
168
+
169
+ data_collator = default_data_collator
170
+ train_dataloader = DataLoader(
171
+ tokenized_datasets["train"],
172
+ shuffle=True,
173
+ collate_fn=data_collator,
174
+ batch_size=config.BATCH_SIZE
175
+ )
176
+ # Consider adding validation dataloader setup here as well
177
+ # eval_dataloader = DataLoader(...)
178
+
179
+ # --- 3. Setup Optimizer ---
180
+ logger.info("Setting up optimizer for FULL model fine-tuning...")
181
+ # Optimize all parameters since PEFT is disabled
182
+ optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
183
+
184
+ logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}")
185
+ # Calculate total steps considering gradient accumulation
186
+ num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
187
+ num_training_steps = config.EPOCHS * num_update_steps_per_epoch
188
+ logger.info(f"Total optimization steps: {num_training_steps}")
189
+
190
+
191
+ # --- 4. Initialize Mixed Precision Training ---
192
+ # Initialize gradient scaler for mixed precision training
193
+ scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION) # Updated to fix deprecation warning
194
+
195
+ # Log mixed precision and dynamic steps status
196
+ if config.USE_MIXED_PRECISION:
197
+ logger.info("Mixed precision training (FP16) enabled")
198
+ if config.USE_DYNAMIC_STEPS:
199
+ logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})")
200
+ logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}")
201
+
202
+ # Log bypass delta calculation status
203
+ if config.BYPASS_DELTA_CALCULATION:
204
+ logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))")
205
+
206
+ # --- 5. Training Loop ---
207
+ logger.info("***** Starting Training *****")
208
+ logger.info(f" Num examples = {len(tokenized_datasets['train'])}")
209
+ logger.info(f" Num Epochs = {config.EPOCHS}")
210
+ logger.info(f" Instantaneous batch size per device = {config.BATCH_SIZE}")
211
+ logger.info(f" Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}")
212
+ logger.info(f" Total optimization steps = {num_training_steps}")
213
+
214
+ # Add note about subset training if applicable
215
+ if subset_percentage < 100.0:
216
+ logger.info(f" NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance")
217
+
218
+
219
+ model.train() # Set model to training mode
220
+ global_step = 0
221
+ total_loss = 0.0 # Use float for accumulated loss
222
+
223
+ # Start from specified epoch (default is 0 if not provided)
224
+ start_epoch = args.start_epoch
225
+
226
+ for epoch in range(start_epoch, config.EPOCHS):
227
+ logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---")
228
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch")
229
+
230
+ for step, batch in enumerate(progress_bar):
231
+ # Move batch to device
232
+ # Ensure only tensors are moved, handle potential non-tensor data if any
233
+ batch_on_device = {}
234
+ for k, v in batch.items():
235
+ if isinstance(v, torch.Tensor):
236
+ batch_on_device[k] = v.to(config.DEVICE)
237
+ # else: # Handle or skip non-tensor items if necessary
238
+ # batch_on_device[k] = v
239
+
240
+ try:
241
+ # Forward pass with autocast for mixed precision
242
+ with autocast('cuda', enabled=config.USE_MIXED_PRECISION): # Updated to fix deprecation warning
243
+ outputs = model(
244
+ input_ids=batch_on_device.get("input_ids"),
245
+ attention_mask=batch_on_device.get("attention_mask"),
246
+ token_type_ids=batch_on_device.get("token_type_ids"),
247
+ start_positions=batch_on_device.get("start_positions"),
248
+ end_positions=batch_on_device.get("end_positions"),
249
+ use_memory=False # Disable memory during training steps
250
+ )
251
+ loss = outputs.loss
252
+
253
+ if loss is None:
254
+ logger.warning(f"Step {step}: Loss is None. Skipping batch.")
255
+ continue
256
+
257
+ # Scale loss for gradient accumulation
258
+ loss = loss / config.GRADIENT_ACCUMULATION_STEPS
259
+
260
+ # Accumulate loss value for logging (before backward)
261
+ total_loss += loss.item()
262
+
263
+ # Scale loss and perform backward pass with AMP
264
+ scaler.scale(loss).backward()
265
+
266
+ except Exception as e:
267
+ logger.error(f"Error during forward/backward pass at step {step}: {e}")
268
+ # Optional: Add more detailed error handling or debugging info
269
+ # logger.error(f"Batch keys: {batch.keys()}")
270
+ # logger.error(f"Input IDs shape: {batch_on_device.get('input_ids').shape if batch_on_device.get('input_ids') is not None else 'None'}")
271
+ raise e # Re-raise the exception to stop training
272
+
273
+ # Optimizer step (perform step only after accumulating gradients)
274
+ if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1:
275
+ # Unscale before optimizer step (to check for infs/NaNs)
276
+ scaler.unscale_(optimizer)
277
+
278
+ # Clip gradients to avoid explosion (optional but recommended with mixed precision)
279
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
280
+
281
+ # Step with scaler
282
+ scaler.step(optimizer)
283
+ scaler.update()
284
+ optimizer.zero_grad() # Reset gradients for the next accumulation cycle
285
+ global_step += 1
286
+
287
+ # Log progress periodically
288
+ if global_step % 50 == 0: # Log every 50 optimization steps
289
+ avg_loss = total_loss / 50 # Average loss over the last 50 steps
290
+ logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}")
291
+ total_loss = 0.0 # Reset loss accumulator
292
+
293
+ # Update progress bar description with current step loss and steps info
294
+ postfix = {
295
+ "Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}",
296
+ "Step": global_step
297
+ }
298
+
299
+ # Add steps info if using dynamic steps
300
+ if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'):
301
+ if 'steps_taken' in model.custom_outputs:
302
+ postfix["Steps"] = model.custom_outputs['steps_taken']
303
+
304
+ progress_bar.set_postfix(postfix)
305
+
306
+
307
+ # --- (Optional) Evaluation at the end of each epoch ---
308
+ # logger.info(f"\n--- Evaluating after Epoch {epoch+1} ---")
309
+ # model.eval()
310
+ # # Add evaluation loop here (requires validation dataloader, postprocessing, metrics)
311
+ # model.train() # Set back to train mode
312
+
313
+ # --- Save Model Checkpoint ---
314
+ output_dir = f"./rrn_qa_model_epoch_{epoch+1}"
315
+ os.makedirs(output_dir, exist_ok=True)
316
+ logger.info(f"--- Saving model checkpoint to {output_dir} ---")
317
+
318
+ # --- Saving Logic for Enhanced Model ---
319
+ try:
320
+ logger.info(f"Saving enhanced model components to {output_dir}")
321
+ # Save base model using its save_pretrained
322
+ model.base_model.save_pretrained(os.path.join(output_dir, "base_model"))
323
+
324
+ # Save all custom modules' state dicts
325
+ torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth"))
326
+ torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth"))
327
+ torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth"))
328
+
329
+ # Save step controller if using learned dynamic steps
330
+ if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"):
331
+ torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth"))
332
+ logger.info("Saved step controller for learned dynamic steps")
333
+
334
+ # Save tokenizer
335
+ tokenizer.save_pretrained(output_dir)
336
+
337
+ # Save configuration
338
+ with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f:
339
+ import json
340
+ config_dict = {
341
+ "num_reasoning_steps": config.NUM_REASONING_STEPS,
342
+ "delta_target_ratio": config.DELTA_TARGET_RATIO,
343
+ "lambda_coherence": config.LAMBDA_COHERENCE,
344
+ "lambda_delta_reg": config.LAMBDA_DELTA_REG,
345
+ "memory_max_size": config.MEMORY_MAX_SIZE,
346
+ "memory_retrieval_k": config.MEMORY_RETRIEVAL_K,
347
+ "use_mixed_precision": config.USE_MIXED_PRECISION,
348
+ "bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION
349
+ }
350
+
351
+ # Add dynamic steps configuration if enabled
352
+ if config.USE_DYNAMIC_STEPS:
353
+ config_dict.update({
354
+ "use_dynamic_steps": config.USE_DYNAMIC_STEPS,
355
+ "max_reasoning_steps": config.MAX_REASONING_STEPS,
356
+ "min_reasoning_steps": config.MIN_REASONING_STEPS,
357
+ "reasoning_step_type": config.REASONING_STEP_TYPE,
358
+ "early_stop_threshold": config.EARLY_STOP_THRESHOLD
359
+ })
360
+
361
+ json.dump(config_dict, f, indent=2)
362
+
363
+ logger.info("Enhanced model checkpoint saved successfully.")
364
+ except Exception as e:
365
+ logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}")
366
+
367
+
368
+ logger.info("\n***** Training finished *****")
369
+
370
+ if __name__ == "__main__":
371
+ # This is required for Windows to properly handle multiprocessing
372
+ multiprocessing.freeze_support()
373
+ main()
374
+
375
+ # Example usage:
376
+ # Train on full dataset (default):
377
+ # python train.py
378
+
379
+ # Train on 10% of data for faster iterations:
380
+ # python train.py --subset_percentage 10.0
381
+
382
+ # Train on 1% for very quick testing:
383
+ # python train.py --subset_percentage 1.0
384
+
385
+ # Resume training from checkpoint with subset:
386
+ # python train.py --checkpoint ./rrn_qa_model_epoch_1 --start_epoch 1 --subset_percentage 25.0
387
+
388
+ # Test with bypassed delta calculation (sets delta = torch.zeros_like(h0)):
389
+ # python train.py --bypass_delta --subset_percentage 1.0
enhanced_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_reasoning_steps": 3,
3
+ "delta_target_ratio": 0.2,
4
+ "lambda_coherence": 0.1,
5
+ "lambda_delta_reg": 0.5,
6
+ "memory_max_size": 50,
7
+ "memory_retrieval_k": 3,
8
+ "use_mixed_precision": false,
9
+ "bypass_delta_calculation": false,
10
+ "use_dynamic_steps": true,
11
+ "max_reasoning_steps": 5,
12
+ "min_reasoning_steps": 1,
13
+ "reasoning_step_type": "learned",
14
+ "early_stop_threshold": 0.01
15
+ }
gating_mechanism.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a1b4605a1297d5d738dda9e6b739764e6412e27053f75df64ebe99f5f49188c
3
+ size 7090146
model-index.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Retroactive Reasoning Network for Question Answering",
3
+ "language": "en",
4
+ "license": "apache-2.0",
5
+ "task_categories": [
6
+ "question-answering"
7
+ ],
8
+ "tags": [
9
+ "question-answering",
10
+ "multi-step-reasoning",
11
+ "retroactive-reasoning",
12
+ "squad"
13
+ ],
14
+ "datasets": [
15
+ "squad"
16
+ ],
17
+ "metrics": [
18
+ "exact_match",
19
+ "f1"
20
+ ],
21
+ "model-index": [
22
+ {
23
+ "name": "RRN QA Model",
24
+ "results": [
25
+ {
26
+ "task": {
27
+ "type": "question-answering",
28
+ "name": "Question Answering"
29
+ },
30
+ "dataset": {
31
+ "name": "SQuAD",
32
+ "type": "squad"
33
+ },
34
+ "metrics": [
35
+ {
36
+ "type": "exact_match",
37
+ "value": "TBD",
38
+ "name": "Exact Match"
39
+ },
40
+ {
41
+ "type": "f1",
42
+ "value": "TBD",
43
+ "name": "F1"
44
+ }
45
+ ]
46
+ }
47
+ ]
48
+ }
49
+ ]
50
+ }
qa_head.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70288115fb5d3176e59d873418d1d1e7fd4dc9c3d6911f8edeee93354227b82b
3
+ size 14175663
retroactive_layer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1faa8764e567c3b5575c5aff1b8c4c3ed7486935730923c692dae3d4122cefbd
3
+ size 59048138
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
step_controller.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cae0d7b2423071c8e36f0d07965a540193fb739638351115bfa8045b9844aac7
3
+ size 398480
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff