Wildnerve-tlm01_Hybrid_Model / utils /attention_connector.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
import os
import json
import logging
from typing import Dict, Any, Optional, List, Tuple, Union
import torch
# Conditional import for attention profile selector
try:
from attention_trigger_system import AttentionProfileSelector
ATTENTION_SELECTOR_AVAILABLE = True
except ImportError:
ATTENTION_SELECTOR_AVAILABLE = False
logging.warning("AttentionProfileSelector not available - content-aware attention disabled")
class AttentionConnector:
"""
Connects the core architecture with content-aware attention mechanisms.
This class serves as the integration layer between:
1. Original text from user input
2. The attention configuration system
3. The SmartHybridAttention implementation
"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the connector with configuration"""
self.logger = logging.getLogger(__name__)
# Set up default config path if none provided
if config_path is None:
self.config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
else:
self.config_path = config_path
# Initialize profile selector if available
self.profile_selector = self._init_profile_selector()
# Thread-local storage for current input text
self.current_input_text = None
self.current_context = {}
self.active_profile_id = "standard"
self.profile_confidence = 1.0
def _init_profile_selector(self) -> Optional[Any]:
"""Initialize the attention profile selector"""
if not ATTENTION_SELECTOR_AVAILABLE:
self.logger.warning("AttentionProfileSelector not available - using default attention")
return None
try:
selector = AttentionProfileSelector(self.config_path)
self.logger.info(f"Initialized AttentionProfileSelector with {len(selector.profiles)} profiles")
return selector
except Exception as e:
self.logger.error(f"Error initializing AttentionProfileSelector: {e}")
return None
def set_input_text(self, text: str, context: Optional[Dict[str, Any]] = None):
"""Set the current input text for attention mechanism"""
self.current_input_text = text
self.current_context = context or {}
# If we have a profile selector, determine the appropriate profile
if self.profile_selector:
self.active_profile_id, self.profile_confidence = self.profile_selector.select_profile(
text, self.current_context
)
self.logger.info(f"Selected attention profile: {self.active_profile_id} (confidence: {self.profile_confidence:.2f})")
def get_attention_parameters(self) -> Dict[str, Any]:
"""Get parameters for the current attention profile"""
if not self.profile_selector:
return {}
return self.profile_selector.get_profile_parameters(self.active_profile_id)
def inject_attention_parameters(self, attention_module: Any) -> Any:
"""Inject content-aware parameters into an attention module"""
if not hasattr(attention_module, 'set_parameters'):
self.logger.warning(f"Attention module does not support parameter injection")
return attention_module
params = self.get_attention_parameters()
attention_module.set_parameters(**params)
return attention_module
def get_input_context(self) -> Dict[str, Any]:
"""Get the current input context for attention mechanism"""
return {
"input_text": self.current_input_text,
"context": self.current_context,
"profile_id": self.active_profile_id,
"confidence": self.profile_confidence
}
# Global singleton instance
_connector_instance = None
def get_attention_connector() -> AttentionConnector:
"""Get or create the global attention connector instance"""
global _connector_instance
if _connector_instance is None:
_connector_instance = AttentionConnector()
return _connector_instance
# Hook functions to integrate with existing architecture
def inject_input_text(input_text: str, model_forward_kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Inject input text into model forward kwargs.
This function should be called in the communicator before forwarding to model.
"""
connector = get_attention_connector()
connector.set_input_text(input_text)
# Add input_text to kwargs for models that support it directly
model_forward_kwargs["original_text"] = input_text
return model_forward_kwargs
def prepare_attention_context(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Dict[str, Any]:
"""
Prepare attention context from connector for attention mechanism.
This should be called within the attention module's forward method.
"""
connector = get_attention_connector()
attention_context = connector.get_input_context()
# Include tensor shapes for attention mechanism's reference
attention_context.update({
"query_shape": query.shape,
"key_shape": key.shape,
"value_shape": value.shape,
})
return attention_context