rrn-qa / code /model.py
will4381's picture
Upload folder using huggingface_hub
3451ca0 verified
# model.py (Enhanced RRN Implementation)
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForQuestionAnswering, AutoConfig, AutoModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput
import config
from modules import CrossAttentionDelta, GatingMechanism, EnhancedQAHead
from memory import ActiveMemory
class EnhancedRRN_QA_Model(nn.Module):
"""
Enhanced Retroactive Reasoning Network for Question Answering.
Improvements:
1. Delta magnitude constraint
2. Gating mechanism
3. Multi-step reasoning
4. Active memory usage
5. Enhanced QA head
6. Improved cross-attention
"""
def __init__(self, model_name=config.BASE_MODEL_NAME):
super().__init__()
self.model_name = model_name
# --- Configuration ---
self.num_reasoning_steps = config.NUM_REASONING_STEPS
self.delta_target_ratio = config.DELTA_TARGET_RATIO
# --- Dynamic Reasoning Steps Configuration ---
self.use_dynamic_steps = config.USE_DYNAMIC_STEPS
self.max_reasoning_steps = config.MAX_REASONING_STEPS
self.min_reasoning_steps = config.MIN_REASONING_STEPS
self.reasoning_step_type = config.REASONING_STEP_TYPE
self.early_stop_threshold = config.EARLY_STOP_THRESHOLD
# --- Load Base Model Configuration ---
self.base_config = AutoConfig.from_pretrained(
self.model_name,
output_hidden_states=True, # Crucial for Reasoning Trace (T)
)
self.hidden_dim = self.base_config.hidden_size
# Add step controller for learned approach (after hidden_dim is defined)
if self.use_dynamic_steps and self.reasoning_step_type == "learned":
self.step_controller = nn.Sequential(
nn.Linear(self.hidden_dim, 128),
nn.ReLU(),
nn.Linear(128, self.max_reasoning_steps - self.min_reasoning_steps + 1)
)
print(f"Using learned dynamic reasoning steps (min={self.min_reasoning_steps}, max={self.max_reasoning_steps})")
# --- Load Base Model ---
self.base_model = AutoModel.from_pretrained(
self.model_name,
config=self.base_config
)
print(f"Loaded base model: {self.model_name}")
print(f"Hidden dimension: {self.hidden_dim}")
print(f"Using {self.num_reasoning_steps} reasoning steps")
# --- Enhanced RRN Components ---
# Improved cross-attention delta mechanism
self.retroactive_update_layer = CrossAttentionDelta(self.hidden_dim)
# Gating mechanism for selective updates
self.gating_mechanism = GatingMechanism(self.hidden_dim)
# Enhanced QA head with deeper architecture and bilinear scoring
self.qa_head = EnhancedQAHead(self.hidden_dim)
# --- Active Memory Module ---
self.memory = ActiveMemory(
max_size=config.MEMORY_MAX_SIZE,
retrieval_k=config.MEMORY_RETRIEVAL_K
)
# --- Loss Functions ---
self.coherence_loss_fn = nn.MSELoss()
self.delta_reg_loss_fn = nn.MSELoss()
def _apply_delta_constraint(self, delta, h0, is_training=False):
"""
Apply delta magnitude constraint to prevent destabilizing updates.
Args:
delta: The computed delta
h0: The initial hidden states
is_training: Whether we're in training mode
Returns:
constrained_delta: The constrained delta
delta_reg_loss: Regularization loss for delta magnitude (if training)
"""
# Compute delta and h0 norms
delta_norm = delta.norm(dim=-1, keepdim=True)
h0_norm = h0.norm(dim=-1, keepdim=True).detach()
# Compute ratio
ratio = delta_norm / (h0_norm + 1e-9)
# Compute regularization loss if in training
delta_reg_loss = None
if is_training:
# Target ratio tensor (same shape as ratio)
target_ratio = torch.ones_like(ratio) * self.delta_target_ratio
delta_reg_loss = self.delta_reg_loss_fn(ratio, target_ratio)
# Apply direct constraint (both during training and inference)
# Only scale down deltas that are too large
scale_factor = torch.ones_like(ratio)
too_large = ratio > self.delta_target_ratio
if too_large.any():
scale_factor[too_large] = self.delta_target_ratio / ratio[too_large]
# Apply scaling
constrained_delta = delta * scale_factor
return constrained_delta, delta_reg_loss
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
use_memory=True
):
return_dict = return_dict if return_dict is not None else self.base_config.use_return_dict
is_training = self.training
# === 1. Initial Forward Pass ===
# Determine if token_type_ids should be passed
include_token_type_ids = token_type_ids is not None
if include_token_type_ids:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True,
output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
return_dict=True
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
return_dict=True
)
# H(0): Last hidden state from the base model
h0 = outputs.last_hidden_state
# T: Reasoning Trace (all hidden states)
reasoning_trace_T = outputs.hidden_states
# y^(0): Initial QA prediction using H(0)
y0_output = self.qa_head(h0)
y0_start_logits, y0_end_logits = y0_output["start_logits"], y0_output["end_logits"]
# === 2. Memory Integration (if enabled) ===
memory_context = None
if use_memory and (is_training and config.MEMORY_USE_DURING_TRAINING or not is_training):
if len(self.memory) > 0:
memory_context = self.memory.get_memory_context(h0, attention_mask)
# === 3. Multi-step Reasoning ===
# Initialize current hidden state
h_current = h0
# Store all deltas and gates for loss calculation and analysis
all_deltas = []
all_gates = []
all_hidden_states = [h0]
# Determine number of reasoning steps to use
actual_steps_taken = 0
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps:
if self.reasoning_step_type == "learned":
# Pool sequence dimension to get a single vector per example
pooled_h0 = h0.mean(dim=1)
# Get step logits from controller
step_logits = self.step_controller(pooled_h0)
if is_training:
# During training, sample from distribution (exploration)
step_probs = F.softmax(step_logits, dim=-1)
steps_idx = torch.multinomial(step_probs, 1).squeeze(-1)
num_steps = steps_idx + self.min_reasoning_steps
else:
# During inference, take argmax (exploitation)
steps_idx = torch.argmax(step_logits, dim=-1)
num_steps = steps_idx + self.min_reasoning_steps
# Store step logits for analysis
step_probs = F.softmax(step_logits, dim=-1)
# Get the maximum number of steps across the batch
max_num_steps = num_steps.max().item()
elif self.reasoning_step_type == "confidence":
# For confidence-based, we'll determine dynamically during the loop
max_num_steps = self.max_reasoning_steps
else:
# Fallback to fixed steps
max_num_steps = self.num_reasoning_steps
else:
# Use fixed number of steps
max_num_steps = self.num_reasoning_steps
# Perform reasoning steps
for step in range(max_num_steps):
# For confidence-based, check if we should continue for each example
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "confidence" and step >= self.min_reasoning_steps:
# Check delta magnitude from previous step
if len(all_deltas) > 0:
prev_delta = all_deltas[-1]
delta_norm = prev_delta.norm(dim=-1).mean().item()
if delta_norm < self.early_stop_threshold:
break
# For learned approach, check if we've reached the determined number of steps
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned":
# Create a mask for examples that should continue
if step > 0: # Skip first step check since all examples need at least 1 step
# Check which examples should continue
continue_mask = (step < num_steps).float().unsqueeze(-1).unsqueeze(-1)
# If no examples need more steps, break
if continue_mask.sum() == 0:
break
# Compute delta using the current hidden state and reasoning trace
if config.BYPASS_DELTA_CALCULATION:
# Bypass delta calculation for testing
delta = torch.zeros_like(h_current)
attn_weights = None
else:
delta, attn_weights = self.retroactive_update_layer(h_current, reasoning_trace_T)
# Apply delta magnitude constraint
constrained_delta, delta_reg_loss = self._apply_delta_constraint(delta, h0, is_training)
# For learned approach with continue_mask, apply mask to delta
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned" and step > 0:
constrained_delta = constrained_delta * continue_mask
# Compute gate values for selective update
gate = self.gating_mechanism(h_current, constrained_delta)
# Apply gated update
h_current = h_current + gate * constrained_delta
# Store for later use
all_deltas.append(constrained_delta)
all_gates.append(gate)
all_hidden_states.append(h_current)
actual_steps_taken = step + 1
# Final hidden state after all reasoning steps
h_final = h_current
# === 4. Final Prediction ===
y_final_output = self.qa_head(h_final)
y_final_start_logits, y_final_end_logits = y_final_output["start_logits"], y_final_output["end_logits"]
# === 5. Loss Calculation ===
total_loss = None
loss_components = {}
if start_positions is not None and end_positions is not None:
# Prepare ground truth positions
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
ignored_index = y_final_start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
# Task Loss (QA Loss)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(y_final_start_logits, start_positions)
end_loss = loss_fct(y_final_end_logits, end_positions)
task_loss = (start_loss + end_loss) / 2
loss_components["task_loss"] = task_loss.item()
# Coherence Loss
coherence_loss_start = self.coherence_loss_fn(y0_start_logits, y_final_start_logits.detach())
coherence_loss_end = self.coherence_loss_fn(y0_end_logits, y_final_end_logits.detach())
coherence_loss = (coherence_loss_start + coherence_loss_end) / 2
loss_components["coherence_loss"] = coherence_loss.item()
# Delta Regularization Loss (if computed)
if delta_reg_loss is not None:
loss_components["delta_reg_loss"] = delta_reg_loss.item()
# Total Loss
total_loss = task_loss + config.LAMBDA_COHERENCE * coherence_loss
# Add delta regularization if computed
if delta_reg_loss is not None:
total_loss = total_loss + config.LAMBDA_DELTA_REG * delta_reg_loss
# === 6. Memory Update ===
if use_memory:
# Prepare input data
input_data = {'input_ids': input_ids, 'attention_mask': attention_mask}
if token_type_ids is not None:
input_data['token_type_ids'] = token_type_ids
# Prepare outputs
initial_output = {'start_logits': y0_start_logits, 'end_logits': y0_end_logits}
final_output = {'start_logits': y_final_start_logits, 'end_logits': y_final_end_logits}
# Add to memory (during both training and inference if enabled)
if is_training and config.MEMORY_USE_DURING_TRAINING or not is_training:
self.memory.add(
input_data=input_data,
hidden_states=h0,
output=initial_output,
reasoning_trace=reasoning_trace_T,
final_hidden_states=h_final,
final_output=final_output
)
# === 7. Return Outputs ===
if not return_dict:
output = (y_final_start_logits, y_final_end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
# Store custom outputs as instance attributes for later access if needed
# This avoids passing them to QuestionAnsweringModelOutput which doesn't accept them
self.custom_outputs = {
"initial_hidden_states": h0,
"final_hidden_states": h_final,
"all_hidden_states": all_hidden_states,
"all_deltas": all_deltas,
"all_gates": all_gates,
"y0_start_logits": y0_start_logits,
"y0_end_logits": y0_end_logits,
"loss_components": loss_components if total_loss is not None else None,
"steps_taken": actual_steps_taken
}
# Add step controller outputs if using learned approach
if self.use_dynamic_steps and self.reasoning_step_type == "learned":
self.custom_outputs["step_probs"] = step_probs
self.custom_outputs["num_steps"] = num_steps
# Return standard QuestionAnsweringModelOutput without custom fields
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=y_final_start_logits,
end_logits=y_final_end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)