Upload folder using huggingface_hub
Browse files- README.md +79 -0
- base_model/config.json +26 -0
- base_model/model.safetensors +3 -0
- code/config.py +75 -0
- code/memory.py +187 -0
- code/model.py +368 -0
- code/modules.py +196 -0
- code/test_model.py +536 -0
- code/train.py +389 -0
- enhanced_config.json +15 -0
- gating_mechanism.pth +3 -0
- model-index.json +50 -0
- qa_head.pth +3 -0
- retroactive_layer.pth +3 -0
- special_tokens_map.json +7 -0
- step_controller.pth +3 -0
- tokenizer.json +0 -0
- tokenizer_config.json +56 -0
- vocab.txt +0 -0
README.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Retroactive Reasoning Network (RRN) for Question Answering
|
| 2 |
+
|
| 3 |
+
## Model Description
|
| 4 |
+
|
| 5 |
+
This model implements an Enhanced Retroactive Reasoning Network (RRN) for Question Answering tasks. The RRN architecture enables multi-step reasoning through an iterative refinement process that retroactively updates hidden states.
|
| 6 |
+
|
| 7 |
+
### Key Features
|
| 8 |
+
|
| 9 |
+
- **Multi-step Reasoning**: The model performs 3 reasoning steps to iteratively refine its predictions.
|
| 10 |
+
- **Dynamic Reasoning Steps**: Enabled - Uses a learned approach to determine the number of steps (min: 1, max: 5)
|
| 11 |
+
- **Gating Mechanism**: Selectively applies updates to hidden states.
|
| 12 |
+
- **Delta Magnitude Constraint**: Prevents destabilizing updates with a target ratio of 0.2.
|
| 13 |
+
- **Active Memory**: Stores and retrieves examples to enhance reasoning.
|
| 14 |
+
|
| 15 |
+
## Usage
|
| 16 |
+
|
| 17 |
+
```python
|
| 18 |
+
from transformers import AutoTokenizer
|
| 19 |
+
from model import EnhancedRRN_QA_Model
|
| 20 |
+
|
| 21 |
+
# Load tokenizer and model
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained("[MODEL_REPO_ID]")
|
| 23 |
+
model = EnhancedRRN_QA_Model("[MODEL_REPO_ID]/base_model")
|
| 24 |
+
|
| 25 |
+
# Load custom components
|
| 26 |
+
import torch
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
model.qa_head.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "qa_head.pth")))
|
| 30 |
+
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "retroactive_layer.pth")))
|
| 31 |
+
model.gating_mechanism.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "gating_mechanism.pth")))
|
| 32 |
+
|
| 33 |
+
# If using learned dynamic steps
|
| 34 |
+
if os.path.exists(os.path.join("[MODEL_REPO_ID]", "step_controller.pth")) and hasattr(model, "step_controller"):
|
| 35 |
+
model.step_controller.load_state_dict(torch.load(os.path.join("[MODEL_REPO_ID]", "step_controller.pth")))
|
| 36 |
+
|
| 37 |
+
# Example usage
|
| 38 |
+
inputs = tokenizer("What is the capital of France?", "Paris is the capital of France.", return_tensors="pt")
|
| 39 |
+
outputs = model(**inputs)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Training
|
| 43 |
+
|
| 44 |
+
This model was trained on the SQuAD dataset using a multi-step reasoning approach. The training code is included in the `code` directory of this repository.
|
| 45 |
+
|
| 46 |
+
To train your own model:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python code/train.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
To evaluate the model:
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
python code/test_model.py
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Model Architecture
|
| 59 |
+
|
| 60 |
+
The RRN architecture consists of:
|
| 61 |
+
|
| 62 |
+
1. A base language model (BERT)
|
| 63 |
+
2. A retroactive update layer that computes delta updates
|
| 64 |
+
3. A gating mechanism for selective updates
|
| 65 |
+
4. An enhanced QA head for answer prediction
|
| 66 |
+
5. A step controller for dynamic reasoning steps (if enabled)
|
| 67 |
+
|
| 68 |
+
## Citation
|
| 69 |
+
|
| 70 |
+
If you use this model in your research, please cite:
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
@article{rrn_qa_model,
|
| 74 |
+
title={Retroactive Reasoning Networks for Question Answering},
|
| 75 |
+
author={[Authors]},
|
| 76 |
+
journal={[Journal]},
|
| 77 |
+
year={2025}
|
| 78 |
+
}
|
| 79 |
+
```
|
base_model/config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"gradient_checkpointing": false,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_dropout_prob": 0.1,
|
| 10 |
+
"hidden_size": 768,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": 3072,
|
| 13 |
+
"layer_norm_eps": 1e-12,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "bert",
|
| 16 |
+
"num_attention_heads": 12,
|
| 17 |
+
"num_hidden_layers": 12,
|
| 18 |
+
"output_hidden_states": true,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"position_embedding_type": "absolute",
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.51.2",
|
| 23 |
+
"type_vocab_size": 2,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"vocab_size": 30522
|
| 26 |
+
}
|
base_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0e649b208cf08b8748542d4303a944dfe14f9aeefc6bfe2bed4fca9dbb7c0ba
|
| 3 |
+
size 437951328
|
code/config.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py (Updated to Disable PEFT)
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# --- Model Configuration ---
|
| 5 |
+
# Base model from Hugging Face (ensure it's suitable for QA)
|
| 6 |
+
# Example: 'bert-base-uncased', 'roberta-base', 'bert-large-uncased-whole-word-masking-finetuned-squad'
|
| 7 |
+
BASE_MODEL_NAME = "bert-base-uncased"
|
| 8 |
+
|
| 9 |
+
# --- RRN Specific Configuration ---
|
| 10 |
+
# Coherence loss weight
|
| 11 |
+
LAMBDA_COHERENCE = 0.1 # Hyperparameter to tune
|
| 12 |
+
|
| 13 |
+
# --- Delta Constraint Configuration ---
|
| 14 |
+
DELTA_TARGET_RATIO = 0.2 # Target ratio of delta norm to h0 norm
|
| 15 |
+
LAMBDA_DELTA_REG = 0.5 # Weight for delta regularization loss
|
| 16 |
+
|
| 17 |
+
# --- Multi-step Reasoning Configuration ---
|
| 18 |
+
NUM_REASONING_STEPS = 3 # Default number of reasoning steps (used when dynamic steps disabled)
|
| 19 |
+
|
| 20 |
+
# --- Dynamic Reasoning Steps Configuration ---
|
| 21 |
+
USE_DYNAMIC_STEPS = True # Enable/disable dynamic reasoning steps
|
| 22 |
+
MAX_REASONING_STEPS = 5 # Maximum number of reasoning steps
|
| 23 |
+
MIN_REASONING_STEPS = 1 # Minimum number of reasoning steps
|
| 24 |
+
REASONING_STEP_TYPE = "learned" # Options: "fixed", "confidence", "learned"
|
| 25 |
+
EARLY_STOP_THRESHOLD = 0.01 # Delta magnitude threshold for early stopping (used with "confidence")
|
| 26 |
+
|
| 27 |
+
# --- Mixed Precision Configuration ---
|
| 28 |
+
USE_MIXED_PRECISION = False # Enable/disable mixed precision training
|
| 29 |
+
|
| 30 |
+
# --- Memory Configuration ---
|
| 31 |
+
MEMORY_MAX_SIZE = 50 # Max number of entries in the memory
|
| 32 |
+
MEMORY_USE_DURING_TRAINING = False # Whether to use memory during training
|
| 33 |
+
MEMORY_RETRIEVAL_K = 3 # Number of examples to retrieve from memory
|
| 34 |
+
|
| 35 |
+
# --- PEFT (LoRA) Configuration ---
|
| 36 |
+
USE_PEFT = False # <--- SET TO False TO DISABLE PEFT ---
|
| 37 |
+
|
| 38 |
+
# --- Optional: Comment out or leave the LoRA specific settings ---
|
| 39 |
+
# LORA_R = 8
|
| 40 |
+
# LORA_ALPHA = 16
|
| 41 |
+
# LORA_DROPOUT = 0.1
|
| 42 |
+
# LORA_TARGET_MODULES = ["query", "value"]
|
| 43 |
+
|
| 44 |
+
# --- Testing Configuration ---
|
| 45 |
+
BYPASS_DELTA_CALCULATION = False # Set to True to bypass delta calculation for testing
|
| 46 |
+
|
| 47 |
+
# --- Training Configuration ---
|
| 48 |
+
# <<< --- Device Detection (CUDA prioritized over MPS) --- >>>
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
DEVICE = "cuda"
|
| 51 |
+
print("CUDA GPU acceleration is available.")
|
| 52 |
+
elif torch.backends.mps.is_available():
|
| 53 |
+
DEVICE = "mps"
|
| 54 |
+
print("Apple Silicon MPS acceleration is available.")
|
| 55 |
+
else:
|
| 56 |
+
DEVICE = "cpu"
|
| 57 |
+
print("No GPU or MPS acceleration available, using CPU.")
|
| 58 |
+
# <<< --- End of Device Detection --- >>>
|
| 59 |
+
|
| 60 |
+
LEARNING_RATE = 1e-5 # Full fine-tuning often uses a smaller LR than PEFT
|
| 61 |
+
EPOCHS = 3
|
| 62 |
+
# --- Adjust Batch Size for Full Fine-tuning ---
|
| 63 |
+
# Full fine-tuning requires significantly more memory
|
| 64 |
+
BATCH_SIZE = 4 # Start smaller, adjust based on your CUDA memory
|
| 65 |
+
GRADIENT_ACCUMULATION_STEPS = 8 # Increase to compensate for smaller batch size
|
| 66 |
+
|
| 67 |
+
# --- Dataset Configuration ---
|
| 68 |
+
# Example for SQuAD
|
| 69 |
+
MAX_SEQ_LENGTH = 320 # Max input sequence length for QA
|
| 70 |
+
DOC_STRIDE = 128 # Stride for overlapping chunks for long documents
|
| 71 |
+
|
| 72 |
+
print(f"Using device: {DEVICE}")
|
| 73 |
+
print(f"Base model: {BASE_MODEL_NAME}")
|
| 74 |
+
# Update print statement to reflect PEFT status
|
| 75 |
+
print(f"Using PEFT (LoRA): {USE_PEFT} - Full Fine-tuning Enabled")
|
code/memory.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# memory.py
|
| 2 |
+
from collections import deque
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import config
|
| 6 |
+
|
| 7 |
+
class ActiveMemory:
|
| 8 |
+
"""
|
| 9 |
+
An active memory module that stores and retrieves examples to enhance reasoning.
|
| 10 |
+
Supports both logging for analysis and retrieval for improved predictions.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
|
| 13 |
+
self.max_size = max_size
|
| 14 |
+
self.retrieval_k = retrieval_k
|
| 15 |
+
self.memory = deque(maxlen=max_size)
|
| 16 |
+
self.device = config.DEVICE
|
| 17 |
+
print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")
|
| 18 |
+
|
| 19 |
+
def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
|
| 20 |
+
"""
|
| 21 |
+
Adds a new entry to the memory.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
input_data: The input to the model (tokenized IDs, attention masks, etc.)
|
| 25 |
+
hidden_states (H0): Initial hidden states from the base model
|
| 26 |
+
output (y0): Initial prediction from the model
|
| 27 |
+
reasoning_trace (T): Reasoning trace (all hidden states)
|
| 28 |
+
final_hidden_states (H1, optional): Final hidden states after retroactive update
|
| 29 |
+
final_output (y1, optional): Final prediction after retroactive update
|
| 30 |
+
"""
|
| 31 |
+
# Create a memory entry with detached tensors moved to CPU
|
| 32 |
+
entry = {
|
| 33 |
+
'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
|
| 34 |
+
'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
|
| 35 |
+
'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
|
| 36 |
+
'hidden_states': hidden_states.cpu().detach(),
|
| 37 |
+
'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
|
| 38 |
+
'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Add final states if provided
|
| 42 |
+
if final_hidden_states is not None:
|
| 43 |
+
entry['final_hidden_states'] = final_hidden_states.cpu().detach()
|
| 44 |
+
if final_output is not None:
|
| 45 |
+
entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
|
| 46 |
+
|
| 47 |
+
# Compute and store a summary vector for efficient retrieval
|
| 48 |
+
# Use mean pooling of hidden states as the summary vector
|
| 49 |
+
if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
|
| 50 |
+
# Mean pooling with attention mask
|
| 51 |
+
mask = entry['attention_mask'].unsqueeze(-1).float()
|
| 52 |
+
masked_embeddings = entry['hidden_states'] * mask
|
| 53 |
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
| 54 |
+
sum_mask = torch.sum(mask, dim=1)
|
| 55 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 56 |
+
entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
|
| 57 |
+
else:
|
| 58 |
+
# Fallback to simple mean if attention mask is not available
|
| 59 |
+
entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
|
| 60 |
+
|
| 61 |
+
self.memory.append(entry)
|
| 62 |
+
|
| 63 |
+
def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
|
| 64 |
+
"""
|
| 65 |
+
Retrieves the k most similar examples from memory based on hidden state similarity.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
query_hidden_states: Hidden states to compare against memory
|
| 69 |
+
query_attention_mask: Attention mask for the query
|
| 70 |
+
k: Number of examples to retrieve (defaults to self.retrieval_k)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List of retrieved memory entries, ordered by similarity (most similar first)
|
| 74 |
+
"""
|
| 75 |
+
if len(self.memory) == 0:
|
| 76 |
+
return []
|
| 77 |
+
|
| 78 |
+
if k is None:
|
| 79 |
+
k = self.retrieval_k
|
| 80 |
+
|
| 81 |
+
k = min(k, len(self.memory))
|
| 82 |
+
|
| 83 |
+
# Compute query summary vector (mean pooling with attention mask)
|
| 84 |
+
if query_attention_mask is not None:
|
| 85 |
+
mask = query_attention_mask.unsqueeze(-1).float()
|
| 86 |
+
masked_embeddings = query_hidden_states * mask
|
| 87 |
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
| 88 |
+
sum_mask = torch.sum(mask, dim=1)
|
| 89 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| 90 |
+
query_vector = (sum_embeddings / sum_mask).squeeze(0)
|
| 91 |
+
else:
|
| 92 |
+
query_vector = query_hidden_states.mean(dim=1).squeeze(0)
|
| 93 |
+
|
| 94 |
+
# Move query vector to CPU for comparison with memory
|
| 95 |
+
query_vector = query_vector.cpu().detach()
|
| 96 |
+
|
| 97 |
+
# Compute similarities with all memory entries
|
| 98 |
+
similarities = []
|
| 99 |
+
for i, entry in enumerate(self.memory):
|
| 100 |
+
memory_vector = entry['summary_vector']
|
| 101 |
+
# Compute cosine similarity
|
| 102 |
+
similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
|
| 103 |
+
similarities.append((i, similarity.item()))
|
| 104 |
+
|
| 105 |
+
# Sort by similarity (descending) and get top k
|
| 106 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 107 |
+
top_k_indices = [idx for idx, _ in similarities[:k]]
|
| 108 |
+
|
| 109 |
+
# Retrieve the top k entries
|
| 110 |
+
retrieved_entries = [self.memory[idx] for idx in top_k_indices]
|
| 111 |
+
|
| 112 |
+
# Move retrieved entries to the same device as the query
|
| 113 |
+
device = query_hidden_states.device
|
| 114 |
+
for entry in retrieved_entries:
|
| 115 |
+
# Only move the tensors we'll actually use (hidden_states and final_hidden_states)
|
| 116 |
+
if 'hidden_states' in entry:
|
| 117 |
+
entry['hidden_states'] = entry['hidden_states'].to(device)
|
| 118 |
+
if 'final_hidden_states' in entry:
|
| 119 |
+
entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
|
| 120 |
+
|
| 121 |
+
return retrieved_entries
|
| 122 |
+
|
| 123 |
+
def get_memory_context(self, query_hidden_states, query_attention_mask=None):
|
| 124 |
+
"""
|
| 125 |
+
Retrieves and processes memory entries to create a context tensor for the model.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
query_hidden_states: Hidden states to compare against memory
|
| 129 |
+
query_attention_mask: Attention mask for the query
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing
|
| 133 |
+
processed memory information, or None if memory is empty
|
| 134 |
+
"""
|
| 135 |
+
# Retrieve similar examples from memory
|
| 136 |
+
retrieved = self.retrieve(query_hidden_states, query_attention_mask)
|
| 137 |
+
|
| 138 |
+
if not retrieved:
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
# Use the device of the query
|
| 142 |
+
device = query_hidden_states.device
|
| 143 |
+
batch_size, seq_len, hidden_dim = query_hidden_states.shape
|
| 144 |
+
|
| 145 |
+
# Process retrieved examples to create memory context
|
| 146 |
+
# Strategy: Average the final hidden states of retrieved examples
|
| 147 |
+
memory_tensors = []
|
| 148 |
+
for entry in retrieved:
|
| 149 |
+
# Prefer final hidden states if available, otherwise use initial hidden states
|
| 150 |
+
if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
|
| 151 |
+
memory_tensors.append(entry['final_hidden_states'])
|
| 152 |
+
elif 'hidden_states' in entry:
|
| 153 |
+
memory_tensors.append(entry['hidden_states'])
|
| 154 |
+
|
| 155 |
+
if not memory_tensors:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
# Average the memory tensors
|
| 159 |
+
# First ensure all tensors have the same sequence length by padding or truncating
|
| 160 |
+
padded_tensors = []
|
| 161 |
+
for tensor in memory_tensors:
|
| 162 |
+
if tensor.size(1) < seq_len:
|
| 163 |
+
# Pad
|
| 164 |
+
padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
|
| 165 |
+
padded_tensor = torch.cat([tensor, padding], dim=1)
|
| 166 |
+
padded_tensors.append(padded_tensor)
|
| 167 |
+
elif tensor.size(1) > seq_len:
|
| 168 |
+
# Truncate
|
| 169 |
+
padded_tensors.append(tensor[:, :seq_len, :])
|
| 170 |
+
else:
|
| 171 |
+
padded_tensors.append(tensor)
|
| 172 |
+
|
| 173 |
+
# Stack and average
|
| 174 |
+
memory_context = torch.stack(padded_tensors).mean(dim=0)
|
| 175 |
+
|
| 176 |
+
# Expand to match batch size if needed
|
| 177 |
+
if memory_context.size(0) == 1 and batch_size > 1:
|
| 178 |
+
memory_context = memory_context.expand(batch_size, -1, -1)
|
| 179 |
+
|
| 180 |
+
return memory_context
|
| 181 |
+
|
| 182 |
+
def clear(self):
|
| 183 |
+
"""Clears all entries from memory."""
|
| 184 |
+
self.memory.clear()
|
| 185 |
+
|
| 186 |
+
def __len__(self):
|
| 187 |
+
return len(self.memory)
|
code/model.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py (Enhanced RRN Implementation)
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from transformers import AutoModelForQuestionAnswering, AutoConfig, AutoModel
|
| 6 |
+
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
| 7 |
+
|
| 8 |
+
import config
|
| 9 |
+
from modules import CrossAttentionDelta, GatingMechanism, EnhancedQAHead
|
| 10 |
+
from memory import ActiveMemory
|
| 11 |
+
|
| 12 |
+
class EnhancedRRN_QA_Model(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Enhanced Retroactive Reasoning Network for Question Answering.
|
| 15 |
+
Improvements:
|
| 16 |
+
1. Delta magnitude constraint
|
| 17 |
+
2. Gating mechanism
|
| 18 |
+
3. Multi-step reasoning
|
| 19 |
+
4. Active memory usage
|
| 20 |
+
5. Enhanced QA head
|
| 21 |
+
6. Improved cross-attention
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, model_name=config.BASE_MODEL_NAME):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.model_name = model_name
|
| 26 |
+
|
| 27 |
+
# --- Configuration ---
|
| 28 |
+
self.num_reasoning_steps = config.NUM_REASONING_STEPS
|
| 29 |
+
self.delta_target_ratio = config.DELTA_TARGET_RATIO
|
| 30 |
+
|
| 31 |
+
# --- Dynamic Reasoning Steps Configuration ---
|
| 32 |
+
self.use_dynamic_steps = config.USE_DYNAMIC_STEPS
|
| 33 |
+
self.max_reasoning_steps = config.MAX_REASONING_STEPS
|
| 34 |
+
self.min_reasoning_steps = config.MIN_REASONING_STEPS
|
| 35 |
+
self.reasoning_step_type = config.REASONING_STEP_TYPE
|
| 36 |
+
self.early_stop_threshold = config.EARLY_STOP_THRESHOLD
|
| 37 |
+
|
| 38 |
+
# --- Load Base Model Configuration ---
|
| 39 |
+
self.base_config = AutoConfig.from_pretrained(
|
| 40 |
+
self.model_name,
|
| 41 |
+
output_hidden_states=True, # Crucial for Reasoning Trace (T)
|
| 42 |
+
)
|
| 43 |
+
self.hidden_dim = self.base_config.hidden_size
|
| 44 |
+
|
| 45 |
+
# Add step controller for learned approach (after hidden_dim is defined)
|
| 46 |
+
if self.use_dynamic_steps and self.reasoning_step_type == "learned":
|
| 47 |
+
self.step_controller = nn.Sequential(
|
| 48 |
+
nn.Linear(self.hidden_dim, 128),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Linear(128, self.max_reasoning_steps - self.min_reasoning_steps + 1)
|
| 51 |
+
)
|
| 52 |
+
print(f"Using learned dynamic reasoning steps (min={self.min_reasoning_steps}, max={self.max_reasoning_steps})")
|
| 53 |
+
|
| 54 |
+
# --- Load Base Model ---
|
| 55 |
+
self.base_model = AutoModel.from_pretrained(
|
| 56 |
+
self.model_name,
|
| 57 |
+
config=self.base_config
|
| 58 |
+
)
|
| 59 |
+
print(f"Loaded base model: {self.model_name}")
|
| 60 |
+
print(f"Hidden dimension: {self.hidden_dim}")
|
| 61 |
+
print(f"Using {self.num_reasoning_steps} reasoning steps")
|
| 62 |
+
|
| 63 |
+
# --- Enhanced RRN Components ---
|
| 64 |
+
# Improved cross-attention delta mechanism
|
| 65 |
+
self.retroactive_update_layer = CrossAttentionDelta(self.hidden_dim)
|
| 66 |
+
|
| 67 |
+
# Gating mechanism for selective updates
|
| 68 |
+
self.gating_mechanism = GatingMechanism(self.hidden_dim)
|
| 69 |
+
|
| 70 |
+
# Enhanced QA head with deeper architecture and bilinear scoring
|
| 71 |
+
self.qa_head = EnhancedQAHead(self.hidden_dim)
|
| 72 |
+
|
| 73 |
+
# --- Active Memory Module ---
|
| 74 |
+
self.memory = ActiveMemory(
|
| 75 |
+
max_size=config.MEMORY_MAX_SIZE,
|
| 76 |
+
retrieval_k=config.MEMORY_RETRIEVAL_K
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# --- Loss Functions ---
|
| 80 |
+
self.coherence_loss_fn = nn.MSELoss()
|
| 81 |
+
self.delta_reg_loss_fn = nn.MSELoss()
|
| 82 |
+
|
| 83 |
+
def _apply_delta_constraint(self, delta, h0, is_training=False):
|
| 84 |
+
"""
|
| 85 |
+
Apply delta magnitude constraint to prevent destabilizing updates.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
delta: The computed delta
|
| 89 |
+
h0: The initial hidden states
|
| 90 |
+
is_training: Whether we're in training mode
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
constrained_delta: The constrained delta
|
| 94 |
+
delta_reg_loss: Regularization loss for delta magnitude (if training)
|
| 95 |
+
"""
|
| 96 |
+
# Compute delta and h0 norms
|
| 97 |
+
delta_norm = delta.norm(dim=-1, keepdim=True)
|
| 98 |
+
h0_norm = h0.norm(dim=-1, keepdim=True).detach()
|
| 99 |
+
|
| 100 |
+
# Compute ratio
|
| 101 |
+
ratio = delta_norm / (h0_norm + 1e-9)
|
| 102 |
+
|
| 103 |
+
# Compute regularization loss if in training
|
| 104 |
+
delta_reg_loss = None
|
| 105 |
+
if is_training:
|
| 106 |
+
# Target ratio tensor (same shape as ratio)
|
| 107 |
+
target_ratio = torch.ones_like(ratio) * self.delta_target_ratio
|
| 108 |
+
delta_reg_loss = self.delta_reg_loss_fn(ratio, target_ratio)
|
| 109 |
+
|
| 110 |
+
# Apply direct constraint (both during training and inference)
|
| 111 |
+
# Only scale down deltas that are too large
|
| 112 |
+
scale_factor = torch.ones_like(ratio)
|
| 113 |
+
too_large = ratio > self.delta_target_ratio
|
| 114 |
+
if too_large.any():
|
| 115 |
+
scale_factor[too_large] = self.delta_target_ratio / ratio[too_large]
|
| 116 |
+
|
| 117 |
+
# Apply scaling
|
| 118 |
+
constrained_delta = delta * scale_factor
|
| 119 |
+
|
| 120 |
+
return constrained_delta, delta_reg_loss
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
input_ids=None,
|
| 125 |
+
attention_mask=None,
|
| 126 |
+
token_type_ids=None,
|
| 127 |
+
start_positions=None,
|
| 128 |
+
end_positions=None,
|
| 129 |
+
output_attentions=None,
|
| 130 |
+
output_hidden_states=None,
|
| 131 |
+
return_dict=None,
|
| 132 |
+
use_memory=True
|
| 133 |
+
):
|
| 134 |
+
return_dict = return_dict if return_dict is not None else self.base_config.use_return_dict
|
| 135 |
+
is_training = self.training
|
| 136 |
+
|
| 137 |
+
# === 1. Initial Forward Pass ===
|
| 138 |
+
# Determine if token_type_ids should be passed
|
| 139 |
+
include_token_type_ids = token_type_ids is not None
|
| 140 |
+
|
| 141 |
+
if include_token_type_ids:
|
| 142 |
+
outputs = self.base_model(
|
| 143 |
+
input_ids=input_ids,
|
| 144 |
+
attention_mask=attention_mask,
|
| 145 |
+
token_type_ids=token_type_ids,
|
| 146 |
+
output_hidden_states=True,
|
| 147 |
+
output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
|
| 148 |
+
return_dict=True
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
outputs = self.base_model(
|
| 152 |
+
input_ids=input_ids,
|
| 153 |
+
attention_mask=attention_mask,
|
| 154 |
+
output_hidden_states=True,
|
| 155 |
+
output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
|
| 156 |
+
return_dict=True
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# H(0): Last hidden state from the base model
|
| 160 |
+
h0 = outputs.last_hidden_state
|
| 161 |
+
|
| 162 |
+
# T: Reasoning Trace (all hidden states)
|
| 163 |
+
reasoning_trace_T = outputs.hidden_states
|
| 164 |
+
|
| 165 |
+
# y^(0): Initial QA prediction using H(0)
|
| 166 |
+
y0_output = self.qa_head(h0)
|
| 167 |
+
y0_start_logits, y0_end_logits = y0_output["start_logits"], y0_output["end_logits"]
|
| 168 |
+
|
| 169 |
+
# === 2. Memory Integration (if enabled) ===
|
| 170 |
+
memory_context = None
|
| 171 |
+
if use_memory and (is_training and config.MEMORY_USE_DURING_TRAINING or not is_training):
|
| 172 |
+
if len(self.memory) > 0:
|
| 173 |
+
memory_context = self.memory.get_memory_context(h0, attention_mask)
|
| 174 |
+
|
| 175 |
+
# === 3. Multi-step Reasoning ===
|
| 176 |
+
# Initialize current hidden state
|
| 177 |
+
h_current = h0
|
| 178 |
+
|
| 179 |
+
# Store all deltas and gates for loss calculation and analysis
|
| 180 |
+
all_deltas = []
|
| 181 |
+
all_gates = []
|
| 182 |
+
all_hidden_states = [h0]
|
| 183 |
+
|
| 184 |
+
# Determine number of reasoning steps to use
|
| 185 |
+
actual_steps_taken = 0
|
| 186 |
+
|
| 187 |
+
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps:
|
| 188 |
+
if self.reasoning_step_type == "learned":
|
| 189 |
+
# Pool sequence dimension to get a single vector per example
|
| 190 |
+
pooled_h0 = h0.mean(dim=1)
|
| 191 |
+
|
| 192 |
+
# Get step logits from controller
|
| 193 |
+
step_logits = self.step_controller(pooled_h0)
|
| 194 |
+
|
| 195 |
+
if is_training:
|
| 196 |
+
# During training, sample from distribution (exploration)
|
| 197 |
+
step_probs = F.softmax(step_logits, dim=-1)
|
| 198 |
+
steps_idx = torch.multinomial(step_probs, 1).squeeze(-1)
|
| 199 |
+
num_steps = steps_idx + self.min_reasoning_steps
|
| 200 |
+
else:
|
| 201 |
+
# During inference, take argmax (exploitation)
|
| 202 |
+
steps_idx = torch.argmax(step_logits, dim=-1)
|
| 203 |
+
num_steps = steps_idx + self.min_reasoning_steps
|
| 204 |
+
|
| 205 |
+
# Store step logits for analysis
|
| 206 |
+
step_probs = F.softmax(step_logits, dim=-1)
|
| 207 |
+
|
| 208 |
+
# Get the maximum number of steps across the batch
|
| 209 |
+
max_num_steps = num_steps.max().item()
|
| 210 |
+
elif self.reasoning_step_type == "confidence":
|
| 211 |
+
# For confidence-based, we'll determine dynamically during the loop
|
| 212 |
+
max_num_steps = self.max_reasoning_steps
|
| 213 |
+
else:
|
| 214 |
+
# Fallback to fixed steps
|
| 215 |
+
max_num_steps = self.num_reasoning_steps
|
| 216 |
+
else:
|
| 217 |
+
# Use fixed number of steps
|
| 218 |
+
max_num_steps = self.num_reasoning_steps
|
| 219 |
+
|
| 220 |
+
# Perform reasoning steps
|
| 221 |
+
for step in range(max_num_steps):
|
| 222 |
+
# For confidence-based, check if we should continue for each example
|
| 223 |
+
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "confidence" and step >= self.min_reasoning_steps:
|
| 224 |
+
# Check delta magnitude from previous step
|
| 225 |
+
if len(all_deltas) > 0:
|
| 226 |
+
prev_delta = all_deltas[-1]
|
| 227 |
+
delta_norm = prev_delta.norm(dim=-1).mean().item()
|
| 228 |
+
if delta_norm < self.early_stop_threshold:
|
| 229 |
+
break
|
| 230 |
+
|
| 231 |
+
# For learned approach, check if we've reached the determined number of steps
|
| 232 |
+
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned":
|
| 233 |
+
# Create a mask for examples that should continue
|
| 234 |
+
if step > 0: # Skip first step check since all examples need at least 1 step
|
| 235 |
+
# Check which examples should continue
|
| 236 |
+
continue_mask = (step < num_steps).float().unsqueeze(-1).unsqueeze(-1)
|
| 237 |
+
|
| 238 |
+
# If no examples need more steps, break
|
| 239 |
+
if continue_mask.sum() == 0:
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
# Compute delta using the current hidden state and reasoning trace
|
| 243 |
+
if config.BYPASS_DELTA_CALCULATION:
|
| 244 |
+
# Bypass delta calculation for testing
|
| 245 |
+
delta = torch.zeros_like(h_current)
|
| 246 |
+
attn_weights = None
|
| 247 |
+
else:
|
| 248 |
+
delta, attn_weights = self.retroactive_update_layer(h_current, reasoning_trace_T)
|
| 249 |
+
|
| 250 |
+
# Apply delta magnitude constraint
|
| 251 |
+
constrained_delta, delta_reg_loss = self._apply_delta_constraint(delta, h0, is_training)
|
| 252 |
+
|
| 253 |
+
# For learned approach with continue_mask, apply mask to delta
|
| 254 |
+
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned" and step > 0:
|
| 255 |
+
constrained_delta = constrained_delta * continue_mask
|
| 256 |
+
|
| 257 |
+
# Compute gate values for selective update
|
| 258 |
+
gate = self.gating_mechanism(h_current, constrained_delta)
|
| 259 |
+
|
| 260 |
+
# Apply gated update
|
| 261 |
+
h_current = h_current + gate * constrained_delta
|
| 262 |
+
|
| 263 |
+
# Store for later use
|
| 264 |
+
all_deltas.append(constrained_delta)
|
| 265 |
+
all_gates.append(gate)
|
| 266 |
+
all_hidden_states.append(h_current)
|
| 267 |
+
actual_steps_taken = step + 1
|
| 268 |
+
|
| 269 |
+
# Final hidden state after all reasoning steps
|
| 270 |
+
h_final = h_current
|
| 271 |
+
|
| 272 |
+
# === 4. Final Prediction ===
|
| 273 |
+
y_final_output = self.qa_head(h_final)
|
| 274 |
+
y_final_start_logits, y_final_end_logits = y_final_output["start_logits"], y_final_output["end_logits"]
|
| 275 |
+
|
| 276 |
+
# === 5. Loss Calculation ===
|
| 277 |
+
total_loss = None
|
| 278 |
+
loss_components = {}
|
| 279 |
+
|
| 280 |
+
if start_positions is not None and end_positions is not None:
|
| 281 |
+
# Prepare ground truth positions
|
| 282 |
+
if len(start_positions.size()) > 1:
|
| 283 |
+
start_positions = start_positions.squeeze(-1)
|
| 284 |
+
if len(end_positions.size()) > 1:
|
| 285 |
+
end_positions = end_positions.squeeze(-1)
|
| 286 |
+
|
| 287 |
+
ignored_index = y_final_start_logits.size(1)
|
| 288 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 289 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 290 |
+
|
| 291 |
+
# Task Loss (QA Loss)
|
| 292 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
| 293 |
+
start_loss = loss_fct(y_final_start_logits, start_positions)
|
| 294 |
+
end_loss = loss_fct(y_final_end_logits, end_positions)
|
| 295 |
+
task_loss = (start_loss + end_loss) / 2
|
| 296 |
+
loss_components["task_loss"] = task_loss.item()
|
| 297 |
+
|
| 298 |
+
# Coherence Loss
|
| 299 |
+
coherence_loss_start = self.coherence_loss_fn(y0_start_logits, y_final_start_logits.detach())
|
| 300 |
+
coherence_loss_end = self.coherence_loss_fn(y0_end_logits, y_final_end_logits.detach())
|
| 301 |
+
coherence_loss = (coherence_loss_start + coherence_loss_end) / 2
|
| 302 |
+
loss_components["coherence_loss"] = coherence_loss.item()
|
| 303 |
+
|
| 304 |
+
# Delta Regularization Loss (if computed)
|
| 305 |
+
if delta_reg_loss is not None:
|
| 306 |
+
loss_components["delta_reg_loss"] = delta_reg_loss.item()
|
| 307 |
+
|
| 308 |
+
# Total Loss
|
| 309 |
+
total_loss = task_loss + config.LAMBDA_COHERENCE * coherence_loss
|
| 310 |
+
|
| 311 |
+
# Add delta regularization if computed
|
| 312 |
+
if delta_reg_loss is not None:
|
| 313 |
+
total_loss = total_loss + config.LAMBDA_DELTA_REG * delta_reg_loss
|
| 314 |
+
|
| 315 |
+
# === 6. Memory Update ===
|
| 316 |
+
if use_memory:
|
| 317 |
+
# Prepare input data
|
| 318 |
+
input_data = {'input_ids': input_ids, 'attention_mask': attention_mask}
|
| 319 |
+
if token_type_ids is not None:
|
| 320 |
+
input_data['token_type_ids'] = token_type_ids
|
| 321 |
+
|
| 322 |
+
# Prepare outputs
|
| 323 |
+
initial_output = {'start_logits': y0_start_logits, 'end_logits': y0_end_logits}
|
| 324 |
+
final_output = {'start_logits': y_final_start_logits, 'end_logits': y_final_end_logits}
|
| 325 |
+
|
| 326 |
+
# Add to memory (during both training and inference if enabled)
|
| 327 |
+
if is_training and config.MEMORY_USE_DURING_TRAINING or not is_training:
|
| 328 |
+
self.memory.add(
|
| 329 |
+
input_data=input_data,
|
| 330 |
+
hidden_states=h0,
|
| 331 |
+
output=initial_output,
|
| 332 |
+
reasoning_trace=reasoning_trace_T,
|
| 333 |
+
final_hidden_states=h_final,
|
| 334 |
+
final_output=final_output
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# === 7. Return Outputs ===
|
| 338 |
+
if not return_dict:
|
| 339 |
+
output = (y_final_start_logits, y_final_end_logits) + outputs[2:]
|
| 340 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 341 |
+
|
| 342 |
+
# Store custom outputs as instance attributes for later access if needed
|
| 343 |
+
# This avoids passing them to QuestionAnsweringModelOutput which doesn't accept them
|
| 344 |
+
self.custom_outputs = {
|
| 345 |
+
"initial_hidden_states": h0,
|
| 346 |
+
"final_hidden_states": h_final,
|
| 347 |
+
"all_hidden_states": all_hidden_states,
|
| 348 |
+
"all_deltas": all_deltas,
|
| 349 |
+
"all_gates": all_gates,
|
| 350 |
+
"y0_start_logits": y0_start_logits,
|
| 351 |
+
"y0_end_logits": y0_end_logits,
|
| 352 |
+
"loss_components": loss_components if total_loss is not None else None,
|
| 353 |
+
"steps_taken": actual_steps_taken
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
# Add step controller outputs if using learned approach
|
| 357 |
+
if self.use_dynamic_steps and self.reasoning_step_type == "learned":
|
| 358 |
+
self.custom_outputs["step_probs"] = step_probs
|
| 359 |
+
self.custom_outputs["num_steps"] = num_steps
|
| 360 |
+
|
| 361 |
+
# Return standard QuestionAnsweringModelOutput without custom fields
|
| 362 |
+
return QuestionAnsweringModelOutput(
|
| 363 |
+
loss=total_loss,
|
| 364 |
+
start_logits=y_final_start_logits,
|
| 365 |
+
end_logits=y_final_end_logits,
|
| 366 |
+
hidden_states=outputs.hidden_states,
|
| 367 |
+
attentions=outputs.attentions
|
| 368 |
+
)
|
code/modules.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modules.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import config
|
| 6 |
+
|
| 7 |
+
class CrossAttentionDelta(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention.
|
| 10 |
+
Improvements:
|
| 11 |
+
1. Pre-norm architecture (layer norm before attention)
|
| 12 |
+
2. More sophisticated attention patterns
|
| 13 |
+
3. Ability to incorporate reasoning trace
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.hidden_dim = hidden_dim
|
| 18 |
+
self.num_heads = num_heads
|
| 19 |
+
|
| 20 |
+
# Pre-norm layer normalization (applied before attention)
|
| 21 |
+
self.pre_norm = nn.LayerNorm(hidden_dim)
|
| 22 |
+
|
| 23 |
+
# Cross-attention mechanism
|
| 24 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 25 |
+
embed_dim=hidden_dim,
|
| 26 |
+
num_heads=num_heads,
|
| 27 |
+
dropout=dropout,
|
| 28 |
+
batch_first=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Post-attention layer normalization
|
| 32 |
+
self.post_norm = nn.LayerNorm(hidden_dim)
|
| 33 |
+
|
| 34 |
+
# Trace integration module (to incorporate reasoning trace T)
|
| 35 |
+
self.trace_integration = nn.Sequential(
|
| 36 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 37 |
+
nn.GELU(),
|
| 38 |
+
nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(hidden_dim, hidden_dim)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Enhanced MLP for delta computation
|
| 43 |
+
self.delta_mlp = nn.Sequential(
|
| 44 |
+
nn.Linear(hidden_dim * 2, hidden_dim * 4), # Larger intermediate expansion
|
| 45 |
+
nn.GELU(),
|
| 46 |
+
nn.Dropout(dropout),
|
| 47 |
+
nn.Linear(hidden_dim * 4, hidden_dim * 2),
|
| 48 |
+
nn.GELU(),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
nn.Linear(hidden_dim * 2, hidden_dim)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Final layer normalization
|
| 54 |
+
self.final_norm = nn.LayerNorm(hidden_dim)
|
| 55 |
+
|
| 56 |
+
def forward(self, h0, reasoning_trace=None):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
|
| 60 |
+
reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model.
|
| 61 |
+
Each tensor has shape (batch_size, seq_len, hidden_dim).
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim).
|
| 65 |
+
"""
|
| 66 |
+
batch_size, seq_len, _ = h0.shape
|
| 67 |
+
|
| 68 |
+
# --- Pre-norm Architecture ---
|
| 69 |
+
# Apply layer normalization before attention (pre-norm)
|
| 70 |
+
h0_norm = self.pre_norm(h0)
|
| 71 |
+
|
| 72 |
+
# --- Enhanced Cross-Attention ---
|
| 73 |
+
# Get attention weights to visualize attention patterns
|
| 74 |
+
attn_output, attn_weights = self.cross_attn(
|
| 75 |
+
query=h0_norm,
|
| 76 |
+
key=h0_norm,
|
| 77 |
+
value=h0_norm,
|
| 78 |
+
need_weights=True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Residual connection and post-norm
|
| 82 |
+
c = self.post_norm(h0 + attn_output)
|
| 83 |
+
|
| 84 |
+
# --- Reasoning Trace Integration (if provided) ---
|
| 85 |
+
if reasoning_trace is not None and len(reasoning_trace) > 0:
|
| 86 |
+
# Use the last layer from the reasoning trace (most semantic)
|
| 87 |
+
last_layer = reasoning_trace[-1]
|
| 88 |
+
|
| 89 |
+
# Integrate the reasoning trace with the current context
|
| 90 |
+
trace_info = self.trace_integration(
|
| 91 |
+
torch.cat([c, last_layer], dim=-1)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Add the trace information to the context
|
| 95 |
+
c = c + trace_info
|
| 96 |
+
|
| 97 |
+
# --- Enhanced MLP for Delta ---
|
| 98 |
+
# Concatenate original h0 with context c
|
| 99 |
+
mlp_input = torch.cat((h0, c), dim=-1)
|
| 100 |
+
|
| 101 |
+
# Compute delta through enhanced MLP
|
| 102 |
+
delta = self.delta_mlp(mlp_input)
|
| 103 |
+
|
| 104 |
+
# Apply final normalization
|
| 105 |
+
delta = self.final_norm(delta)
|
| 106 |
+
|
| 107 |
+
return delta, attn_weights
|
| 108 |
+
|
| 109 |
+
class GatingMechanism(nn.Module):
|
| 110 |
+
"""
|
| 111 |
+
Gating mechanism to selectively apply updates.
|
| 112 |
+
Learns when to apply the delta update based on the hidden state and delta.
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, hidden_dim, dropout=0.1):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.gate_network = nn.Sequential(
|
| 117 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 118 |
+
nn.GELU(),
|
| 119 |
+
nn.Dropout(dropout),
|
| 120 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 121 |
+
nn.GELU(),
|
| 122 |
+
nn.Dropout(dropout),
|
| 123 |
+
nn.Linear(hidden_dim, 1),
|
| 124 |
+
nn.Sigmoid() # Output between 0 and 1
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def forward(self, h0, delta):
|
| 128 |
+
"""
|
| 129 |
+
Args:
|
| 130 |
+
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
|
| 131 |
+
delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim).
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1).
|
| 135 |
+
"""
|
| 136 |
+
# Concatenate h0 and delta
|
| 137 |
+
gate_input = torch.cat([h0, delta], dim=-1)
|
| 138 |
+
|
| 139 |
+
# Compute gate values
|
| 140 |
+
gate = self.gate_network(gate_input)
|
| 141 |
+
|
| 142 |
+
return gate
|
| 143 |
+
|
| 144 |
+
class EnhancedQAHead(nn.Module):
|
| 145 |
+
"""
|
| 146 |
+
Enhanced Question Answering head with deeper architecture and bilinear scoring.
|
| 147 |
+
"""
|
| 148 |
+
def __init__(self, hidden_dim, dropout=0.1):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
# Deeper representation before prediction
|
| 152 |
+
self.start_transform = nn.Sequential(
|
| 153 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 154 |
+
nn.GELU(),
|
| 155 |
+
nn.Dropout(dropout),
|
| 156 |
+
nn.Linear(hidden_dim, hidden_dim)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.end_transform = nn.Sequential(
|
| 160 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 161 |
+
nn.GELU(),
|
| 162 |
+
nn.Dropout(dropout),
|
| 163 |
+
nn.Linear(hidden_dim, hidden_dim)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Bilinear layer for start position scoring
|
| 167 |
+
self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
|
| 168 |
+
|
| 169 |
+
# Bilinear layer for end position scoring
|
| 170 |
+
self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
|
| 171 |
+
|
| 172 |
+
# Global representation for bilinear scoring
|
| 173 |
+
self.global_rep = nn.Parameter(torch.randn(hidden_dim))
|
| 174 |
+
|
| 175 |
+
def forward(self, hidden_states):
|
| 176 |
+
"""
|
| 177 |
+
Args:
|
| 178 |
+
hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim).
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
dict: Dictionary with start_logits and end_logits.
|
| 182 |
+
"""
|
| 183 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 184 |
+
|
| 185 |
+
# Transform hidden states
|
| 186 |
+
start_rep = self.start_transform(hidden_states)
|
| 187 |
+
end_rep = self.end_transform(hidden_states)
|
| 188 |
+
|
| 189 |
+
# Expand global representation for batch processing
|
| 190 |
+
global_rep = self.global_rep.expand(batch_size, seq_len, -1)
|
| 191 |
+
|
| 192 |
+
# Compute start and end logits using bilinear scoring
|
| 193 |
+
start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1)
|
| 194 |
+
end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1)
|
| 195 |
+
|
| 196 |
+
return {"start_logits": start_logits, "end_logits": end_logits}
|
code/test_model.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_model.py - RRN QA Model evaluation script with multi-step reasoning support
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from transformers import AutoTokenizer, AutoModel, default_data_collator
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
import os
|
| 8 |
+
import evaluate as hf_evaluate # Import with alias to avoid naming conflict
|
| 9 |
+
import collections
|
| 10 |
+
import numpy as np
|
| 11 |
+
import logging
|
| 12 |
+
import multiprocessing # For Windows multiprocessing support
|
| 13 |
+
import json
|
| 14 |
+
import argparse
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
|
| 18 |
+
# Import custom modules and config
|
| 19 |
+
import config
|
| 20 |
+
from model import EnhancedRRN_QA_Model # Import the enhanced model
|
| 21 |
+
# Make sure memory.py and modules.py are accessible
|
| 22 |
+
|
| 23 |
+
# --- Configuration ---
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
# Parse command line arguments
|
| 29 |
+
parser = argparse.ArgumentParser(description="Test RRN QA Model")
|
| 30 |
+
parser.add_argument("--checkpoint", type=str, default="./rrn_qa_model_epoch_3",
|
| 31 |
+
help="Path to checkpoint directory (default: ./rrn_qa_model_epoch_3)")
|
| 32 |
+
parser.add_argument("--batch_size", type=int, default=8,
|
| 33 |
+
help="Evaluation batch size (default: 8)")
|
| 34 |
+
parser.add_argument("--fixed_steps", type=int, default=None,
|
| 35 |
+
help="Override to use fixed number of reasoning steps (default: None, use model's dynamic steps)")
|
| 36 |
+
parser.add_argument("--use_memory", action="store_true",
|
| 37 |
+
help="Enable active memory during evaluation")
|
| 38 |
+
parser.add_argument("--output_dir", type=str, default="./eval_results",
|
| 39 |
+
help="Directory to save evaluation results (default: ./eval_results)")
|
| 40 |
+
parser.add_argument("--visualize", action="store_true",
|
| 41 |
+
help="Generate visualizations of reasoning steps")
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
CHECKPOINT_DIR = args.checkpoint
|
| 45 |
+
EVAL_BATCH_SIZE = args.batch_size
|
| 46 |
+
DEVICE = config.DEVICE
|
| 47 |
+
USE_MEMORY = args.use_memory
|
| 48 |
+
OUTPUT_DIR = args.output_dir
|
| 49 |
+
|
| 50 |
+
# Create output directory if it doesn't exist
|
| 51 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
logger.info(f"Evaluation configuration:")
|
| 54 |
+
logger.info(f" Checkpoint: {CHECKPOINT_DIR}")
|
| 55 |
+
logger.info(f" Batch size: {EVAL_BATCH_SIZE}")
|
| 56 |
+
logger.info(f" Device: {DEVICE}")
|
| 57 |
+
logger.info(f" Use memory: {USE_MEMORY}")
|
| 58 |
+
logger.info(f" Output directory: {OUTPUT_DIR}")
|
| 59 |
+
if args.fixed_steps is not None:
|
| 60 |
+
logger.info(f" Using fixed {args.fixed_steps} reasoning steps (overriding model config)")
|
| 61 |
+
|
| 62 |
+
# --- 1. Load Tokenizer and Model from Checkpoint ---
|
| 63 |
+
logger.info(f"Loading tokenizer from {CHECKPOINT_DIR}...")
|
| 64 |
+
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
| 65 |
+
|
| 66 |
+
logger.info(f"Loading Enhanced RRN QA Model architecture...")
|
| 67 |
+
# Instantiate the enhanced model architecture
|
| 68 |
+
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
|
| 69 |
+
|
| 70 |
+
# Check if we're loading from a checkpoint with the enhanced architecture
|
| 71 |
+
base_model_path = os.path.join(CHECKPOINT_DIR, "base_model")
|
| 72 |
+
qa_head_path = os.path.join(CHECKPOINT_DIR, "qa_head.pth")
|
| 73 |
+
retroactive_layer_path = os.path.join(CHECKPOINT_DIR, "retroactive_layer.pth")
|
| 74 |
+
gating_mechanism_path = os.path.join(CHECKPOINT_DIR, "gating_mechanism.pth")
|
| 75 |
+
step_controller_path = os.path.join(CHECKPOINT_DIR, "step_controller.pth")
|
| 76 |
+
|
| 77 |
+
# Check for required components
|
| 78 |
+
if not os.path.exists(base_model_path):
|
| 79 |
+
logger.error(f"Base model directory not found at: {base_model_path}")
|
| 80 |
+
exit()
|
| 81 |
+
if not os.path.exists(qa_head_path):
|
| 82 |
+
logger.error(f"QA head weights not found at: {qa_head_path}")
|
| 83 |
+
exit()
|
| 84 |
+
if not os.path.exists(retroactive_layer_path):
|
| 85 |
+
logger.error(f"Retroactive layer weights not found at: {retroactive_layer_path}")
|
| 86 |
+
exit()
|
| 87 |
+
|
| 88 |
+
# Load base model weights
|
| 89 |
+
logger.info(f"Loading base model weights from {base_model_path}...")
|
| 90 |
+
model.base_model = AutoModel.from_pretrained(base_model_path)
|
| 91 |
+
|
| 92 |
+
# Check if we're loading from an enhanced checkpoint or a legacy checkpoint
|
| 93 |
+
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
|
| 94 |
+
|
| 95 |
+
if is_enhanced_checkpoint:
|
| 96 |
+
# Load all enhanced components
|
| 97 |
+
logger.info("Loading enhanced model components...")
|
| 98 |
+
model.qa_head.load_state_dict(torch.load(qa_head_path, map_location='cpu'))
|
| 99 |
+
model.retroactive_update_layer.load_state_dict(torch.load(retroactive_layer_path, map_location='cpu'))
|
| 100 |
+
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path, map_location='cpu'))
|
| 101 |
+
|
| 102 |
+
# Load step controller if available (for learned dynamic steps)
|
| 103 |
+
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
|
| 104 |
+
logger.info("Loading step controller for learned dynamic steps...")
|
| 105 |
+
model.step_controller.load_state_dict(torch.load(step_controller_path, map_location='cpu'))
|
| 106 |
+
|
| 107 |
+
logger.info("Enhanced model loaded successfully.")
|
| 108 |
+
else:
|
| 109 |
+
# We're loading from a legacy checkpoint - need to adapt the weights
|
| 110 |
+
logger.info("Loading from legacy checkpoint - adapting weights to enhanced architecture...")
|
| 111 |
+
|
| 112 |
+
# For the QA head, we need to initialize the enhanced QA head from scratch
|
| 113 |
+
# since the architectures are different
|
| 114 |
+
logger.info("Initializing enhanced QA head with random weights...")
|
| 115 |
+
|
| 116 |
+
# For the retroactive layer, we can try to load the weights but might need adjustments
|
| 117 |
+
logger.warning("Note: The enhanced model uses a different architecture than the checkpoint.")
|
| 118 |
+
logger.warning("Some components will use random initialization.")
|
| 119 |
+
|
| 120 |
+
# Load enhanced config if available
|
| 121 |
+
enhanced_config_path = os.path.join(CHECKPOINT_DIR, "enhanced_config.json")
|
| 122 |
+
if os.path.exists(enhanced_config_path):
|
| 123 |
+
logger.info(f"Loading enhanced configuration from {enhanced_config_path}")
|
| 124 |
+
with open(enhanced_config_path, 'r') as f:
|
| 125 |
+
enhanced_config = json.load(f)
|
| 126 |
+
|
| 127 |
+
# Override model configuration with saved values
|
| 128 |
+
if "num_reasoning_steps" in enhanced_config:
|
| 129 |
+
model.num_reasoning_steps = enhanced_config["num_reasoning_steps"]
|
| 130 |
+
logger.info(f"Using {model.num_reasoning_steps} reasoning steps from config")
|
| 131 |
+
|
| 132 |
+
if "use_dynamic_steps" in enhanced_config:
|
| 133 |
+
model.use_dynamic_steps = enhanced_config["use_dynamic_steps"]
|
| 134 |
+
if model.use_dynamic_steps:
|
| 135 |
+
model.max_reasoning_steps = enhanced_config.get("max_reasoning_steps", config.MAX_REASONING_STEPS)
|
| 136 |
+
model.min_reasoning_steps = enhanced_config.get("min_reasoning_steps", config.MIN_REASONING_STEPS)
|
| 137 |
+
model.reasoning_step_type = enhanced_config.get("reasoning_step_type", config.REASONING_STEP_TYPE)
|
| 138 |
+
model.early_stop_threshold = enhanced_config.get("early_stop_threshold", config.EARLY_STOP_THRESHOLD)
|
| 139 |
+
logger.info(f"Using dynamic reasoning steps (type: {model.reasoning_step_type})")
|
| 140 |
+
logger.info(f"Min steps: {model.min_reasoning_steps}, Max steps: {model.max_reasoning_steps}")
|
| 141 |
+
|
| 142 |
+
# Override with fixed steps if specified
|
| 143 |
+
if args.fixed_steps is not None:
|
| 144 |
+
logger.info(f"Overriding with fixed {args.fixed_steps} reasoning steps")
|
| 145 |
+
model.use_dynamic_steps = False
|
| 146 |
+
model.num_reasoning_steps = args.fixed_steps
|
| 147 |
+
|
| 148 |
+
model.to(DEVICE)
|
| 149 |
+
model.eval() # Set model to evaluation mode
|
| 150 |
+
logger.info("Model loaded successfully and set to evaluation mode.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# --- 2. Load and Preprocess Validation Dataset ---
|
| 154 |
+
logger.info("Loading SQuAD validation dataset...")
|
| 155 |
+
raw_datasets = load_dataset("squad", split="validation")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
question_column_name = "question"
|
| 159 |
+
context_column_name = "context"
|
| 160 |
+
answer_column_name = "answers"
|
| 161 |
+
pad_on_right = tokenizer.padding_side == "right"
|
| 162 |
+
|
| 163 |
+
# Validation preprocessing: Keep example_id and offset_mapping
|
| 164 |
+
def prepare_validation_features(examples):
|
| 165 |
+
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
|
| 166 |
+
tokenized_examples = tokenizer(
|
| 167 |
+
examples[question_column_name if pad_on_right else context_column_name],
|
| 168 |
+
examples[context_column_name if pad_on_right else question_column_name],
|
| 169 |
+
truncation="only_second" if pad_on_right else "only_first",
|
| 170 |
+
max_length=config.MAX_SEQ_LENGTH,
|
| 171 |
+
stride=config.DOC_STRIDE,
|
| 172 |
+
return_overflowing_tokens=True,
|
| 173 |
+
return_offsets_mapping=True,
|
| 174 |
+
padding="max_length",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Keep track of which feature belongs to which example
|
| 178 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 179 |
+
|
| 180 |
+
# Add the example_id to link features to original examples
|
| 181 |
+
tokenized_examples["example_id"] = []
|
| 182 |
+
for i in range(len(tokenized_examples["input_ids"])):
|
| 183 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 184 |
+
context_index = 1 if pad_on_right else 0
|
| 185 |
+
sample_index = sample_mapping[i]
|
| 186 |
+
tokenized_examples["example_id"].append(examples["id"][sample_index])
|
| 187 |
+
|
| 188 |
+
# Set offset mapping to None for question tokens to avoid predicting answers there
|
| 189 |
+
tokenized_examples["offset_mapping"][i] = [
|
| 190 |
+
(o if sequence_ids[k] == context_index else None)
|
| 191 |
+
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
return tokenized_examples
|
| 195 |
+
|
| 196 |
+
logger.info("Preprocessing validation dataset...")
|
| 197 |
+
# Disable multiprocessing which can hang on some systems
|
| 198 |
+
logger.info("Using single process for preprocessing to prevent hanging")
|
| 199 |
+
eval_dataset = raw_datasets.map(
|
| 200 |
+
prepare_validation_features,
|
| 201 |
+
batched=True,
|
| 202 |
+
remove_columns=raw_datasets.column_names,
|
| 203 |
+
num_proc=1, # Disable multiprocessing to avoid hanging
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Custom collator to handle None values in offset_mapping
|
| 207 |
+
def custom_data_collator(features):
|
| 208 |
+
# First, remove offset_mapping which contains None values that can't be batched
|
| 209 |
+
offset_mappings = [f.pop("offset_mapping") for f in features]
|
| 210 |
+
|
| 211 |
+
# Use default collator for everything else
|
| 212 |
+
batch = default_data_collator(features)
|
| 213 |
+
|
| 214 |
+
# Add offset_mapping back as a list since it can't be converted to a tensor
|
| 215 |
+
batch["offset_mapping"] = offset_mappings
|
| 216 |
+
|
| 217 |
+
return batch
|
| 218 |
+
|
| 219 |
+
# Use custom data collator
|
| 220 |
+
data_collator = custom_data_collator
|
| 221 |
+
|
| 222 |
+
eval_dataloader = DataLoader(
|
| 223 |
+
eval_dataset,
|
| 224 |
+
collate_fn=data_collator,
|
| 225 |
+
batch_size=EVAL_BATCH_SIZE
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# --- 3. Run Inference ---
|
| 229 |
+
logger.info("***** Running Evaluation *****")
|
| 230 |
+
logger.info(f" Num examples = {len(eval_dataset)}")
|
| 231 |
+
logger.info(f" Batch size = {EVAL_BATCH_SIZE}")
|
| 232 |
+
|
| 233 |
+
all_start_logits = []
|
| 234 |
+
all_end_logits = []
|
| 235 |
+
feature_indices = [] # Keep track of the order
|
| 236 |
+
|
| 237 |
+
# Track multi-step reasoning metrics
|
| 238 |
+
reasoning_steps_taken = []
|
| 239 |
+
delta_magnitudes = []
|
| 240 |
+
gate_values = []
|
| 241 |
+
initial_vs_final_changes = []
|
| 242 |
+
|
| 243 |
+
with torch.no_grad():
|
| 244 |
+
for step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
|
| 245 |
+
# Move batch to device
|
| 246 |
+
batch_on_device = {k: v.to(DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
| 247 |
+
# Store feature indices corresponding to this batch
|
| 248 |
+
# Assuming 'input_ids' or similar key represents features in order
|
| 249 |
+
current_indices = list(range(step * EVAL_BATCH_SIZE, step * EVAL_BATCH_SIZE + len(batch_on_device['input_ids'])))
|
| 250 |
+
feature_indices.extend(current_indices)
|
| 251 |
+
|
| 252 |
+
# Forward pass - pass only inputs needed by model.forward
|
| 253 |
+
outputs = model(
|
| 254 |
+
input_ids=batch_on_device.get("input_ids"),
|
| 255 |
+
attention_mask=batch_on_device.get("attention_mask"),
|
| 256 |
+
token_type_ids=batch_on_device.get("token_type_ids"),
|
| 257 |
+
use_memory=USE_MEMORY, # Use memory if enabled
|
| 258 |
+
return_dict=True
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Get the final logits (y1)
|
| 262 |
+
start_logits = outputs.start_logits
|
| 263 |
+
end_logits = outputs.end_logits
|
| 264 |
+
|
| 265 |
+
all_start_logits.append(start_logits.cpu().numpy())
|
| 266 |
+
all_end_logits.append(end_logits.cpu().numpy())
|
| 267 |
+
|
| 268 |
+
# Collect multi-step reasoning metrics from custom_outputs
|
| 269 |
+
if hasattr(model, 'custom_outputs'):
|
| 270 |
+
# Number of reasoning steps taken
|
| 271 |
+
if 'steps_taken' in model.custom_outputs:
|
| 272 |
+
reasoning_steps_taken.append(model.custom_outputs['steps_taken'])
|
| 273 |
+
|
| 274 |
+
# Delta magnitudes (how much the model updates at each step)
|
| 275 |
+
if 'all_deltas' in model.custom_outputs and len(model.custom_outputs['all_deltas']) > 0:
|
| 276 |
+
batch_deltas = []
|
| 277 |
+
for delta in model.custom_outputs['all_deltas']:
|
| 278 |
+
# Calculate mean delta magnitude across sequence dimension
|
| 279 |
+
delta_norm = delta.norm(dim=-1).mean().cpu().item()
|
| 280 |
+
batch_deltas.append(delta_norm)
|
| 281 |
+
delta_magnitudes.append(batch_deltas)
|
| 282 |
+
|
| 283 |
+
# Gate values (how selective the updates are)
|
| 284 |
+
if 'all_gates' in model.custom_outputs and len(model.custom_outputs['all_gates']) > 0:
|
| 285 |
+
batch_gates = []
|
| 286 |
+
for gate in model.custom_outputs['all_gates']:
|
| 287 |
+
# Calculate mean gate value across sequence dimension
|
| 288 |
+
gate_mean = gate.mean().cpu().item()
|
| 289 |
+
batch_gates.append(gate_mean)
|
| 290 |
+
gate_values.append(batch_gates)
|
| 291 |
+
|
| 292 |
+
# Compare initial vs final predictions
|
| 293 |
+
if 'y0_start_logits' in model.custom_outputs and 'y0_end_logits' in model.custom_outputs:
|
| 294 |
+
y0_start = model.custom_outputs['y0_start_logits']
|
| 295 |
+
y0_end = model.custom_outputs['y0_end_logits']
|
| 296 |
+
|
| 297 |
+
# Calculate how much the predictions changed
|
| 298 |
+
start_change = (start_logits - y0_start).abs().mean().cpu().item()
|
| 299 |
+
end_change = (end_logits - y0_end).abs().mean().cpu().item()
|
| 300 |
+
initial_vs_final_changes.append((start_change + end_change) / 2)
|
| 301 |
+
|
| 302 |
+
# Concatenate all results
|
| 303 |
+
all_start_logits = np.concatenate(all_start_logits, axis=0)
|
| 304 |
+
all_end_logits = np.concatenate(all_end_logits, axis=0)
|
| 305 |
+
|
| 306 |
+
# Ensure the number of predictions matches the number of features
|
| 307 |
+
if len(all_start_logits) != len(eval_dataset):
|
| 308 |
+
logger.warning(f"Mismatch in prediction count ({len(all_start_logits)}) and feature count ({len(eval_dataset)}). Check dataloader/inference loop.")
|
| 309 |
+
# Attempt to slice if predictions exceed features (might happen if last batch wasn't full)
|
| 310 |
+
all_start_logits = all_start_logits[:len(eval_dataset)]
|
| 311 |
+
all_end_logits = all_end_logits[:len(eval_dataset)]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# Create dictionary mapping feature index to its logits
|
| 315 |
+
predictions_dict = {
|
| 316 |
+
feature_index: (start_logit, end_logit)
|
| 317 |
+
for feature_index, (start_logit, end_logit) in zip(feature_indices, zip(all_start_logits, all_end_logits))
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# --- 4. Post-Processing ---
|
| 322 |
+
# (Adapted from Hugging Face run_qa.py example script)
|
| 323 |
+
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30, tokenizer=tokenizer):
|
| 324 |
+
all_start_logits, all_end_logits = zip(*raw_predictions.values())
|
| 325 |
+
|
| 326 |
+
# Build a map from example ID to list of related feature indices
|
| 327 |
+
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
| 328 |
+
features_per_example = collections.defaultdict(list)
|
| 329 |
+
for i, feature in enumerate(features):
|
| 330 |
+
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
|
| 331 |
+
|
| 332 |
+
# Dictionary to store predictions
|
| 333 |
+
predictions = collections.OrderedDict()
|
| 334 |
+
|
| 335 |
+
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
| 336 |
+
|
| 337 |
+
# Loop over all examples
|
| 338 |
+
for example_index, example in enumerate(tqdm(examples, desc="Post-processing")):
|
| 339 |
+
feature_indices = features_per_example[example_index] # Indices of features related to this example
|
| 340 |
+
|
| 341 |
+
min_null_score = None # Used to identify impossible answers
|
| 342 |
+
valid_answers = []
|
| 343 |
+
context = example["context"]
|
| 344 |
+
|
| 345 |
+
# Loop through features associated with the current example
|
| 346 |
+
for feature_index in feature_indices:
|
| 347 |
+
start_logits = all_start_logits[feature_index]
|
| 348 |
+
end_logits = all_end_logits[feature_index]
|
| 349 |
+
offset_mapping = features[feature_index]["offset_mapping"]
|
| 350 |
+
|
| 351 |
+
# Update minimum null prediction score
|
| 352 |
+
cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
|
| 353 |
+
feature_null_score = start_logits[cls_index] + end_logits[cls_index]
|
| 354 |
+
if min_null_score is None or min_null_score < feature_null_score:
|
| 355 |
+
min_null_score = feature_null_score
|
| 356 |
+
|
| 357 |
+
# Go through all possibilities for start/end positions
|
| 358 |
+
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
|
| 359 |
+
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
|
| 360 |
+
for start_index in start_indexes:
|
| 361 |
+
for end_index in end_indexes:
|
| 362 |
+
# Skip invalid pairs (start > end, index out of bounds, answer in question part)
|
| 363 |
+
if start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or \
|
| 364 |
+
offset_mapping[start_index] is None or offset_mapping[end_index] is None or \
|
| 365 |
+
end_index < start_index:
|
| 366 |
+
continue
|
| 367 |
+
|
| 368 |
+
# Check answer length
|
| 369 |
+
if end_index - start_index + 1 > max_answer_length:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
# Extract text and score
|
| 373 |
+
start_char = offset_mapping[start_index][0]
|
| 374 |
+
end_char = offset_mapping[end_index][1]
|
| 375 |
+
score = start_logits[start_index] + end_logits[end_index]
|
| 376 |
+
|
| 377 |
+
valid_answers.append({
|
| 378 |
+
"score": score,
|
| 379 |
+
"text": context[start_char: end_char]
|
| 380 |
+
})
|
| 381 |
+
|
| 382 |
+
# Select the best answer across all features for this example
|
| 383 |
+
if len(valid_answers) > 0:
|
| 384 |
+
best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
|
| 385 |
+
else:
|
| 386 |
+
# Fallback for no valid answers found
|
| 387 |
+
best_answer = {"text": "", "score": min_null_score} # Assign CLS score if needed
|
| 388 |
+
|
| 389 |
+
# Assign final prediction (use empty string if null score is best)
|
| 390 |
+
# Simple version: always take the best scoring valid answer
|
| 391 |
+
# More sophisticated versions might compare best_answer["score"] vs min_null_score
|
| 392 |
+
predictions[example["id"]] = best_answer["text"]
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
return predictions
|
| 396 |
+
|
| 397 |
+
logger.info("Starting post-processing...")
|
| 398 |
+
final_predictions = postprocess_qa_predictions(raw_datasets, eval_dataset, predictions_dict)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# --- 5. Compute Metrics ---
|
| 402 |
+
logger.info("Calculating SQuAD metrics...")
|
| 403 |
+
metric = hf_evaluate.load("squad")
|
| 404 |
+
|
| 405 |
+
# Format predictions and references for the metric
|
| 406 |
+
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
|
| 407 |
+
formatted_references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in raw_datasets]
|
| 408 |
+
|
| 409 |
+
results = metric.compute(predictions=formatted_predictions, references=formatted_references)
|
| 410 |
+
|
| 411 |
+
logger.info("***** Evaluation Results *****")
|
| 412 |
+
print(results)
|
| 413 |
+
|
| 414 |
+
# --- 6. Analyze Multi-step Reasoning Metrics ---
|
| 415 |
+
logger.info("\n***** Multi-step Reasoning Analysis *****")
|
| 416 |
+
|
| 417 |
+
# Calculate average number of reasoning steps
|
| 418 |
+
if reasoning_steps_taken:
|
| 419 |
+
avg_steps = sum(reasoning_steps_taken) / len(reasoning_steps_taken)
|
| 420 |
+
logger.info(f"Average reasoning steps: {avg_steps:.2f}")
|
| 421 |
+
|
| 422 |
+
# Count frequency of each step count
|
| 423 |
+
step_counts = collections.Counter(reasoning_steps_taken)
|
| 424 |
+
logger.info(f"Step count distribution: {dict(sorted(step_counts.items()))}")
|
| 425 |
+
|
| 426 |
+
# Calculate average delta magnitudes per step
|
| 427 |
+
if delta_magnitudes:
|
| 428 |
+
# Transpose to get step-wise averages
|
| 429 |
+
steps_delta_magnitudes = defaultdict(list)
|
| 430 |
+
for batch_deltas in delta_magnitudes:
|
| 431 |
+
for step_idx, delta in enumerate(batch_deltas):
|
| 432 |
+
steps_delta_magnitudes[step_idx].append(delta)
|
| 433 |
+
|
| 434 |
+
avg_delta_by_step = {step: sum(deltas)/len(deltas) for step, deltas in steps_delta_magnitudes.items()}
|
| 435 |
+
logger.info(f"Average delta magnitude by step: {avg_delta_by_step}")
|
| 436 |
+
|
| 437 |
+
# Calculate average gate values per step
|
| 438 |
+
if gate_values:
|
| 439 |
+
# Transpose to get step-wise averages
|
| 440 |
+
steps_gate_values = defaultdict(list)
|
| 441 |
+
for batch_gates in gate_values:
|
| 442 |
+
for step_idx, gate in enumerate(batch_gates):
|
| 443 |
+
steps_gate_values[step_idx].append(gate)
|
| 444 |
+
|
| 445 |
+
avg_gate_by_step = {step: sum(gates)/len(gates) for step, gates in steps_gate_values.items()}
|
| 446 |
+
logger.info(f"Average gate value by step: {avg_gate_by_step}")
|
| 447 |
+
|
| 448 |
+
# Calculate average change from initial to final predictions
|
| 449 |
+
if initial_vs_final_changes:
|
| 450 |
+
avg_change = sum(initial_vs_final_changes) / len(initial_vs_final_changes)
|
| 451 |
+
logger.info(f"Average change from initial to final predictions: {avg_change:.4f}")
|
| 452 |
+
|
| 453 |
+
# --- 7. Save Results ---
|
| 454 |
+
results_file = os.path.join(OUTPUT_DIR, "eval_results.json")
|
| 455 |
+
with open(results_file, 'w') as f:
|
| 456 |
+
# Combine SQuAD metrics with multi-step reasoning metrics
|
| 457 |
+
full_results = {
|
| 458 |
+
"squad_metrics": results,
|
| 459 |
+
"multi_step_metrics": {
|
| 460 |
+
"avg_reasoning_steps": avg_steps if reasoning_steps_taken else None,
|
| 461 |
+
"step_count_distribution": dict(sorted(step_counts.items())) if reasoning_steps_taken else None,
|
| 462 |
+
"avg_delta_by_step": avg_delta_by_step if delta_magnitudes else None,
|
| 463 |
+
"avg_gate_by_step": avg_gate_by_step if gate_values else None,
|
| 464 |
+
"avg_prediction_change": avg_change if initial_vs_final_changes else None
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
json.dump(full_results, f, indent=2)
|
| 468 |
+
|
| 469 |
+
logger.info(f"Results saved to {results_file}")
|
| 470 |
+
|
| 471 |
+
# --- 8. Generate Visualizations (if requested) ---
|
| 472 |
+
if args.visualize and (delta_magnitudes or gate_values or reasoning_steps_taken):
|
| 473 |
+
logger.info("Generating visualizations...")
|
| 474 |
+
|
| 475 |
+
# Create visualization directory
|
| 476 |
+
viz_dir = os.path.join(OUTPUT_DIR, "visualizations")
|
| 477 |
+
os.makedirs(viz_dir, exist_ok=True)
|
| 478 |
+
|
| 479 |
+
# Plot step distribution
|
| 480 |
+
if reasoning_steps_taken:
|
| 481 |
+
plt.figure(figsize=(10, 6))
|
| 482 |
+
plt.bar(step_counts.keys(), step_counts.values())
|
| 483 |
+
plt.xlabel('Number of Reasoning Steps')
|
| 484 |
+
plt.ylabel('Frequency')
|
| 485 |
+
plt.title('Distribution of Reasoning Steps')
|
| 486 |
+
plt.savefig(os.path.join(viz_dir, 'step_distribution.png'))
|
| 487 |
+
plt.close()
|
| 488 |
+
|
| 489 |
+
# Plot delta magnitudes by step
|
| 490 |
+
if delta_magnitudes and steps_delta_magnitudes:
|
| 491 |
+
plt.figure(figsize=(10, 6))
|
| 492 |
+
steps = sorted(steps_delta_magnitudes.keys())
|
| 493 |
+
values = [avg_delta_by_step[step] for step in steps]
|
| 494 |
+
plt.plot(steps, values, marker='o')
|
| 495 |
+
plt.xlabel('Reasoning Step')
|
| 496 |
+
plt.ylabel('Average Delta Magnitude')
|
| 497 |
+
plt.title('Delta Magnitude by Reasoning Step')
|
| 498 |
+
plt.grid(True)
|
| 499 |
+
plt.savefig(os.path.join(viz_dir, 'delta_magnitudes.png'))
|
| 500 |
+
plt.close()
|
| 501 |
+
|
| 502 |
+
# Plot gate values by step
|
| 503 |
+
if gate_values and steps_gate_values:
|
| 504 |
+
plt.figure(figsize=(10, 6))
|
| 505 |
+
steps = sorted(steps_gate_values.keys())
|
| 506 |
+
values = [avg_gate_by_step[step] for step in steps]
|
| 507 |
+
plt.plot(steps, values, marker='o')
|
| 508 |
+
plt.xlabel('Reasoning Step')
|
| 509 |
+
plt.ylabel('Average Gate Value')
|
| 510 |
+
plt.title('Gate Value by Reasoning Step')
|
| 511 |
+
plt.grid(True)
|
| 512 |
+
plt.savefig(os.path.join(viz_dir, 'gate_values.png'))
|
| 513 |
+
plt.close()
|
| 514 |
+
|
| 515 |
+
logger.info(f"Visualizations saved to {viz_dir}")
|
| 516 |
+
|
| 517 |
+
if __name__ == "__main__":
|
| 518 |
+
# This is required for Windows to properly handle multiprocessing
|
| 519 |
+
multiprocessing.freeze_support()
|
| 520 |
+
main()
|
| 521 |
+
|
| 522 |
+
# Example usage:
|
| 523 |
+
# Test with default settings (epoch 3 checkpoint):
|
| 524 |
+
# python test_model.py
|
| 525 |
+
|
| 526 |
+
# Test with specific checkpoint:
|
| 527 |
+
# python test_model.py --checkpoint ./rrn_qa_model_epoch_2
|
| 528 |
+
|
| 529 |
+
# Test with fixed number of reasoning steps:
|
| 530 |
+
# python test_model.py --fixed_steps 3
|
| 531 |
+
|
| 532 |
+
# Test with active memory:
|
| 533 |
+
# python test_model.py --use_memory
|
| 534 |
+
|
| 535 |
+
# Test with visualizations:
|
| 536 |
+
# python test_model.py --visualize
|
code/train.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train.py (Updated for Full Fine-tuning)
|
| 2 |
+
import torch
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torch.amp import autocast, GradScaler # For mixed precision training (updated import)
|
| 6 |
+
from transformers import AutoTokenizer, default_data_collator
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from tqdm.auto import tqdm # Progress bar
|
| 9 |
+
import os
|
| 10 |
+
import evaluate # For metrics
|
| 11 |
+
import logging # Optional: Better logging
|
| 12 |
+
import multiprocessing # For Windows multiprocessing support
|
| 13 |
+
import argparse # For command line arguments
|
| 14 |
+
|
| 15 |
+
# Import our custom modules and config
|
| 16 |
+
import config
|
| 17 |
+
from model import EnhancedRRN_QA_Model
|
| 18 |
+
|
| 19 |
+
# Setup basic logging
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
# Parse command line arguments
|
| 25 |
+
parser = argparse.ArgumentParser(description="Train RRN QA Model")
|
| 26 |
+
parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from")
|
| 27 |
+
parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--subset_percentage",
|
| 30 |
+
type=float,
|
| 31 |
+
default=100.0,
|
| 32 |
+
help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)"
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--bypass_delta",
|
| 36 |
+
action="store_true",
|
| 37 |
+
help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))"
|
| 38 |
+
)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
# Set bypass delta calculation flag if specified
|
| 42 |
+
if args.bypass_delta:
|
| 43 |
+
logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)")
|
| 44 |
+
config.BYPASS_DELTA_CALCULATION = True
|
| 45 |
+
else:
|
| 46 |
+
config.BYPASS_DELTA_CALCULATION = False
|
| 47 |
+
|
| 48 |
+
# --- 1. Load Tokenizer and Model ---
|
| 49 |
+
if args.checkpoint:
|
| 50 |
+
logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}")
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
|
| 52 |
+
|
| 53 |
+
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
|
| 54 |
+
# Initialize the model with base architecture
|
| 55 |
+
model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model"))
|
| 56 |
+
|
| 57 |
+
# Check for enhanced model components
|
| 58 |
+
gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth")
|
| 59 |
+
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
|
| 60 |
+
|
| 61 |
+
# Load custom module weights
|
| 62 |
+
logger.info("Loading model components...")
|
| 63 |
+
model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth")))
|
| 64 |
+
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth")))
|
| 65 |
+
|
| 66 |
+
# Load gating mechanism if available
|
| 67 |
+
if is_enhanced_checkpoint:
|
| 68 |
+
logger.info("Loading gating mechanism...")
|
| 69 |
+
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path))
|
| 70 |
+
|
| 71 |
+
# Load step controller if available (for learned dynamic steps)
|
| 72 |
+
step_controller_path = os.path.join(args.checkpoint, "step_controller.pth")
|
| 73 |
+
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
|
| 74 |
+
logger.info("Loading step controller for learned dynamic steps...")
|
| 75 |
+
model.step_controller.load_state_dict(torch.load(step_controller_path))
|
| 76 |
+
else:
|
| 77 |
+
logger.info("Loading tokenizer...")
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)
|
| 79 |
+
|
| 80 |
+
logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...")
|
| 81 |
+
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
|
| 82 |
+
|
| 83 |
+
model.to(config.DEVICE)
|
| 84 |
+
|
| 85 |
+
# --- 2. Load and Preprocess Dataset ---
|
| 86 |
+
logger.info("Loading SQuAD dataset...")
|
| 87 |
+
raw_datasets = load_dataset("squad")
|
| 88 |
+
|
| 89 |
+
# Handle dataset subsetting
|
| 90 |
+
subset_percentage = args.subset_percentage
|
| 91 |
+
if subset_percentage < 100.0:
|
| 92 |
+
original_train_size = len(raw_datasets["train"])
|
| 93 |
+
|
| 94 |
+
# Calculate subset size and validate
|
| 95 |
+
subset_percentage = max(0.1, min(100.0, subset_percentage)) # Clamp between 0.1% and 100%
|
| 96 |
+
train_subset_size = int(original_train_size * subset_percentage / 100)
|
| 97 |
+
train_subset_size = max(100, min(original_train_size, train_subset_size)) # Ensure reasonable bounds
|
| 98 |
+
|
| 99 |
+
# Create reproducible subset with fixed seed for consistency
|
| 100 |
+
subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist()
|
| 101 |
+
raw_datasets["train"] = raw_datasets["train"].select(subset_indices)
|
| 102 |
+
|
| 103 |
+
logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)")
|
| 104 |
+
else:
|
| 105 |
+
logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)")
|
| 106 |
+
|
| 107 |
+
question_column_name = "question"
|
| 108 |
+
context_column_name = "context"
|
| 109 |
+
answer_column_name = "answers"
|
| 110 |
+
pad_on_right = tokenizer.padding_side == "right"
|
| 111 |
+
|
| 112 |
+
def prepare_train_features(examples):
|
| 113 |
+
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
|
| 114 |
+
tokenized_examples = tokenizer(
|
| 115 |
+
examples[question_column_name if pad_on_right else context_column_name],
|
| 116 |
+
examples[context_column_name if pad_on_right else question_column_name],
|
| 117 |
+
truncation="only_second" if pad_on_right else "only_first",
|
| 118 |
+
max_length=config.MAX_SEQ_LENGTH,
|
| 119 |
+
stride=config.DOC_STRIDE,
|
| 120 |
+
return_overflowing_tokens=True,
|
| 121 |
+
return_offsets_mapping=True,
|
| 122 |
+
padding="max_length",
|
| 123 |
+
)
|
| 124 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
| 125 |
+
offset_mapping = tokenized_examples.pop("offset_mapping")
|
| 126 |
+
tokenized_examples["start_positions"] = []
|
| 127 |
+
tokenized_examples["end_positions"] = []
|
| 128 |
+
|
| 129 |
+
for i, offsets in enumerate(offset_mapping):
|
| 130 |
+
input_ids = tokenized_examples["input_ids"][i]
|
| 131 |
+
cls_index = input_ids.index(tokenizer.cls_token_id)
|
| 132 |
+
sequence_ids = tokenized_examples.sequence_ids(i)
|
| 133 |
+
sample_index = sample_mapping[i]
|
| 134 |
+
answers = examples[answer_column_name][sample_index]
|
| 135 |
+
|
| 136 |
+
if len(answers["answer_start"]) == 0:
|
| 137 |
+
tokenized_examples["start_positions"].append(cls_index)
|
| 138 |
+
tokenized_examples["end_positions"].append(cls_index)
|
| 139 |
+
else:
|
| 140 |
+
start_char = answers["answer_start"][0]
|
| 141 |
+
end_char = start_char + len(answers["text"][0])
|
| 142 |
+
token_start_index = 0
|
| 143 |
+
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
| 144 |
+
token_start_index += 1
|
| 145 |
+
token_end_index = len(input_ids) - 1
|
| 146 |
+
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
| 147 |
+
token_end_index -= 1
|
| 148 |
+
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
|
| 149 |
+
tokenized_examples["start_positions"].append(cls_index)
|
| 150 |
+
tokenized_examples["end_positions"].append(cls_index)
|
| 151 |
+
else:
|
| 152 |
+
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
|
| 153 |
+
token_start_index += 1
|
| 154 |
+
tokenized_examples["start_positions"].append(token_start_index - 1)
|
| 155 |
+
while offsets[token_end_index][1] >= end_char:
|
| 156 |
+
token_end_index -= 1
|
| 157 |
+
tokenized_examples["end_positions"].append(token_end_index + 1)
|
| 158 |
+
return tokenized_examples
|
| 159 |
+
|
| 160 |
+
logger.info("Preprocessing datasets...")
|
| 161 |
+
# Use single process on Windows to avoid multiprocessing issues
|
| 162 |
+
tokenized_datasets = raw_datasets.map(
|
| 163 |
+
prepare_train_features,
|
| 164 |
+
batched=True,
|
| 165 |
+
remove_columns=raw_datasets["train"].column_names,
|
| 166 |
+
num_proc=1 # Use single process to avoid Windows multiprocessing issues
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
data_collator = default_data_collator
|
| 170 |
+
train_dataloader = DataLoader(
|
| 171 |
+
tokenized_datasets["train"],
|
| 172 |
+
shuffle=True,
|
| 173 |
+
collate_fn=data_collator,
|
| 174 |
+
batch_size=config.BATCH_SIZE
|
| 175 |
+
)
|
| 176 |
+
# Consider adding validation dataloader setup here as well
|
| 177 |
+
# eval_dataloader = DataLoader(...)
|
| 178 |
+
|
| 179 |
+
# --- 3. Setup Optimizer ---
|
| 180 |
+
logger.info("Setting up optimizer for FULL model fine-tuning...")
|
| 181 |
+
# Optimize all parameters since PEFT is disabled
|
| 182 |
+
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
|
| 183 |
+
|
| 184 |
+
logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}")
|
| 185 |
+
# Calculate total steps considering gradient accumulation
|
| 186 |
+
num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
|
| 187 |
+
num_training_steps = config.EPOCHS * num_update_steps_per_epoch
|
| 188 |
+
logger.info(f"Total optimization steps: {num_training_steps}")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# --- 4. Initialize Mixed Precision Training ---
|
| 192 |
+
# Initialize gradient scaler for mixed precision training
|
| 193 |
+
scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION) # Updated to fix deprecation warning
|
| 194 |
+
|
| 195 |
+
# Log mixed precision and dynamic steps status
|
| 196 |
+
if config.USE_MIXED_PRECISION:
|
| 197 |
+
logger.info("Mixed precision training (FP16) enabled")
|
| 198 |
+
if config.USE_DYNAMIC_STEPS:
|
| 199 |
+
logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})")
|
| 200 |
+
logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}")
|
| 201 |
+
|
| 202 |
+
# Log bypass delta calculation status
|
| 203 |
+
if config.BYPASS_DELTA_CALCULATION:
|
| 204 |
+
logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))")
|
| 205 |
+
|
| 206 |
+
# --- 5. Training Loop ---
|
| 207 |
+
logger.info("***** Starting Training *****")
|
| 208 |
+
logger.info(f" Num examples = {len(tokenized_datasets['train'])}")
|
| 209 |
+
logger.info(f" Num Epochs = {config.EPOCHS}")
|
| 210 |
+
logger.info(f" Instantaneous batch size per device = {config.BATCH_SIZE}")
|
| 211 |
+
logger.info(f" Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}")
|
| 212 |
+
logger.info(f" Total optimization steps = {num_training_steps}")
|
| 213 |
+
|
| 214 |
+
# Add note about subset training if applicable
|
| 215 |
+
if subset_percentage < 100.0:
|
| 216 |
+
logger.info(f" NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
model.train() # Set model to training mode
|
| 220 |
+
global_step = 0
|
| 221 |
+
total_loss = 0.0 # Use float for accumulated loss
|
| 222 |
+
|
| 223 |
+
# Start from specified epoch (default is 0 if not provided)
|
| 224 |
+
start_epoch = args.start_epoch
|
| 225 |
+
|
| 226 |
+
for epoch in range(start_epoch, config.EPOCHS):
|
| 227 |
+
logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---")
|
| 228 |
+
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch")
|
| 229 |
+
|
| 230 |
+
for step, batch in enumerate(progress_bar):
|
| 231 |
+
# Move batch to device
|
| 232 |
+
# Ensure only tensors are moved, handle potential non-tensor data if any
|
| 233 |
+
batch_on_device = {}
|
| 234 |
+
for k, v in batch.items():
|
| 235 |
+
if isinstance(v, torch.Tensor):
|
| 236 |
+
batch_on_device[k] = v.to(config.DEVICE)
|
| 237 |
+
# else: # Handle or skip non-tensor items if necessary
|
| 238 |
+
# batch_on_device[k] = v
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
# Forward pass with autocast for mixed precision
|
| 242 |
+
with autocast('cuda', enabled=config.USE_MIXED_PRECISION): # Updated to fix deprecation warning
|
| 243 |
+
outputs = model(
|
| 244 |
+
input_ids=batch_on_device.get("input_ids"),
|
| 245 |
+
attention_mask=batch_on_device.get("attention_mask"),
|
| 246 |
+
token_type_ids=batch_on_device.get("token_type_ids"),
|
| 247 |
+
start_positions=batch_on_device.get("start_positions"),
|
| 248 |
+
end_positions=batch_on_device.get("end_positions"),
|
| 249 |
+
use_memory=False # Disable memory during training steps
|
| 250 |
+
)
|
| 251 |
+
loss = outputs.loss
|
| 252 |
+
|
| 253 |
+
if loss is None:
|
| 254 |
+
logger.warning(f"Step {step}: Loss is None. Skipping batch.")
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
# Scale loss for gradient accumulation
|
| 258 |
+
loss = loss / config.GRADIENT_ACCUMULATION_STEPS
|
| 259 |
+
|
| 260 |
+
# Accumulate loss value for logging (before backward)
|
| 261 |
+
total_loss += loss.item()
|
| 262 |
+
|
| 263 |
+
# Scale loss and perform backward pass with AMP
|
| 264 |
+
scaler.scale(loss).backward()
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"Error during forward/backward pass at step {step}: {e}")
|
| 268 |
+
# Optional: Add more detailed error handling or debugging info
|
| 269 |
+
# logger.error(f"Batch keys: {batch.keys()}")
|
| 270 |
+
# logger.error(f"Input IDs shape: {batch_on_device.get('input_ids').shape if batch_on_device.get('input_ids') is not None else 'None'}")
|
| 271 |
+
raise e # Re-raise the exception to stop training
|
| 272 |
+
|
| 273 |
+
# Optimizer step (perform step only after accumulating gradients)
|
| 274 |
+
if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1:
|
| 275 |
+
# Unscale before optimizer step (to check for infs/NaNs)
|
| 276 |
+
scaler.unscale_(optimizer)
|
| 277 |
+
|
| 278 |
+
# Clip gradients to avoid explosion (optional but recommended with mixed precision)
|
| 279 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 280 |
+
|
| 281 |
+
# Step with scaler
|
| 282 |
+
scaler.step(optimizer)
|
| 283 |
+
scaler.update()
|
| 284 |
+
optimizer.zero_grad() # Reset gradients for the next accumulation cycle
|
| 285 |
+
global_step += 1
|
| 286 |
+
|
| 287 |
+
# Log progress periodically
|
| 288 |
+
if global_step % 50 == 0: # Log every 50 optimization steps
|
| 289 |
+
avg_loss = total_loss / 50 # Average loss over the last 50 steps
|
| 290 |
+
logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}")
|
| 291 |
+
total_loss = 0.0 # Reset loss accumulator
|
| 292 |
+
|
| 293 |
+
# Update progress bar description with current step loss and steps info
|
| 294 |
+
postfix = {
|
| 295 |
+
"Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}",
|
| 296 |
+
"Step": global_step
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
# Add steps info if using dynamic steps
|
| 300 |
+
if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'):
|
| 301 |
+
if 'steps_taken' in model.custom_outputs:
|
| 302 |
+
postfix["Steps"] = model.custom_outputs['steps_taken']
|
| 303 |
+
|
| 304 |
+
progress_bar.set_postfix(postfix)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# --- (Optional) Evaluation at the end of each epoch ---
|
| 308 |
+
# logger.info(f"\n--- Evaluating after Epoch {epoch+1} ---")
|
| 309 |
+
# model.eval()
|
| 310 |
+
# # Add evaluation loop here (requires validation dataloader, postprocessing, metrics)
|
| 311 |
+
# model.train() # Set back to train mode
|
| 312 |
+
|
| 313 |
+
# --- Save Model Checkpoint ---
|
| 314 |
+
output_dir = f"./rrn_qa_model_epoch_{epoch+1}"
|
| 315 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 316 |
+
logger.info(f"--- Saving model checkpoint to {output_dir} ---")
|
| 317 |
+
|
| 318 |
+
# --- Saving Logic for Enhanced Model ---
|
| 319 |
+
try:
|
| 320 |
+
logger.info(f"Saving enhanced model components to {output_dir}")
|
| 321 |
+
# Save base model using its save_pretrained
|
| 322 |
+
model.base_model.save_pretrained(os.path.join(output_dir, "base_model"))
|
| 323 |
+
|
| 324 |
+
# Save all custom modules' state dicts
|
| 325 |
+
torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth"))
|
| 326 |
+
torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth"))
|
| 327 |
+
torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth"))
|
| 328 |
+
|
| 329 |
+
# Save step controller if using learned dynamic steps
|
| 330 |
+
if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"):
|
| 331 |
+
torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth"))
|
| 332 |
+
logger.info("Saved step controller for learned dynamic steps")
|
| 333 |
+
|
| 334 |
+
# Save tokenizer
|
| 335 |
+
tokenizer.save_pretrained(output_dir)
|
| 336 |
+
|
| 337 |
+
# Save configuration
|
| 338 |
+
with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f:
|
| 339 |
+
import json
|
| 340 |
+
config_dict = {
|
| 341 |
+
"num_reasoning_steps": config.NUM_REASONING_STEPS,
|
| 342 |
+
"delta_target_ratio": config.DELTA_TARGET_RATIO,
|
| 343 |
+
"lambda_coherence": config.LAMBDA_COHERENCE,
|
| 344 |
+
"lambda_delta_reg": config.LAMBDA_DELTA_REG,
|
| 345 |
+
"memory_max_size": config.MEMORY_MAX_SIZE,
|
| 346 |
+
"memory_retrieval_k": config.MEMORY_RETRIEVAL_K,
|
| 347 |
+
"use_mixed_precision": config.USE_MIXED_PRECISION,
|
| 348 |
+
"bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
# Add dynamic steps configuration if enabled
|
| 352 |
+
if config.USE_DYNAMIC_STEPS:
|
| 353 |
+
config_dict.update({
|
| 354 |
+
"use_dynamic_steps": config.USE_DYNAMIC_STEPS,
|
| 355 |
+
"max_reasoning_steps": config.MAX_REASONING_STEPS,
|
| 356 |
+
"min_reasoning_steps": config.MIN_REASONING_STEPS,
|
| 357 |
+
"reasoning_step_type": config.REASONING_STEP_TYPE,
|
| 358 |
+
"early_stop_threshold": config.EARLY_STOP_THRESHOLD
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
json.dump(config_dict, f, indent=2)
|
| 362 |
+
|
| 363 |
+
logger.info("Enhanced model checkpoint saved successfully.")
|
| 364 |
+
except Exception as e:
|
| 365 |
+
logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
logger.info("\n***** Training finished *****")
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
# This is required for Windows to properly handle multiprocessing
|
| 372 |
+
multiprocessing.freeze_support()
|
| 373 |
+
main()
|
| 374 |
+
|
| 375 |
+
# Example usage:
|
| 376 |
+
# Train on full dataset (default):
|
| 377 |
+
# python train.py
|
| 378 |
+
|
| 379 |
+
# Train on 10% of data for faster iterations:
|
| 380 |
+
# python train.py --subset_percentage 10.0
|
| 381 |
+
|
| 382 |
+
# Train on 1% for very quick testing:
|
| 383 |
+
# python train.py --subset_percentage 1.0
|
| 384 |
+
|
| 385 |
+
# Resume training from checkpoint with subset:
|
| 386 |
+
# python train.py --checkpoint ./rrn_qa_model_epoch_1 --start_epoch 1 --subset_percentage 25.0
|
| 387 |
+
|
| 388 |
+
# Test with bypassed delta calculation (sets delta = torch.zeros_like(h0)):
|
| 389 |
+
# python train.py --bypass_delta --subset_percentage 1.0
|
enhanced_config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_reasoning_steps": 3,
|
| 3 |
+
"delta_target_ratio": 0.2,
|
| 4 |
+
"lambda_coherence": 0.1,
|
| 5 |
+
"lambda_delta_reg": 0.5,
|
| 6 |
+
"memory_max_size": 50,
|
| 7 |
+
"memory_retrieval_k": 3,
|
| 8 |
+
"use_mixed_precision": false,
|
| 9 |
+
"bypass_delta_calculation": false,
|
| 10 |
+
"use_dynamic_steps": true,
|
| 11 |
+
"max_reasoning_steps": 5,
|
| 12 |
+
"min_reasoning_steps": 1,
|
| 13 |
+
"reasoning_step_type": "learned",
|
| 14 |
+
"early_stop_threshold": 0.01
|
| 15 |
+
}
|
gating_mechanism.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a1b4605a1297d5d738dda9e6b739764e6412e27053f75df64ebe99f5f49188c
|
| 3 |
+
size 7090146
|
model-index.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "Retroactive Reasoning Network for Question Answering",
|
| 3 |
+
"language": "en",
|
| 4 |
+
"license": "apache-2.0",
|
| 5 |
+
"task_categories": [
|
| 6 |
+
"question-answering"
|
| 7 |
+
],
|
| 8 |
+
"tags": [
|
| 9 |
+
"question-answering",
|
| 10 |
+
"multi-step-reasoning",
|
| 11 |
+
"retroactive-reasoning",
|
| 12 |
+
"squad"
|
| 13 |
+
],
|
| 14 |
+
"datasets": [
|
| 15 |
+
"squad"
|
| 16 |
+
],
|
| 17 |
+
"metrics": [
|
| 18 |
+
"exact_match",
|
| 19 |
+
"f1"
|
| 20 |
+
],
|
| 21 |
+
"model-index": [
|
| 22 |
+
{
|
| 23 |
+
"name": "RRN QA Model",
|
| 24 |
+
"results": [
|
| 25 |
+
{
|
| 26 |
+
"task": {
|
| 27 |
+
"type": "question-answering",
|
| 28 |
+
"name": "Question Answering"
|
| 29 |
+
},
|
| 30 |
+
"dataset": {
|
| 31 |
+
"name": "SQuAD",
|
| 32 |
+
"type": "squad"
|
| 33 |
+
},
|
| 34 |
+
"metrics": [
|
| 35 |
+
{
|
| 36 |
+
"type": "exact_match",
|
| 37 |
+
"value": "TBD",
|
| 38 |
+
"name": "Exact Match"
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"type": "f1",
|
| 42 |
+
"value": "TBD",
|
| 43 |
+
"name": "F1"
|
| 44 |
+
}
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
]
|
| 50 |
+
}
|
qa_head.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:70288115fb5d3176e59d873418d1d1e7fd4dc9c3d6911f8edeee93354227b82b
|
| 3 |
+
size 14175663
|
retroactive_layer.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1faa8764e567c3b5575c5aff1b8c4c3ed7486935730923c692dae3d4122cefbd
|
| 3 |
+
size 59048138
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
step_controller.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cae0d7b2423071c8e36f0d07965a540193fb739638351115bfa8045b9844aac7
|
| 3 |
+
size 398480
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_lower_case": true,
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "[MASK]",
|
| 49 |
+
"model_max_length": 512,
|
| 50 |
+
"pad_token": "[PAD]",
|
| 51 |
+
"sep_token": "[SEP]",
|
| 52 |
+
"strip_accents": null,
|
| 53 |
+
"tokenize_chinese_chars": true,
|
| 54 |
+
"tokenizer_class": "BertTokenizer",
|
| 55 |
+
"unk_token": "[UNK]"
|
| 56 |
+
}
|
vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|