import os import json from typing import Dict, Any, Optional, List, Tuple import re from pathlib import Path class AttentionProfileSelector: """ Selects appropriate attention profiles based on input characteristics and configuration specified in the JSON dataset. """ def __init__(self, config_path: Optional[str] = None): """ Initialize the selector with the provided configuration. Args: config_path: Path to the attention configuration JSON """ if config_path is None: # Default to the standard location config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json") self.config = self._load_config(config_path) self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])} self.default_profile_id = self.config.get("default_profile", "standard") self.selection_strategy = self.config.get("profile_selection_strategy", {}) def _load_config(self, config_path: str) -> Dict[str, Any]: """Load configuration from JSON file.""" try: with open(config_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"Error loading attention configuration: {e}") return {} def select_profile(self, input_text: str, context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]: """ Select the most appropriate attention profile based on input characteristics. Args: input_text: The user's input text context: Additional context about the interaction Returns: Tuple of (profile_id, confidence) """ if not self.profiles: return self.default_profile_id, 1.0 # Initialize scores for each profile scores = {profile_id: 0.0 for profile_id in self.profiles.keys()} # Calculate content length score input_length = len(input_text) for profile_id, profile in self.profiles.items(): # Check document length thresholds length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0) if input_length > length_threshold and length_threshold > 0: scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3) # Check content type signals for profile_id, profile in self.profiles.items(): content_signals = profile.get("activation_signals", {}).get("content_type_signals", []) matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower()) if content_signals: signal_score = matched_signals / len(content_signals) scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5) # Check structure indicators for profile_id, profile in self.profiles.items(): structure_signals = profile.get("activation_signals", {}).get("structure_indicators", []) matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower()) if structure_signals: signal_score = matched_signals / len(structure_signals) scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5) # Check for explicit request in context if context and "requested_attention" in context: requested = context["requested_attention"] if requested in self.profiles: scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0) # Get the highest scoring profile if not scores: return self.default_profile_id, 1.0 best_profile_id = max(scores.items(), key=lambda x: x[1])[0] confidence = scores[best_profile_id] # Apply minimum confidence threshold min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65) if confidence < min_confidence: return self.default_profile_id, confidence return best_profile_id, confidence def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]: """ Get the parameters for the specified attention profile. Args: profile_id: ID of the attention profile Returns: Dictionary of attention parameters """ if profile_id in self.profiles: return self.profiles[profile_id].get("parameters", {}) return {} def get_attention_type(self, profile_id: str) -> str: """ Get the attention mechanism type for the specified profile. Args: profile_id: ID of the attention profile Returns: String identifying the attention type """ if profile_id in self.profiles: return self.profiles[profile_id].get("attention_type", "standard") return "standard" # Factory method to create appropriate attention mechanism def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector): """ Factory function to create an attention mechanism based on the selected profile. Args: profile_id: ID of the selected attention profile model_dim: Model hidden dimension selector: AttentionProfileSelector instance Returns: Configured attention mechanism """ # This function would integrate with your existing attention mechanisms # For implementation with smartHybridAttention: attention_type = selector.get_attention_type(profile_id) parameters = selector.get_profile_parameters(profile_id) # Import here to avoid circular imports try: from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention # Map parameters from JSON to the attention mechanism attention_params = { "dim": model_dim, "num_heads": parameters.get("num_heads", 8), "window_size": parameters.get("window_size", 256), "use_sliding": parameters.get("use_sliding_window", True), "use_global": parameters.get("use_global_tokens", True), "global_token_ratio": parameters.get("global_token_ratio", 0.05), "memory_tokens": parameters.get("memory_token_count", 16) } # Create appropriate attention mechanism based on type if attention_type == "hierarchical": attention_params["use_hierarchical"] = True return create_smart_hybrid_attention(**attention_params) except ImportError: print("Warning: smartHybridAttention not found. Using placeholder.") # Return a placeholder if the module is not available import torch.nn as nn return nn.MultiheadAttention(model_dim, 8) # Usage example: if __name__ == "__main__": selector = AttentionProfileSelector() # Example inputs code_input = "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)" document_input = """# Chapter 1: Introduction This technical document covers the architecture of our system. ## Section 1.1: Overview The system consists of multiple components working together. """ conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected." # Test profile selection code_profile, code_conf = selector.select_profile(code_input) doc_profile, doc_conf = selector.select_profile(document_input) conv_profile, conv_conf = selector.select_profile(conversation_input) print(f"Code input → {code_profile} (confidence: {code_conf:.2f})") print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})") print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")