from dataclasses import dataclass from typing import Optional import numpy as np from loguru import logger # Optional PyTorch imports try: import torch import torch.nn as nn import torch.nn.functional as F TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False logger.warning("PyTorch not available. Using numpy-based fusion only.") # PyTorch-based models (only available if torch is installed) if TORCH_AVAILABLE: @dataclass class FusionInput: """Input container for fusion layer""" lm_logits: torch.Tensor retrieval_scores: torch.Tensor retrieval_features: Optional[torch.Tensor] = None @dataclass class FusionOutput: """Output container for fusion layer""" final_probs: torch.Tensor fused_logits: torch.Tensor lm_weight: torch.Tensor retrieval_weight: torch.Tensor class RetrievalMLP(nn.Module): """ Two-layer MLP that projects retrieval scores to label space. As described in paper: "MLP is a two-layer network that projects retrieval scores to the label space" """ def __init__( self, input_dim: int = 64, hidden_dim: int = 128, output_dim: int = 1, dropout: float = 0.1, ): super().__init__() self.layers = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, output_dim), ) self._init_weights() def _init_weights(self): """Initialize weights using Xavier initialization""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) class ConfidenceAwareFusion(nn.Module): """ Implements Equation (2) from the paper: pfinal(y|q, D) = σ(β · pLM + (1 − β) · MLP(pret)) Where σ is Sigmoid for binary classification (num_classes=2). For binary: MLP outputs 1 logit; positive/negative probs are derived via softmax. With contrastive loss from Equation (3): L = -log(e^sp / Σe^sn) + λ||β||² """ def __init__( self, retrieval_input_dim: int = 64, hidden_dim: int = 128, num_classes: int = 2, initial_beta: float = 0.5, lambda_reg: float = 0.01, learn_beta: bool = True, ): super().__init__() self.num_classes = num_classes self.lambda_reg = lambda_reg self.is_binary = num_classes == 2 # Trainable gating parameter β # We use a logit parameter and apply sigmoid in forward() to ensure β ∈ [0, 1] self._beta_logit = nn.Parameter( torch.tensor(self._inverse_sigmoid(initial_beta)), requires_grad=learn_beta, ) # MLP for projecting retrieval scores # For binary classification, output 1 logit; else output num_classes logits mlp_output_dim = 1 if self.is_binary else num_classes self.retrieval_mlp = RetrievalMLP( input_dim=retrieval_input_dim, hidden_dim=hidden_dim, output_dim=mlp_output_dim, ) activation_type = "sigmoid" if self.is_binary else "softmax" logger.info( f"ConfidenceAwareFusion initialized: β={initial_beta}, λ={lambda_reg}, num_classes={num_classes}, activation={activation_type}" ) def _inverse_sigmoid(self, x: float) -> float: """Inverse sigmoid for initialization""" x = np.clip(x, 1e-6, 1 - 1e-6) return np.log(x / (1 - x)) @property def beta(self) -> torch.Tensor: """Get the current gating parameter β. Guaranteed to be in [0, 1].""" return torch.sigmoid(self._beta_logit) def forward( self, lm_logits: torch.Tensor, retrieval_features: torch.Tensor ) -> FusionOutput: """Forward pass implementing Equation (2): β·pLM + (1-β)·MLP(pret).""" batch_size = lm_logits.size(0) beta = self.beta # Project retrieval features to label space retrieval_logits = self.retrieval_mlp(retrieval_features) if self.is_binary: # Binary classification: treat as 2-class softmax (same as multi-class) # lm_logits: [B, 2] in LABEL_LIST order [positive, negative] assert lm_logits.size(1) == 2, ( f"Binary mode: lm_logits should be [B, 2], got {lm_logits.shape}" ) # For binary, MLP outputs 1 logit → expand to 2 logits [pos, neg] # retrieval_logits: [B, 1] → treat as positive logit; negative = -positive assert retrieval_logits.size() == (batch_size, 1), ( f"Binary mode: retrieval_logits should be [B, 1], got {retrieval_logits.shape}" ) # Expand retrieval to 2 logits in label order: [r, -r] retrieval_logits_2 = torch.cat( [retrieval_logits, -retrieval_logits], dim=-1 ) # [B, 2] # Fuse: β·pLM + (1-β)·MLP(pret) fused_logits = ( beta * lm_logits + (1 - beta) * retrieval_logits_2 ) # [B, 2] # Apply softmax to get probabilities final_probs = torch.softmax(fused_logits, dim=-1) # [B, 2] # final_probs[:, 0] = P(positive), final_probs[:, 1] = P(negative) else: # Multi-class: use softmax assert lm_logits.size(1) == self.num_classes, ( f"lm_logits shape mismatch: expected [B, {self.num_classes}], got {lm_logits.shape}" ) assert retrieval_logits.size() == (batch_size, self.num_classes), ( f"retrieval_logits shape mismatch: expected [{batch_size}, {self.num_classes}], got {retrieval_logits.shape}" ) fused_logits = beta * lm_logits + (1 - beta) * retrieval_logits final_probs = torch.softmax(fused_logits, dim=-1) return FusionOutput( final_probs=final_probs, fused_logits=fused_logits, lm_weight=beta.detach(), retrieval_weight=(1 - beta).detach(), ) def compute_contrastive_loss( self, positive_scores: torch.Tensor, negative_scores: torch.Tensor, temperature: float = 1.0, ) -> torch.Tensor: """Compute contrastive loss from Equation (3).""" sp = positive_scores / temperature sn = negative_scores / temperature numerator = torch.exp(sp) denominator = numerator + torch.sum(torch.exp(sn), dim=-1, keepdim=True) contrastive_loss = -torch.log(numerator / (denominator + 1e-8)) contrastive_loss = contrastive_loss.mean() beta_reg = self.lambda_reg * (self.beta**2) return contrastive_loss + beta_reg class RetrievalFeatureEncoder(nn.Module): """Encodes retrieval results into features for fusion.""" def __init__( self, num_retrieved: int = 5, score_features: int = 4, hidden_dim: int = 64, output_dim: int = 64, ): super().__init__() self.num_retrieved = num_retrieved input_dim = num_retrieved * score_features self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim), ) self.attention = nn.Sequential( nn.Linear(score_features, 16), nn.Tanh(), nn.Linear(16, 1) ) def forward( self, retrieval_scores: torch.Tensor, retrieval_features: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size = retrieval_scores.size(0) attn_logits = self.attention(retrieval_scores) attn_weights = F.softmax(attn_logits, dim=1) weighted = retrieval_scores * attn_weights flat = weighted.view(batch_size, -1) encoded = self.encoder(flat) return encoded