|
|
"""
|
|
|
Utilities for dual encoder configuration and initialization.
|
|
|
"""
|
|
|
import logging
|
|
|
from typing import Dict, Any, Optional, Union
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from config import load_config, app_config
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DualEncoderConfig:
|
|
|
"""Configuration object for dual encoders"""
|
|
|
|
|
|
def __init__(self, config_dict: Optional[Dict[str, Any]] = None):
|
|
|
"""
|
|
|
Initialize dual encoder configuration.
|
|
|
|
|
|
Args:
|
|
|
config_dict: Optional configuration dictionary to override defaults
|
|
|
"""
|
|
|
|
|
|
config = load_config()
|
|
|
|
|
|
|
|
|
self.USE_PRETRAINED_ENCODER = True
|
|
|
self.USE_CUSTOM_ENCODER = True
|
|
|
self.FUSION_METHOD = "concat"
|
|
|
self.FUSION_WEIGHTS = [0.5, 0.5]
|
|
|
self.TRAINING_MODE = "joint"
|
|
|
|
|
|
|
|
|
if hasattr(config, "DUAL_ENCODER_CONFIG"):
|
|
|
dual_config = config.DUAL_ENCODER_CONFIG
|
|
|
for key, value in dual_config.items():
|
|
|
setattr(self, key, value)
|
|
|
|
|
|
|
|
|
if config_dict:
|
|
|
for key, value in config_dict.items():
|
|
|
setattr(self, key, value)
|
|
|
|
|
|
class DualEncoderFusion(nn.Module):
|
|
|
"""
|
|
|
Module that combines outputs from pretrained and custom encoders.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Union[Dict[str, Any], DualEncoderConfig]] = None):
|
|
|
"""
|
|
|
Initialize fusion module.
|
|
|
|
|
|
Args:
|
|
|
config: Configuration for fusion (dict or DualEncoderConfig object)
|
|
|
"""
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
if isinstance(config, dict):
|
|
|
self.config = DualEncoderConfig(config)
|
|
|
elif config is None:
|
|
|
self.config = DualEncoderConfig()
|
|
|
else:
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
if self.config.FUSION_METHOD == "weighted_sum":
|
|
|
weights = torch.tensor(self.config.FUSION_WEIGHTS, dtype=torch.float32)
|
|
|
self.register_buffer('fusion_weights', weights / weights.sum())
|
|
|
|
|
|
def forward(self, pretrained_output: torch.Tensor, custom_output: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Combine encoder outputs based on fusion method.
|
|
|
|
|
|
Args:
|
|
|
pretrained_output: Output from pretrained encoder
|
|
|
custom_output: Output from custom encoder
|
|
|
|
|
|
Returns:
|
|
|
Combined tensor
|
|
|
"""
|
|
|
|
|
|
if not self.config.USE_PRETRAINED_ENCODER:
|
|
|
return custom_output
|
|
|
if not self.config.USE_CUSTOM_ENCODER:
|
|
|
return pretrained_output
|
|
|
|
|
|
|
|
|
if self.config.FUSION_METHOD == "concat":
|
|
|
return torch.cat([pretrained_output, custom_output], dim=-1)
|
|
|
elif self.config.FUSION_METHOD == "add":
|
|
|
|
|
|
if pretrained_output.shape != custom_output.shape:
|
|
|
raise ValueError(f"Cannot add tensors with different shapes: {pretrained_output.shape} and {custom_output.shape}")
|
|
|
return pretrained_output + custom_output
|
|
|
elif self.config.FUSION_METHOD == "weighted_sum":
|
|
|
|
|
|
if pretrained_output.shape != custom_output.shape:
|
|
|
raise ValueError(f"Cannot use weighted sum with different shapes: {pretrained_output.shape} and {custom_output.shape}")
|
|
|
|
|
|
w1, w2 = self.fusion_weights
|
|
|
return w1 * pretrained_output + w2 * custom_output
|
|
|
else:
|
|
|
raise ValueError(f"Unknown fusion method: {self.config.FUSION_METHOD}")
|
|
|
|
|
|
def get_dual_encoder_config() -> DualEncoderConfig:
|
|
|
"""
|
|
|
Get dual encoder configuration from app_config.
|
|
|
|
|
|
Returns:
|
|
|
DualEncoderConfig object
|
|
|
"""
|
|
|
return DualEncoderConfig()
|
|
|
|
|
|
|
|
|
def test_fusion_methods():
|
|
|
"""Test different fusion methods"""
|
|
|
config = DualEncoderConfig()
|
|
|
|
|
|
|
|
|
x1 = torch.randn(2, 10, 768)
|
|
|
x2 = torch.randn(2, 10, 768)
|
|
|
|
|
|
|
|
|
config.FUSION_METHOD = "concat"
|
|
|
fusion_concat = DualEncoderFusion(config)
|
|
|
output_concat = fusion_concat(x1, x2)
|
|
|
print(f"Concat output shape: {output_concat.shape}")
|
|
|
|
|
|
|
|
|
config.FUSION_METHOD = "add"
|
|
|
fusion_add = DualEncoderFusion(config)
|
|
|
output_add = fusion_add(x1, x2)
|
|
|
print(f"Add output shape: {output_add.shape}")
|
|
|
|
|
|
|
|
|
config.FUSION_METHOD = "weighted_sum"
|
|
|
config.FUSION_WEIGHTS = [0.7, 0.3]
|
|
|
fusion_weighted = DualEncoderFusion(config)
|
|
|
output_weighted = fusion_weighted(x1, x2)
|
|
|
print(f"Weighted sum output shape: {output_weighted.shape}")
|
|
|
|
|
|
return {
|
|
|
"concat": output_concat,
|
|
|
"add": output_add,
|
|
|
"weighted_sum": output_weighted
|
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
test_results = test_fusion_methods()
|
|
|
|