""" Utilities for dual encoder configuration and initialization. """ import logging from typing import Dict, Any, Optional, Union import torch import torch.nn as nn from config import load_config, app_config logger = logging.getLogger(__name__) class DualEncoderConfig: """Configuration object for dual encoders""" def __init__(self, config_dict: Optional[Dict[str, Any]] = None): """ Initialize dual encoder configuration. Args: config_dict: Optional configuration dictionary to override defaults """ # Set defaults from app_config config = load_config() # Default configuration self.USE_PRETRAINED_ENCODER = True self.USE_CUSTOM_ENCODER = True self.FUSION_METHOD = "concat" # Options: concat, add, weighted_sum self.FUSION_WEIGHTS = [0.5, 0.5] # Weights for pretrained and custom encoders self.TRAINING_MODE = "joint" # Options: joint, alternating, pretrained_first # Override defaults with app_config if available if hasattr(config, "DUAL_ENCODER_CONFIG"): dual_config = config.DUAL_ENCODER_CONFIG for key, value in dual_config.items(): setattr(self, key, value) # Override with provided config_dict if available if config_dict: for key, value in config_dict.items(): setattr(self, key, value) class DualEncoderFusion(nn.Module): """ Module that combines outputs from pretrained and custom encoders. """ def __init__(self, config: Optional[Union[Dict[str, Any], DualEncoderConfig]] = None): """ Initialize fusion module. Args: config: Configuration for fusion (dict or DualEncoderConfig object) """ super().__init__() # Convert dict to DualEncoderConfig if needed if isinstance(config, dict): self.config = DualEncoderConfig(config) elif config is None: self.config = DualEncoderConfig() else: self.config = config # Initialize fusion weights if using weighted sum if self.config.FUSION_METHOD == "weighted_sum": weights = torch.tensor(self.config.FUSION_WEIGHTS, dtype=torch.float32) self.register_buffer('fusion_weights', weights / weights.sum()) def forward(self, pretrained_output: torch.Tensor, custom_output: torch.Tensor) -> torch.Tensor: """ Combine encoder outputs based on fusion method. Args: pretrained_output: Output from pretrained encoder custom_output: Output from custom encoder Returns: Combined tensor """ # Handle the case where one encoder is disabled if not self.config.USE_PRETRAINED_ENCODER: return custom_output if not self.config.USE_CUSTOM_ENCODER: return pretrained_output # Apply fusion method if self.config.FUSION_METHOD == "concat": return torch.cat([pretrained_output, custom_output], dim=-1) elif self.config.FUSION_METHOD == "add": # Ensure dimensions match if pretrained_output.shape != custom_output.shape: raise ValueError(f"Cannot add tensors with different shapes: {pretrained_output.shape} and {custom_output.shape}") return pretrained_output + custom_output elif self.config.FUSION_METHOD == "weighted_sum": # Ensure dimensions match if pretrained_output.shape != custom_output.shape: raise ValueError(f"Cannot use weighted sum with different shapes: {pretrained_output.shape} and {custom_output.shape}") # Apply weights w1, w2 = self.fusion_weights return w1 * pretrained_output + w2 * custom_output else: raise ValueError(f"Unknown fusion method: {self.config.FUSION_METHOD}") def get_dual_encoder_config() -> DualEncoderConfig: """ Get dual encoder configuration from app_config. Returns: DualEncoderConfig object """ return DualEncoderConfig() # Testing function def test_fusion_methods(): """Test different fusion methods""" config = DualEncoderConfig() # Create test tensors x1 = torch.randn(2, 10, 768) x2 = torch.randn(2, 10, 768) # Test concat fusion config.FUSION_METHOD = "concat" fusion_concat = DualEncoderFusion(config) output_concat = fusion_concat(x1, x2) print(f"Concat output shape: {output_concat.shape}") # Should be [2, 10, 1536] # Test add fusion config.FUSION_METHOD = "add" fusion_add = DualEncoderFusion(config) output_add = fusion_add(x1, x2) print(f"Add output shape: {output_add.shape}") # Should be [2, 10, 768] # Test weighted sum fusion config.FUSION_METHOD = "weighted_sum" config.FUSION_WEIGHTS = [0.7, 0.3] fusion_weighted = DualEncoderFusion(config) output_weighted = fusion_weighted(x1, x2) print(f"Weighted sum output shape: {output_weighted.shape}") # Should be [2, 10, 768] return { "concat": output_concat, "add": output_add, "weighted_sum": output_weighted } if __name__ == "__main__": # Run tests test_results = test_fusion_methods()