Wildnerve-tlm01_Hybrid_Model / utils /dual_encoder_utils.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
"""
Utilities for dual encoder configuration and initialization.
"""
import logging
from typing import Dict, Any, Optional, Union
import torch
import torch.nn as nn
from config import load_config, app_config
logger = logging.getLogger(__name__)
class DualEncoderConfig:
"""Configuration object for dual encoders"""
def __init__(self, config_dict: Optional[Dict[str, Any]] = None):
"""
Initialize dual encoder configuration.
Args:
config_dict: Optional configuration dictionary to override defaults
"""
# Set defaults from app_config
config = load_config()
# Default configuration
self.USE_PRETRAINED_ENCODER = True
self.USE_CUSTOM_ENCODER = True
self.FUSION_METHOD = "concat" # Options: concat, add, weighted_sum
self.FUSION_WEIGHTS = [0.5, 0.5] # Weights for pretrained and custom encoders
self.TRAINING_MODE = "joint" # Options: joint, alternating, pretrained_first
# Override defaults with app_config if available
if hasattr(config, "DUAL_ENCODER_CONFIG"):
dual_config = config.DUAL_ENCODER_CONFIG
for key, value in dual_config.items():
setattr(self, key, value)
# Override with provided config_dict if available
if config_dict:
for key, value in config_dict.items():
setattr(self, key, value)
class DualEncoderFusion(nn.Module):
"""
Module that combines outputs from pretrained and custom encoders.
"""
def __init__(self, config: Optional[Union[Dict[str, Any], DualEncoderConfig]] = None):
"""
Initialize fusion module.
Args:
config: Configuration for fusion (dict or DualEncoderConfig object)
"""
super().__init__()
# Convert dict to DualEncoderConfig if needed
if isinstance(config, dict):
self.config = DualEncoderConfig(config)
elif config is None:
self.config = DualEncoderConfig()
else:
self.config = config
# Initialize fusion weights if using weighted sum
if self.config.FUSION_METHOD == "weighted_sum":
weights = torch.tensor(self.config.FUSION_WEIGHTS, dtype=torch.float32)
self.register_buffer('fusion_weights', weights / weights.sum())
def forward(self, pretrained_output: torch.Tensor, custom_output: torch.Tensor) -> torch.Tensor:
"""
Combine encoder outputs based on fusion method.
Args:
pretrained_output: Output from pretrained encoder
custom_output: Output from custom encoder
Returns:
Combined tensor
"""
# Handle the case where one encoder is disabled
if not self.config.USE_PRETRAINED_ENCODER:
return custom_output
if not self.config.USE_CUSTOM_ENCODER:
return pretrained_output
# Apply fusion method
if self.config.FUSION_METHOD == "concat":
return torch.cat([pretrained_output, custom_output], dim=-1)
elif self.config.FUSION_METHOD == "add":
# Ensure dimensions match
if pretrained_output.shape != custom_output.shape:
raise ValueError(f"Cannot add tensors with different shapes: {pretrained_output.shape} and {custom_output.shape}")
return pretrained_output + custom_output
elif self.config.FUSION_METHOD == "weighted_sum":
# Ensure dimensions match
if pretrained_output.shape != custom_output.shape:
raise ValueError(f"Cannot use weighted sum with different shapes: {pretrained_output.shape} and {custom_output.shape}")
# Apply weights
w1, w2 = self.fusion_weights
return w1 * pretrained_output + w2 * custom_output
else:
raise ValueError(f"Unknown fusion method: {self.config.FUSION_METHOD}")
def get_dual_encoder_config() -> DualEncoderConfig:
"""
Get dual encoder configuration from app_config.
Returns:
DualEncoderConfig object
"""
return DualEncoderConfig()
# Testing function
def test_fusion_methods():
"""Test different fusion methods"""
config = DualEncoderConfig()
# Create test tensors
x1 = torch.randn(2, 10, 768)
x2 = torch.randn(2, 10, 768)
# Test concat fusion
config.FUSION_METHOD = "concat"
fusion_concat = DualEncoderFusion(config)
output_concat = fusion_concat(x1, x2)
print(f"Concat output shape: {output_concat.shape}") # Should be [2, 10, 1536]
# Test add fusion
config.FUSION_METHOD = "add"
fusion_add = DualEncoderFusion(config)
output_add = fusion_add(x1, x2)
print(f"Add output shape: {output_add.shape}") # Should be [2, 10, 768]
# Test weighted sum fusion
config.FUSION_METHOD = "weighted_sum"
config.FUSION_WEIGHTS = [0.7, 0.3]
fusion_weighted = DualEncoderFusion(config)
output_weighted = fusion_weighted(x1, x2)
print(f"Weighted sum output shape: {output_weighted.shape}") # Should be [2, 10, 768]
return {
"concat": output_concat,
"add": output_add,
"weighted_sum": output_weighted
}
if __name__ == "__main__":
# Run tests
test_results = test_fusion_methods()