|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from transformers import AutoModelForQuestionAnswering, AutoConfig, AutoModel
|
|
|
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
|
|
|
|
|
import config
|
|
|
from modules import CrossAttentionDelta, GatingMechanism, EnhancedQAHead
|
|
|
from memory import ActiveMemory
|
|
|
|
|
|
class EnhancedRRN_QA_Model(nn.Module):
|
|
|
"""
|
|
|
Enhanced Retroactive Reasoning Network for Question Answering.
|
|
|
Improvements:
|
|
|
1. Delta magnitude constraint
|
|
|
2. Gating mechanism
|
|
|
3. Multi-step reasoning
|
|
|
4. Active memory usage
|
|
|
5. Enhanced QA head
|
|
|
6. Improved cross-attention
|
|
|
"""
|
|
|
def __init__(self, model_name=config.BASE_MODEL_NAME):
|
|
|
super().__init__()
|
|
|
self.model_name = model_name
|
|
|
|
|
|
|
|
|
self.num_reasoning_steps = config.NUM_REASONING_STEPS
|
|
|
self.delta_target_ratio = config.DELTA_TARGET_RATIO
|
|
|
|
|
|
|
|
|
self.use_dynamic_steps = config.USE_DYNAMIC_STEPS
|
|
|
self.max_reasoning_steps = config.MAX_REASONING_STEPS
|
|
|
self.min_reasoning_steps = config.MIN_REASONING_STEPS
|
|
|
self.reasoning_step_type = config.REASONING_STEP_TYPE
|
|
|
self.early_stop_threshold = config.EARLY_STOP_THRESHOLD
|
|
|
|
|
|
|
|
|
self.base_config = AutoConfig.from_pretrained(
|
|
|
self.model_name,
|
|
|
output_hidden_states=True,
|
|
|
)
|
|
|
self.hidden_dim = self.base_config.hidden_size
|
|
|
|
|
|
|
|
|
if self.use_dynamic_steps and self.reasoning_step_type == "learned":
|
|
|
self.step_controller = nn.Sequential(
|
|
|
nn.Linear(self.hidden_dim, 128),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(128, self.max_reasoning_steps - self.min_reasoning_steps + 1)
|
|
|
)
|
|
|
print(f"Using learned dynamic reasoning steps (min={self.min_reasoning_steps}, max={self.max_reasoning_steps})")
|
|
|
|
|
|
|
|
|
self.base_model = AutoModel.from_pretrained(
|
|
|
self.model_name,
|
|
|
config=self.base_config
|
|
|
)
|
|
|
print(f"Loaded base model: {self.model_name}")
|
|
|
print(f"Hidden dimension: {self.hidden_dim}")
|
|
|
print(f"Using {self.num_reasoning_steps} reasoning steps")
|
|
|
|
|
|
|
|
|
|
|
|
self.retroactive_update_layer = CrossAttentionDelta(self.hidden_dim)
|
|
|
|
|
|
|
|
|
self.gating_mechanism = GatingMechanism(self.hidden_dim)
|
|
|
|
|
|
|
|
|
self.qa_head = EnhancedQAHead(self.hidden_dim)
|
|
|
|
|
|
|
|
|
self.memory = ActiveMemory(
|
|
|
max_size=config.MEMORY_MAX_SIZE,
|
|
|
retrieval_k=config.MEMORY_RETRIEVAL_K
|
|
|
)
|
|
|
|
|
|
|
|
|
self.coherence_loss_fn = nn.MSELoss()
|
|
|
self.delta_reg_loss_fn = nn.MSELoss()
|
|
|
|
|
|
def _apply_delta_constraint(self, delta, h0, is_training=False):
|
|
|
"""
|
|
|
Apply delta magnitude constraint to prevent destabilizing updates.
|
|
|
|
|
|
Args:
|
|
|
delta: The computed delta
|
|
|
h0: The initial hidden states
|
|
|
is_training: Whether we're in training mode
|
|
|
|
|
|
Returns:
|
|
|
constrained_delta: The constrained delta
|
|
|
delta_reg_loss: Regularization loss for delta magnitude (if training)
|
|
|
"""
|
|
|
|
|
|
delta_norm = delta.norm(dim=-1, keepdim=True)
|
|
|
h0_norm = h0.norm(dim=-1, keepdim=True).detach()
|
|
|
|
|
|
|
|
|
ratio = delta_norm / (h0_norm + 1e-9)
|
|
|
|
|
|
|
|
|
delta_reg_loss = None
|
|
|
if is_training:
|
|
|
|
|
|
target_ratio = torch.ones_like(ratio) * self.delta_target_ratio
|
|
|
delta_reg_loss = self.delta_reg_loss_fn(ratio, target_ratio)
|
|
|
|
|
|
|
|
|
|
|
|
scale_factor = torch.ones_like(ratio)
|
|
|
too_large = ratio > self.delta_target_ratio
|
|
|
if too_large.any():
|
|
|
scale_factor[too_large] = self.delta_target_ratio / ratio[too_large]
|
|
|
|
|
|
|
|
|
constrained_delta = delta * scale_factor
|
|
|
|
|
|
return constrained_delta, delta_reg_loss
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids=None,
|
|
|
attention_mask=None,
|
|
|
token_type_ids=None,
|
|
|
start_positions=None,
|
|
|
end_positions=None,
|
|
|
output_attentions=None,
|
|
|
output_hidden_states=None,
|
|
|
return_dict=None,
|
|
|
use_memory=True
|
|
|
):
|
|
|
return_dict = return_dict if return_dict is not None else self.base_config.use_return_dict
|
|
|
is_training = self.training
|
|
|
|
|
|
|
|
|
|
|
|
include_token_type_ids = token_type_ids is not None
|
|
|
|
|
|
if include_token_type_ids:
|
|
|
outputs = self.base_model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
token_type_ids=token_type_ids,
|
|
|
output_hidden_states=True,
|
|
|
output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
|
|
|
return_dict=True
|
|
|
)
|
|
|
else:
|
|
|
outputs = self.base_model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
output_hidden_states=True,
|
|
|
output_attentions=output_attentions if output_attentions is not None else self.base_config.output_attentions,
|
|
|
return_dict=True
|
|
|
)
|
|
|
|
|
|
|
|
|
h0 = outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
reasoning_trace_T = outputs.hidden_states
|
|
|
|
|
|
|
|
|
y0_output = self.qa_head(h0)
|
|
|
y0_start_logits, y0_end_logits = y0_output["start_logits"], y0_output["end_logits"]
|
|
|
|
|
|
|
|
|
memory_context = None
|
|
|
if use_memory and (is_training and config.MEMORY_USE_DURING_TRAINING or not is_training):
|
|
|
if len(self.memory) > 0:
|
|
|
memory_context = self.memory.get_memory_context(h0, attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
h_current = h0
|
|
|
|
|
|
|
|
|
all_deltas = []
|
|
|
all_gates = []
|
|
|
all_hidden_states = [h0]
|
|
|
|
|
|
|
|
|
actual_steps_taken = 0
|
|
|
|
|
|
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps:
|
|
|
if self.reasoning_step_type == "learned":
|
|
|
|
|
|
pooled_h0 = h0.mean(dim=1)
|
|
|
|
|
|
|
|
|
step_logits = self.step_controller(pooled_h0)
|
|
|
|
|
|
if is_training:
|
|
|
|
|
|
step_probs = F.softmax(step_logits, dim=-1)
|
|
|
steps_idx = torch.multinomial(step_probs, 1).squeeze(-1)
|
|
|
num_steps = steps_idx + self.min_reasoning_steps
|
|
|
else:
|
|
|
|
|
|
steps_idx = torch.argmax(step_logits, dim=-1)
|
|
|
num_steps = steps_idx + self.min_reasoning_steps
|
|
|
|
|
|
|
|
|
step_probs = F.softmax(step_logits, dim=-1)
|
|
|
|
|
|
|
|
|
max_num_steps = num_steps.max().item()
|
|
|
elif self.reasoning_step_type == "confidence":
|
|
|
|
|
|
max_num_steps = self.max_reasoning_steps
|
|
|
else:
|
|
|
|
|
|
max_num_steps = self.num_reasoning_steps
|
|
|
else:
|
|
|
|
|
|
max_num_steps = self.num_reasoning_steps
|
|
|
|
|
|
|
|
|
for step in range(max_num_steps):
|
|
|
|
|
|
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "confidence" and step >= self.min_reasoning_steps:
|
|
|
|
|
|
if len(all_deltas) > 0:
|
|
|
prev_delta = all_deltas[-1]
|
|
|
delta_norm = prev_delta.norm(dim=-1).mean().item()
|
|
|
if delta_norm < self.early_stop_threshold:
|
|
|
break
|
|
|
|
|
|
|
|
|
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned":
|
|
|
|
|
|
if step > 0:
|
|
|
|
|
|
continue_mask = (step < num_steps).float().unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
if continue_mask.sum() == 0:
|
|
|
break
|
|
|
|
|
|
|
|
|
if config.BYPASS_DELTA_CALCULATION:
|
|
|
|
|
|
delta = torch.zeros_like(h_current)
|
|
|
attn_weights = None
|
|
|
else:
|
|
|
delta, attn_weights = self.retroactive_update_layer(h_current, reasoning_trace_T)
|
|
|
|
|
|
|
|
|
constrained_delta, delta_reg_loss = self._apply_delta_constraint(delta, h0, is_training)
|
|
|
|
|
|
|
|
|
if hasattr(self, 'use_dynamic_steps') and self.use_dynamic_steps and self.reasoning_step_type == "learned" and step > 0:
|
|
|
constrained_delta = constrained_delta * continue_mask
|
|
|
|
|
|
|
|
|
gate = self.gating_mechanism(h_current, constrained_delta)
|
|
|
|
|
|
|
|
|
h_current = h_current + gate * constrained_delta
|
|
|
|
|
|
|
|
|
all_deltas.append(constrained_delta)
|
|
|
all_gates.append(gate)
|
|
|
all_hidden_states.append(h_current)
|
|
|
actual_steps_taken = step + 1
|
|
|
|
|
|
|
|
|
h_final = h_current
|
|
|
|
|
|
|
|
|
y_final_output = self.qa_head(h_final)
|
|
|
y_final_start_logits, y_final_end_logits = y_final_output["start_logits"], y_final_output["end_logits"]
|
|
|
|
|
|
|
|
|
total_loss = None
|
|
|
loss_components = {}
|
|
|
|
|
|
if start_positions is not None and end_positions is not None:
|
|
|
|
|
|
if len(start_positions.size()) > 1:
|
|
|
start_positions = start_positions.squeeze(-1)
|
|
|
if len(end_positions.size()) > 1:
|
|
|
end_positions = end_positions.squeeze(-1)
|
|
|
|
|
|
ignored_index = y_final_start_logits.size(1)
|
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
|
|
start_loss = loss_fct(y_final_start_logits, start_positions)
|
|
|
end_loss = loss_fct(y_final_end_logits, end_positions)
|
|
|
task_loss = (start_loss + end_loss) / 2
|
|
|
loss_components["task_loss"] = task_loss.item()
|
|
|
|
|
|
|
|
|
coherence_loss_start = self.coherence_loss_fn(y0_start_logits, y_final_start_logits.detach())
|
|
|
coherence_loss_end = self.coherence_loss_fn(y0_end_logits, y_final_end_logits.detach())
|
|
|
coherence_loss = (coherence_loss_start + coherence_loss_end) / 2
|
|
|
loss_components["coherence_loss"] = coherence_loss.item()
|
|
|
|
|
|
|
|
|
if delta_reg_loss is not None:
|
|
|
loss_components["delta_reg_loss"] = delta_reg_loss.item()
|
|
|
|
|
|
|
|
|
total_loss = task_loss + config.LAMBDA_COHERENCE * coherence_loss
|
|
|
|
|
|
|
|
|
if delta_reg_loss is not None:
|
|
|
total_loss = total_loss + config.LAMBDA_DELTA_REG * delta_reg_loss
|
|
|
|
|
|
|
|
|
if use_memory:
|
|
|
|
|
|
input_data = {'input_ids': input_ids, 'attention_mask': attention_mask}
|
|
|
if token_type_ids is not None:
|
|
|
input_data['token_type_ids'] = token_type_ids
|
|
|
|
|
|
|
|
|
initial_output = {'start_logits': y0_start_logits, 'end_logits': y0_end_logits}
|
|
|
final_output = {'start_logits': y_final_start_logits, 'end_logits': y_final_end_logits}
|
|
|
|
|
|
|
|
|
if is_training and config.MEMORY_USE_DURING_TRAINING or not is_training:
|
|
|
self.memory.add(
|
|
|
input_data=input_data,
|
|
|
hidden_states=h0,
|
|
|
output=initial_output,
|
|
|
reasoning_trace=reasoning_trace_T,
|
|
|
final_hidden_states=h_final,
|
|
|
final_output=final_output
|
|
|
)
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
output = (y_final_start_logits, y_final_end_logits) + outputs[2:]
|
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
|
|
|
|
|
|
|
|
self.custom_outputs = {
|
|
|
"initial_hidden_states": h0,
|
|
|
"final_hidden_states": h_final,
|
|
|
"all_hidden_states": all_hidden_states,
|
|
|
"all_deltas": all_deltas,
|
|
|
"all_gates": all_gates,
|
|
|
"y0_start_logits": y0_start_logits,
|
|
|
"y0_end_logits": y0_end_logits,
|
|
|
"loss_components": loss_components if total_loss is not None else None,
|
|
|
"steps_taken": actual_steps_taken
|
|
|
}
|
|
|
|
|
|
|
|
|
if self.use_dynamic_steps and self.reasoning_step_type == "learned":
|
|
|
self.custom_outputs["step_probs"] = step_probs
|
|
|
self.custom_outputs["num_steps"] = num_steps
|
|
|
|
|
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
|
loss=total_loss,
|
|
|
start_logits=y_final_start_logits,
|
|
|
end_logits=y_final_end_logits,
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
attentions=outputs.attentions
|
|
|
)
|
|
|
|