| import os
|
| import json
|
| import logging
|
| from typing import Dict, Any, Optional, List, Tuple, Union
|
| import torch
|
|
|
|
|
| try:
|
| from attention_trigger_system import AttentionProfileSelector
|
| ATTENTION_SELECTOR_AVAILABLE = True
|
| except ImportError:
|
| ATTENTION_SELECTOR_AVAILABLE = False
|
| logging.warning("AttentionProfileSelector not available - content-aware attention disabled")
|
|
|
| class AttentionConnector:
|
| """
|
| Connects the core architecture with content-aware attention mechanisms.
|
| This class serves as the integration layer between:
|
| 1. Original text from user input
|
| 2. The attention configuration system
|
| 3. The SmartHybridAttention implementation
|
| """
|
|
|
| def __init__(self, config_path: Optional[str] = None):
|
| """Initialize the connector with configuration"""
|
| self.logger = logging.getLogger(__name__)
|
|
|
|
|
| if config_path is None:
|
| self.config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
|
| else:
|
| self.config_path = config_path
|
|
|
|
|
| self.profile_selector = self._init_profile_selector()
|
|
|
|
|
| self.current_input_text = None
|
| self.current_context = {}
|
| self.active_profile_id = "standard"
|
| self.profile_confidence = 1.0
|
|
|
| def _init_profile_selector(self) -> Optional[Any]:
|
| """Initialize the attention profile selector"""
|
| if not ATTENTION_SELECTOR_AVAILABLE:
|
| self.logger.warning("AttentionProfileSelector not available - using default attention")
|
| return None
|
|
|
| try:
|
| selector = AttentionProfileSelector(self.config_path)
|
| self.logger.info(f"Initialized AttentionProfileSelector with {len(selector.profiles)} profiles")
|
| return selector
|
| except Exception as e:
|
| self.logger.error(f"Error initializing AttentionProfileSelector: {e}")
|
| return None
|
|
|
| def set_input_text(self, text: str, context: Optional[Dict[str, Any]] = None):
|
| """Set the current input text for attention mechanism"""
|
| self.current_input_text = text
|
| self.current_context = context or {}
|
|
|
|
|
| if self.profile_selector:
|
| self.active_profile_id, self.profile_confidence = self.profile_selector.select_profile(
|
| text, self.current_context
|
| )
|
| self.logger.info(f"Selected attention profile: {self.active_profile_id} (confidence: {self.profile_confidence:.2f})")
|
|
|
| def get_attention_parameters(self) -> Dict[str, Any]:
|
| """Get parameters for the current attention profile"""
|
| if not self.profile_selector:
|
| return {}
|
|
|
| return self.profile_selector.get_profile_parameters(self.active_profile_id)
|
|
|
| def inject_attention_parameters(self, attention_module: Any) -> Any:
|
| """Inject content-aware parameters into an attention module"""
|
| if not hasattr(attention_module, 'set_parameters'):
|
| self.logger.warning(f"Attention module does not support parameter injection")
|
| return attention_module
|
|
|
| params = self.get_attention_parameters()
|
| attention_module.set_parameters(**params)
|
| return attention_module
|
|
|
| def get_input_context(self) -> Dict[str, Any]:
|
| """Get the current input context for attention mechanism"""
|
| return {
|
| "input_text": self.current_input_text,
|
| "context": self.current_context,
|
| "profile_id": self.active_profile_id,
|
| "confidence": self.profile_confidence
|
| }
|
|
|
|
|
| _connector_instance = None
|
|
|
| def get_attention_connector() -> AttentionConnector:
|
| """Get or create the global attention connector instance"""
|
| global _connector_instance
|
| if _connector_instance is None:
|
| _connector_instance = AttentionConnector()
|
| return _connector_instance
|
|
|
|
|
|
|
| def inject_input_text(input_text: str, model_forward_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
| """
|
| Inject input text into model forward kwargs.
|
| This function should be called in the communicator before forwarding to model.
|
| """
|
| connector = get_attention_connector()
|
| connector.set_input_text(input_text)
|
|
|
|
|
| model_forward_kwargs["original_text"] = input_text
|
|
|
| return model_forward_kwargs
|
|
|
| def prepare_attention_context(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Dict[str, Any]:
|
| """
|
| Prepare attention context from connector for attention mechanism.
|
| This should be called within the attention module's forward method.
|
| """
|
| connector = get_attention_connector()
|
| attention_context = connector.get_input_context()
|
|
|
|
|
| attention_context.update({
|
| "query_shape": query.shape,
|
| "key_shape": key.shape,
|
| "value_shape": value.shape,
|
| })
|
|
|
| return attention_context
|
|
|