File size: 8,729 Bytes
0861a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import json
from typing import Dict, Any, Optional, List, Tuple
import re
from pathlib import Path

class AttentionProfileSelector:
    """

    Selects appropriate attention profiles based on input characteristics

    and configuration specified in the JSON dataset.

    """
    
    def __init__(self, config_path: Optional[str] = None):
        """

        Initialize the selector with the provided configuration.

        

        Args:

            config_path: Path to the attention configuration JSON

        """
        if config_path is None:
            # Default to the standard location
            config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
        
        self.config = self._load_config(config_path)
        self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])}
        self.default_profile_id = self.config.get("default_profile", "standard")
        self.selection_strategy = self.config.get("profile_selection_strategy", {})
        
    def _load_config(self, config_path: str) -> Dict[str, Any]:
        """Load configuration from JSON file."""
        try:
            with open(config_path, 'r') as f:
                return json.load(f)
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error loading attention configuration: {e}")
            return {}
    
    def select_profile(self, 

                      input_text: str, 

                      context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]:
        """

        Select the most appropriate attention profile based on input characteristics.

        

        Args:

            input_text: The user's input text

            context: Additional context about the interaction

            

        Returns:

            Tuple of (profile_id, confidence)

        """
        if not self.profiles:
            return self.default_profile_id, 1.0
        
        # Initialize scores for each profile
        scores = {profile_id: 0.0 for profile_id in self.profiles.keys()}
        
        # Calculate content length score
        input_length = len(input_text)
        for profile_id, profile in self.profiles.items():
            # Check document length thresholds
            length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0)
            if input_length > length_threshold and length_threshold > 0:
                scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3)
        
        # Check content type signals
        for profile_id, profile in self.profiles.items():
            content_signals = profile.get("activation_signals", {}).get("content_type_signals", [])
            matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower())
            if content_signals:
                signal_score = matched_signals / len(content_signals)
                scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
        
        # Check structure indicators
        for profile_id, profile in self.profiles.items():
            structure_signals = profile.get("activation_signals", {}).get("structure_indicators", [])
            matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower())
            if structure_signals:
                signal_score = matched_signals / len(structure_signals)
                scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
        
        # Check for explicit request in context
        if context and "requested_attention" in context:
            requested = context["requested_attention"]
            if requested in self.profiles:
                scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0)
        
        # Get the highest scoring profile
        if not scores:
            return self.default_profile_id, 1.0
            
        best_profile_id = max(scores.items(), key=lambda x: x[1])[0]
        confidence = scores[best_profile_id]
        
        # Apply minimum confidence threshold
        min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65)
        if confidence < min_confidence:
            return self.default_profile_id, confidence
            
        return best_profile_id, confidence
    
    def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]:
        """

        Get the parameters for the specified attention profile.

        

        Args:

            profile_id: ID of the attention profile

            

        Returns:

            Dictionary of attention parameters

        """
        if profile_id in self.profiles:
            return self.profiles[profile_id].get("parameters", {})
        return {}
    
    def get_attention_type(self, profile_id: str) -> str:
        """

        Get the attention mechanism type for the specified profile.

        

        Args:

            profile_id: ID of the attention profile

            

        Returns:

            String identifying the attention type

        """
        if profile_id in self.profiles:
            return self.profiles[profile_id].get("attention_type", "standard")
        return "standard"


# Factory method to create appropriate attention mechanism
def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector):
    """

    Factory function to create an attention mechanism based on the selected profile.

    

    Args:

        profile_id: ID of the selected attention profile

        model_dim: Model hidden dimension

        selector: AttentionProfileSelector instance

        

    Returns:

        Configured attention mechanism

    """
    # This function would integrate with your existing attention mechanisms
    # For implementation with smartHybridAttention:
    attention_type = selector.get_attention_type(profile_id)
    parameters = selector.get_profile_parameters(profile_id)
    
    # Import here to avoid circular imports
    try:
        from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention
        
        # Map parameters from JSON to the attention mechanism
        attention_params = {
            "dim": model_dim,
            "num_heads": parameters.get("num_heads", 8),
            "window_size": parameters.get("window_size", 256),
            "use_sliding": parameters.get("use_sliding_window", True),
            "use_global": parameters.get("use_global_tokens", True),
            "global_token_ratio": parameters.get("global_token_ratio", 0.05),
            "memory_tokens": parameters.get("memory_token_count", 16)
        }
        
        # Create appropriate attention mechanism based on type
        if attention_type == "hierarchical":
            attention_params["use_hierarchical"] = True
        
        return create_smart_hybrid_attention(**attention_params)
        
    except ImportError:
        print("Warning: smartHybridAttention not found. Using placeholder.")
        # Return a placeholder if the module is not available
        import torch.nn as nn
        return nn.MultiheadAttention(model_dim, 8)


# Usage example:
if __name__ == "__main__":
    selector = AttentionProfileSelector()
    
    # Example inputs
    code_input = "def calculate_fibonacci(n):\n    if n <= 1:\n        return n\n    else:\n        return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)"
    
    document_input = """# Chapter 1: Introduction

    This technical document covers the architecture of our system.

    ## Section 1.1: Overview

    The system consists of multiple components working together.

    """
    
    conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected."
    
    # Test profile selection
    code_profile, code_conf = selector.select_profile(code_input)
    doc_profile, doc_conf = selector.select_profile(document_input)
    conv_profile, conv_conf = selector.select_profile(conversation_input)
    
    print(f"Code input → {code_profile} (confidence: {code_conf:.2f})")
    print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})")
    print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")