File size: 8,729 Bytes
0861a59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
import json
from typing import Dict, Any, Optional, List, Tuple
import re
from pathlib import Path
class AttentionProfileSelector:
"""
Selects appropriate attention profiles based on input characteristics
and configuration specified in the JSON dataset.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the selector with the provided configuration.
Args:
config_path: Path to the attention configuration JSON
"""
if config_path is None:
# Default to the standard location
config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
self.config = self._load_config(config_path)
self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])}
self.default_profile_id = self.config.get("default_profile", "standard")
self.selection_strategy = self.config.get("profile_selection_strategy", {})
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from JSON file."""
try:
with open(config_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading attention configuration: {e}")
return {}
def select_profile(self,
input_text: str,
context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]:
"""
Select the most appropriate attention profile based on input characteristics.
Args:
input_text: The user's input text
context: Additional context about the interaction
Returns:
Tuple of (profile_id, confidence)
"""
if not self.profiles:
return self.default_profile_id, 1.0
# Initialize scores for each profile
scores = {profile_id: 0.0 for profile_id in self.profiles.keys()}
# Calculate content length score
input_length = len(input_text)
for profile_id, profile in self.profiles.items():
# Check document length thresholds
length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0)
if input_length > length_threshold and length_threshold > 0:
scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3)
# Check content type signals
for profile_id, profile in self.profiles.items():
content_signals = profile.get("activation_signals", {}).get("content_type_signals", [])
matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower())
if content_signals:
signal_score = matched_signals / len(content_signals)
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
# Check structure indicators
for profile_id, profile in self.profiles.items():
structure_signals = profile.get("activation_signals", {}).get("structure_indicators", [])
matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower())
if structure_signals:
signal_score = matched_signals / len(structure_signals)
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
# Check for explicit request in context
if context and "requested_attention" in context:
requested = context["requested_attention"]
if requested in self.profiles:
scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0)
# Get the highest scoring profile
if not scores:
return self.default_profile_id, 1.0
best_profile_id = max(scores.items(), key=lambda x: x[1])[0]
confidence = scores[best_profile_id]
# Apply minimum confidence threshold
min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65)
if confidence < min_confidence:
return self.default_profile_id, confidence
return best_profile_id, confidence
def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]:
"""
Get the parameters for the specified attention profile.
Args:
profile_id: ID of the attention profile
Returns:
Dictionary of attention parameters
"""
if profile_id in self.profiles:
return self.profiles[profile_id].get("parameters", {})
return {}
def get_attention_type(self, profile_id: str) -> str:
"""
Get the attention mechanism type for the specified profile.
Args:
profile_id: ID of the attention profile
Returns:
String identifying the attention type
"""
if profile_id in self.profiles:
return self.profiles[profile_id].get("attention_type", "standard")
return "standard"
# Factory method to create appropriate attention mechanism
def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector):
"""
Factory function to create an attention mechanism based on the selected profile.
Args:
profile_id: ID of the selected attention profile
model_dim: Model hidden dimension
selector: AttentionProfileSelector instance
Returns:
Configured attention mechanism
"""
# This function would integrate with your existing attention mechanisms
# For implementation with smartHybridAttention:
attention_type = selector.get_attention_type(profile_id)
parameters = selector.get_profile_parameters(profile_id)
# Import here to avoid circular imports
try:
from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention
# Map parameters from JSON to the attention mechanism
attention_params = {
"dim": model_dim,
"num_heads": parameters.get("num_heads", 8),
"window_size": parameters.get("window_size", 256),
"use_sliding": parameters.get("use_sliding_window", True),
"use_global": parameters.get("use_global_tokens", True),
"global_token_ratio": parameters.get("global_token_ratio", 0.05),
"memory_tokens": parameters.get("memory_token_count", 16)
}
# Create appropriate attention mechanism based on type
if attention_type == "hierarchical":
attention_params["use_hierarchical"] = True
return create_smart_hybrid_attention(**attention_params)
except ImportError:
print("Warning: smartHybridAttention not found. Using placeholder.")
# Return a placeholder if the module is not available
import torch.nn as nn
return nn.MultiheadAttention(model_dim, 8)
# Usage example:
if __name__ == "__main__":
selector = AttentionProfileSelector()
# Example inputs
code_input = "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)"
document_input = """# Chapter 1: Introduction
This technical document covers the architecture of our system.
## Section 1.1: Overview
The system consists of multiple components working together.
"""
conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected."
# Test profile selection
code_profile, code_conf = selector.select_profile(code_input)
doc_profile, doc_conf = selector.select_profile(document_input)
conv_profile, conv_conf = selector.select_profile(conversation_input)
print(f"Code input → {code_profile} (confidence: {code_conf:.2f})")
print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})")
print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")
|