File size: 5,604 Bytes
0861a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""

Utilities for dual encoder configuration and initialization.

"""
import logging
from typing import Dict, Any, Optional, Union
import torch
import torch.nn as nn

from config import load_config, app_config

logger = logging.getLogger(__name__)

class DualEncoderConfig:
    """Configuration object for dual encoders"""
    
    def __init__(self, config_dict: Optional[Dict[str, Any]] = None):
        """

        Initialize dual encoder configuration.

        

        Args:

            config_dict: Optional configuration dictionary to override defaults

        """
        # Set defaults from app_config
        config = load_config()
        
        # Default configuration
        self.USE_PRETRAINED_ENCODER = True
        self.USE_CUSTOM_ENCODER = True
        self.FUSION_METHOD = "concat"  # Options: concat, add, weighted_sum
        self.FUSION_WEIGHTS = [0.5, 0.5]  # Weights for pretrained and custom encoders
        self.TRAINING_MODE = "joint"  # Options: joint, alternating, pretrained_first
        
        # Override defaults with app_config if available
        if hasattr(config, "DUAL_ENCODER_CONFIG"):
            dual_config = config.DUAL_ENCODER_CONFIG
            for key, value in dual_config.items():
                setattr(self, key, value)
                
        # Override with provided config_dict if available
        if config_dict:
            for key, value in config_dict.items():
                setattr(self, key, value)

class DualEncoderFusion(nn.Module):
    """

    Module that combines outputs from pretrained and custom encoders.

    """
    
    def __init__(self, config: Optional[Union[Dict[str, Any], DualEncoderConfig]] = None):
        """

        Initialize fusion module.

        

        Args:

            config: Configuration for fusion (dict or DualEncoderConfig object)

        """
        super().__init__()
        
        # Convert dict to DualEncoderConfig if needed
        if isinstance(config, dict):
            self.config = DualEncoderConfig(config)
        elif config is None:
            self.config = DualEncoderConfig()
        else:
            self.config = config
            
        # Initialize fusion weights if using weighted sum
        if self.config.FUSION_METHOD == "weighted_sum":
            weights = torch.tensor(self.config.FUSION_WEIGHTS, dtype=torch.float32)
            self.register_buffer('fusion_weights', weights / weights.sum())
    
    def forward(self, pretrained_output: torch.Tensor, custom_output: torch.Tensor) -> torch.Tensor:
        """

        Combine encoder outputs based on fusion method.

        

        Args:

            pretrained_output: Output from pretrained encoder

            custom_output: Output from custom encoder

            

        Returns:

            Combined tensor

        """
        # Handle the case where one encoder is disabled
        if not self.config.USE_PRETRAINED_ENCODER:
            return custom_output
        if not self.config.USE_CUSTOM_ENCODER:
            return pretrained_output
            
        # Apply fusion method
        if self.config.FUSION_METHOD == "concat":
            return torch.cat([pretrained_output, custom_output], dim=-1)
        elif self.config.FUSION_METHOD == "add":
            # Ensure dimensions match
            if pretrained_output.shape != custom_output.shape:
                raise ValueError(f"Cannot add tensors with different shapes: {pretrained_output.shape} and {custom_output.shape}")
            return pretrained_output + custom_output
        elif self.config.FUSION_METHOD == "weighted_sum":
            # Ensure dimensions match
            if pretrained_output.shape != custom_output.shape:
                raise ValueError(f"Cannot use weighted sum with different shapes: {pretrained_output.shape} and {custom_output.shape}")
            # Apply weights
            w1, w2 = self.fusion_weights
            return w1 * pretrained_output + w2 * custom_output
        else:
            raise ValueError(f"Unknown fusion method: {self.config.FUSION_METHOD}")

def get_dual_encoder_config() -> DualEncoderConfig:
    """

    Get dual encoder configuration from app_config.

    

    Returns:

        DualEncoderConfig object

    """
    return DualEncoderConfig()

# Testing function
def test_fusion_methods():
    """Test different fusion methods"""
    config = DualEncoderConfig()
    
    # Create test tensors
    x1 = torch.randn(2, 10, 768)
    x2 = torch.randn(2, 10, 768)
    
    # Test concat fusion
    config.FUSION_METHOD = "concat"
    fusion_concat = DualEncoderFusion(config)
    output_concat = fusion_concat(x1, x2)
    print(f"Concat output shape: {output_concat.shape}")  # Should be [2, 10, 1536]
    
    # Test add fusion
    config.FUSION_METHOD = "add"
    fusion_add = DualEncoderFusion(config)
    output_add = fusion_add(x1, x2)
    print(f"Add output shape: {output_add.shape}")  # Should be [2, 10, 768]
    
    # Test weighted sum fusion
    config.FUSION_METHOD = "weighted_sum"
    config.FUSION_WEIGHTS = [0.7, 0.3]
    fusion_weighted = DualEncoderFusion(config)
    output_weighted = fusion_weighted(x1, x2)
    print(f"Weighted sum output shape: {output_weighted.shape}")  # Should be [2, 10, 768]
    
    return {
        "concat": output_concat,
        "add": output_add,
        "weighted_sum": output_weighted
    }

if __name__ == "__main__":
    # Run tests
    test_results = test_fusion_methods()