Wildnerve-tlm01_Hybrid_Model / utils /attention_trigger_system.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
import os
import json
from typing import Dict, Any, Optional, List, Tuple
import re
from pathlib import Path
class AttentionProfileSelector:
"""
Selects appropriate attention profiles based on input characteristics
and configuration specified in the JSON dataset.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the selector with the provided configuration.
Args:
config_path: Path to the attention configuration JSON
"""
if config_path is None:
# Default to the standard location
config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
self.config = self._load_config(config_path)
self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])}
self.default_profile_id = self.config.get("default_profile", "standard")
self.selection_strategy = self.config.get("profile_selection_strategy", {})
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from JSON file."""
try:
with open(config_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading attention configuration: {e}")
return {}
def select_profile(self,
input_text: str,
context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]:
"""
Select the most appropriate attention profile based on input characteristics.
Args:
input_text: The user's input text
context: Additional context about the interaction
Returns:
Tuple of (profile_id, confidence)
"""
if not self.profiles:
return self.default_profile_id, 1.0
# Initialize scores for each profile
scores = {profile_id: 0.0 for profile_id in self.profiles.keys()}
# Calculate content length score
input_length = len(input_text)
for profile_id, profile in self.profiles.items():
# Check document length thresholds
length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0)
if input_length > length_threshold and length_threshold > 0:
scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3)
# Check content type signals
for profile_id, profile in self.profiles.items():
content_signals = profile.get("activation_signals", {}).get("content_type_signals", [])
matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower())
if content_signals:
signal_score = matched_signals / len(content_signals)
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
# Check structure indicators
for profile_id, profile in self.profiles.items():
structure_signals = profile.get("activation_signals", {}).get("structure_indicators", [])
matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower())
if structure_signals:
signal_score = matched_signals / len(structure_signals)
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
# Check for explicit request in context
if context and "requested_attention" in context:
requested = context["requested_attention"]
if requested in self.profiles:
scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0)
# Get the highest scoring profile
if not scores:
return self.default_profile_id, 1.0
best_profile_id = max(scores.items(), key=lambda x: x[1])[0]
confidence = scores[best_profile_id]
# Apply minimum confidence threshold
min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65)
if confidence < min_confidence:
return self.default_profile_id, confidence
return best_profile_id, confidence
def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]:
"""
Get the parameters for the specified attention profile.
Args:
profile_id: ID of the attention profile
Returns:
Dictionary of attention parameters
"""
if profile_id in self.profiles:
return self.profiles[profile_id].get("parameters", {})
return {}
def get_attention_type(self, profile_id: str) -> str:
"""
Get the attention mechanism type for the specified profile.
Args:
profile_id: ID of the attention profile
Returns:
String identifying the attention type
"""
if profile_id in self.profiles:
return self.profiles[profile_id].get("attention_type", "standard")
return "standard"
# Factory method to create appropriate attention mechanism
def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector):
"""
Factory function to create an attention mechanism based on the selected profile.
Args:
profile_id: ID of the selected attention profile
model_dim: Model hidden dimension
selector: AttentionProfileSelector instance
Returns:
Configured attention mechanism
"""
# This function would integrate with your existing attention mechanisms
# For implementation with smartHybridAttention:
attention_type = selector.get_attention_type(profile_id)
parameters = selector.get_profile_parameters(profile_id)
# Import here to avoid circular imports
try:
from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention
# Map parameters from JSON to the attention mechanism
attention_params = {
"dim": model_dim,
"num_heads": parameters.get("num_heads", 8),
"window_size": parameters.get("window_size", 256),
"use_sliding": parameters.get("use_sliding_window", True),
"use_global": parameters.get("use_global_tokens", True),
"global_token_ratio": parameters.get("global_token_ratio", 0.05),
"memory_tokens": parameters.get("memory_token_count", 16)
}
# Create appropriate attention mechanism based on type
if attention_type == "hierarchical":
attention_params["use_hierarchical"] = True
return create_smart_hybrid_attention(**attention_params)
except ImportError:
print("Warning: smartHybridAttention not found. Using placeholder.")
# Return a placeholder if the module is not available
import torch.nn as nn
return nn.MultiheadAttention(model_dim, 8)
# Usage example:
if __name__ == "__main__":
selector = AttentionProfileSelector()
# Example inputs
code_input = "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)"
document_input = """# Chapter 1: Introduction
This technical document covers the architecture of our system.
## Section 1.1: Overview
The system consists of multiple components working together.
"""
conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected."
# Test profile selection
code_profile, code_conf = selector.select_profile(code_input)
doc_profile, doc_conf = selector.select_profile(document_input)
conv_profile, conv_conf = selector.select_profile(conversation_input)
print(f"Code input → {code_profile} (confidence: {code_conf:.2f})")
print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})")
print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")