# model.py (Enhanced RRN Implementation) import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForQuestionAnswering, AutoConfig, AutoModel from transformers.modeling_outputs import QuestionAnsweringModelOutput import config from modules import CrossAttentionDelta, GatingMechanism, EnhancedQAHead from memory import ActiveMemory class EnhancedRRN_QA_Model(nn.Module): """ Enhanced Retroactive Reasoning Network for Question Answering. Improvements: 1. Delta magnitude constraint 2. Gating mechanism 3. Multi-step reasoning 4. Active memory usage 5. Enhanced QA head 6. Improved cross-attention """ def __init__(self, model_name=config.BASE_MODEL_NAME): super().__init__() self.model_name = model_name # --- Configuration --- self.num_reasoning_steps = config.NUM_REASONING_STEPS self.delta_target_ratio = config.DELTA_TARGET_RATIO # --- Dynamic Reasoning Steps Configuration --- self.use_dynamic_steps = config.USE_DYNAMIC_STEPS self.max_reasoning_steps = config.MAX_REASONING_STEPS self.min_reasoning_steps = config.MIN_REASONING_STEPS self.reasoning_step_type = config.REASONING_STEP_TYPE self.early_stop_threshold = config.EARLY_STOP_THRESHOLD # --- Load Base Model Configuration --- self.base_config = AutoConfig.from_pretrained( self.model_name, output_hidden_states=True, # Crucial for Reasoning Trace (T) ) self.hidden_dim = self.base_config.hidden_size # Add step controller for learned approach (after hidden_dim is defined) if self.use_dynamic_steps and self.reasoning_step_type == "learned": self.step_controller = nn.Sequential( nn.Linear(self.hidden_dim, 128), nn.ReLU(), nn.Linear(128, self.max_reasoning_steps - self.min_reasoning_steps + 1) ) print(f"Using learned dynamic reasoning steps (min={self.min_reasoning_steps}, max={self.max_reasoning_steps})") # --- Load Base Model --- self.base_model = AutoModel.from_pretrained( self.model_name, config=self.base_config ) print(f"Loaded base model: {self.model_name}") print(f"Hidden dimension: {self.hidden_dim}") print(f"Using {self.num_reasoning_steps} reasoning steps") # --- Enhanced RRN Components --- # Improved cross-attention delta mechanism self.retroactive_update_layer = CrossAttentionDelta(self.hidden_dim) # Gating mechanism for selective updates self.gating_mechanism = GatingMechanism(self.hidden_dim) # Enhanced QA head with deeper architecture and bilinear scoring self.qa_head = EnhancedQAHead(self.hidden_dim) # --- Active Memory Module --- self.memory = ActiveMemory( max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K ) # --- Loss Functions --- self.coherence_loss_fn = nn.MSELoss() self.delta_reg_loss_fn = nn.MSELoss() def _apply_delta_constraint(self, delta, h0, is_training=False): """ Apply delta magnitude constraint to prevent destabilizing updates. Args: delta: The computed delta h0: The initial hidden states is_training: Whether we're in training mode Returns: constrained_delta: The constrained delta delta_reg_loss: Regularization loss for delta magnitude (if training) """ # Compute delta and h0 norms delta_norm = delta.norm(dim=-1, keepdim=True) h0_norm = h0.norm(dim=-1, keepdim=True).detach() # Compute ratio ratio = delta_norm / (h0_norm + 1e-9) # Compute regularization loss if in training delta_reg_loss = None if is_training: # Target ratio tensor (same shape as ratio) target_ratio = torch.ones_like(ratio) * self.delta_target_ratio delta_reg_loss = self.delta_reg_loss_fn(ratio, target_ratio) # Apply direct constraint (both during training and inference) # Only scale down deltas that are too large scale_factor = torch.ones_like(ratio) too_large = ratio > self.delta_target_ratio if too_large.any(): scale_factor[too_large] = self.delta_target_ratio / ratio[too_large] # Apply scaling constrained_delta = delta * scale_factor return constrained_delta, delta_reg_loss def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_memory=True ): return_dict = return_dict if return_dict is not None else self.base_config.use_return_dict is_training = self.training # === 1. Initial Forward Pass === # Determine if token_type_ids should be passed include_token_type_ids = token_type_ids is not None if include_token_type_ids: outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True, output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions, return_dict=True ) else: outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions, return_dict=True ) # H(0): Last hidden state from the base model h0 = outputs.last_hidden_state # T: Reasoning Trace (all hidden states) reasoning_trace_T = outputs.hidden_states # y^(0): Initial QA prediction using H(0) y0_output = self.qa_head(h0) y0_start_logits, y0_end_logits = y0_output["start_logits"], y0_output["end_logits"] # === 2. Memory Integration (if enabled) === memory_context = None if use_memory and (is_training and config.MEMORY_USE_DURING_TRAINING or not is_training): if len(self.memory) > 0: memory_context = self.memory.get_memory_context(h0, attention_mask) # === 3. Multi-step Reasoning === # Initialize current hidden state h_current = h0 # Store all deltas and gates for loss calculation and analysis all_deltas = [] all_gates = [] all_hidden_states = [h0] # Determine number of reasoning steps to use actual_steps_taken = 0 if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps: if self.reasoning_step_type == "learned": # Pool sequence dimension to get a single vector per example pooled_h0 = h0.mean(dim=1) # Get step logits from controller step_logits = self.step_controller(pooled_h0) if is_training: # During training, sample from distribution (exploration) step_probs = F.softmax(step_logits, dim=-1) steps_idx = torch.multinomial(step_probs, 1).squeeze(-1) num_steps = steps_idx + self.min_reasoning_steps else: # During inference, take argmax (exploitation) steps_idx = torch.argmax(step_logits, dim=-1) num_steps = steps_idx + self.min_reasoning_steps # Store step logits for analysis step_probs = F.softmax(step_logits, dim=-1) # Get the maximum number of steps across the batch max_num_steps = num_steps.max().item() elif self.reasoning_step_type == "confidence": # For confidence-based, we'll determine dynamically during the loop max_num_steps = self.max_reasoning_steps else: # Fallback to fixed steps max_num_steps = self.num_reasoning_steps else: # Use fixed number of steps max_num_steps = self.num_reasoning_steps # Perform reasoning steps for step in range(max_num_steps): # For confidence-based, check if we should continue for each example if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "confidence" and step >= self.min_reasoning_steps: # Check delta magnitude from previous step if len(all_deltas) > 0: prev_delta = all_deltas[-1] delta_norm = prev_delta.norm(dim=-1).mean().item() if delta_norm < self.early_stop_threshold: break # For learned approach, check if we've reached the determined number of steps if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned": # Create a mask for examples that should continue if step > 0: # Skip first step check since all examples need at least 1 step # Check which examples should continue continue_mask = (step < num_steps).float().unsqueeze(-1).unsqueeze(-1) # If no examples need more steps, break if continue_mask.sum() == 0: break # Compute delta using the current hidden state and reasoning trace if config.BYPASS_DELTA_CALCULATION: # Bypass delta calculation for testing delta = torch.zeros_like(h_current) attn_weights = None else: delta, attn_weights = self.retroactive_update_layer(h_current, reasoning_trace_T) # Apply delta magnitude constraint constrained_delta, delta_reg_loss = self._apply_delta_constraint(delta, h0, is_training) # For learned approach with continue_mask, apply mask to delta if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned" and step > 0: constrained_delta = constrained_delta * continue_mask # Compute gate values for selective update gate = self.gating_mechanism(h_current, constrained_delta) # Apply gated update h_current = h_current + gate * constrained_delta # Store for later use all_deltas.append(constrained_delta) all_gates.append(gate) all_hidden_states.append(h_current) actual_steps_taken = step + 1 # Final hidden state after all reasoning steps h_final = h_current # === 4. Final Prediction === y_final_output = self.qa_head(h_final) y_final_start_logits, y_final_end_logits = y_final_output["start_logits"], y_final_output["end_logits"] # === 5. Loss Calculation === total_loss = None loss_components = {} if start_positions is not None and end_positions is not None: # Prepare ground truth positions if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) ignored_index = y_final_start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) # Task Loss (QA Loss) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(y_final_start_logits, start_positions) end_loss = loss_fct(y_final_end_logits, end_positions) task_loss = (start_loss + end_loss) / 2 loss_components["task_loss"] = task_loss.item() # Coherence Loss coherence_loss_start = self.coherence_loss_fn(y0_start_logits, y_final_start_logits.detach()) coherence_loss_end = self.coherence_loss_fn(y0_end_logits, y_final_end_logits.detach()) coherence_loss = (coherence_loss_start + coherence_loss_end) / 2 loss_components["coherence_loss"] = coherence_loss.item() # Delta Regularization Loss (if computed) if delta_reg_loss is not None: loss_components["delta_reg_loss"] = delta_reg_loss.item() # Total Loss total_loss = task_loss + config.LAMBDA_COHERENCE * coherence_loss # Add delta regularization if computed if delta_reg_loss is not None: total_loss = total_loss + config.LAMBDA_DELTA_REG * delta_reg_loss # === 6. Memory Update === if use_memory: # Prepare input data input_data = {'input_ids': input_ids, 'attention_mask': attention_mask} if token_type_ids is not None: input_data['token_type_ids'] = token_type_ids # Prepare outputs initial_output = {'start_logits': y0_start_logits, 'end_logits': y0_end_logits} final_output = {'start_logits': y_final_start_logits, 'end_logits': y_final_end_logits} # Add to memory (during both training and inference if enabled) if is_training and config.MEMORY_USE_DURING_TRAINING or not is_training: self.memory.add( input_data=input_data, hidden_states=h0, output=initial_output, reasoning_trace=reasoning_trace_T, final_hidden_states=h_final, final_output=final_output ) # === 7. Return Outputs === if not return_dict: output = (y_final_start_logits, y_final_end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output # Store custom outputs as instance attributes for later access if needed # This avoids passing them to QuestionAnsweringModelOutput which doesn't accept them self.custom_outputs = { "initial_hidden_states": h0, "final_hidden_states": h_final, "all_hidden_states": all_hidden_states, "all_deltas": all_deltas, "all_gates": all_gates, "y0_start_logits": y0_start_logits, "y0_end_logits": y0_end_logits, "loss_components": loss_components if total_loss is not None else None, "steps_taken": actual_steps_taken } # Add step controller outputs if using learned approach if self.use_dynamic_steps and self.reasoning_step_type == "learned": self.custom_outputs["step_probs"] = step_probs self.custom_outputs["num_steps"] = num_steps # Return standard QuestionAnsweringModelOutput without custom fields return QuestionAnsweringModelOutput( loss=total_loss, start_logits=y_final_start_logits, end_logits=y_final_end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions )