|
|
import os
|
|
|
import json
|
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
|
import re
|
|
|
from pathlib import Path
|
|
|
|
|
|
class AttentionProfileSelector:
|
|
|
"""
|
|
|
Selects appropriate attention profiles based on input characteristics
|
|
|
and configuration specified in the JSON dataset.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config_path: Optional[str] = None):
|
|
|
"""
|
|
|
Initialize the selector with the provided configuration.
|
|
|
|
|
|
Args:
|
|
|
config_path: Path to the attention configuration JSON
|
|
|
"""
|
|
|
if config_path is None:
|
|
|
|
|
|
config_path = os.path.join(os.path.dirname(__file__), "attention_configuration.json")
|
|
|
|
|
|
self.config = self._load_config(config_path)
|
|
|
self.profiles = {p["profile_id"]: p for p in self.config.get("attention_profiles", [])}
|
|
|
self.default_profile_id = self.config.get("default_profile", "standard")
|
|
|
self.selection_strategy = self.config.get("profile_selection_strategy", {})
|
|
|
|
|
|
def _load_config(self, config_path: str) -> Dict[str, Any]:
|
|
|
"""Load configuration from JSON file."""
|
|
|
try:
|
|
|
with open(config_path, 'r') as f:
|
|
|
return json.load(f)
|
|
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
|
print(f"Error loading attention configuration: {e}")
|
|
|
return {}
|
|
|
|
|
|
def select_profile(self,
|
|
|
input_text: str,
|
|
|
context: Optional[Dict[str, Any]] = None) -> Tuple[str, float]:
|
|
|
"""
|
|
|
Select the most appropriate attention profile based on input characteristics.
|
|
|
|
|
|
Args:
|
|
|
input_text: The user's input text
|
|
|
context: Additional context about the interaction
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (profile_id, confidence)
|
|
|
"""
|
|
|
if not self.profiles:
|
|
|
return self.default_profile_id, 1.0
|
|
|
|
|
|
|
|
|
scores = {profile_id: 0.0 for profile_id in self.profiles.keys()}
|
|
|
|
|
|
|
|
|
input_length = len(input_text)
|
|
|
for profile_id, profile in self.profiles.items():
|
|
|
|
|
|
length_threshold = profile.get("activation_signals", {}).get("document_length_threshold", 0)
|
|
|
if input_length > length_threshold and length_threshold > 0:
|
|
|
scores[profile_id] += self.selection_strategy.get("strategy_parameters", {}).get("input_length_weight", 0.3)
|
|
|
|
|
|
|
|
|
for profile_id, profile in self.profiles.items():
|
|
|
content_signals = profile.get("activation_signals", {}).get("content_type_signals", [])
|
|
|
matched_signals = sum(1 for signal in content_signals if signal.lower() in input_text.lower())
|
|
|
if content_signals:
|
|
|
signal_score = matched_signals / len(content_signals)
|
|
|
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
|
|
|
|
|
|
|
|
|
for profile_id, profile in self.profiles.items():
|
|
|
structure_signals = profile.get("activation_signals", {}).get("structure_indicators", [])
|
|
|
matched_signals = sum(1 for signal in structure_signals if signal.lower() in input_text.lower())
|
|
|
if structure_signals:
|
|
|
signal_score = matched_signals / len(structure_signals)
|
|
|
scores[profile_id] += signal_score * self.selection_strategy.get("strategy_parameters", {}).get("content_type_weight", 0.5)
|
|
|
|
|
|
|
|
|
if context and "requested_attention" in context:
|
|
|
requested = context["requested_attention"]
|
|
|
if requested in self.profiles:
|
|
|
scores[requested] += self.selection_strategy.get("strategy_parameters", {}).get("explicit_request_weight", 1.0)
|
|
|
|
|
|
|
|
|
if not scores:
|
|
|
return self.default_profile_id, 1.0
|
|
|
|
|
|
best_profile_id = max(scores.items(), key=lambda x: x[1])[0]
|
|
|
confidence = scores[best_profile_id]
|
|
|
|
|
|
|
|
|
min_confidence = self.selection_strategy.get("strategy_parameters", {}).get("minimum_confidence", 0.65)
|
|
|
if confidence < min_confidence:
|
|
|
return self.default_profile_id, confidence
|
|
|
|
|
|
return best_profile_id, confidence
|
|
|
|
|
|
def get_profile_parameters(self, profile_id: str) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Get the parameters for the specified attention profile.
|
|
|
|
|
|
Args:
|
|
|
profile_id: ID of the attention profile
|
|
|
|
|
|
Returns:
|
|
|
Dictionary of attention parameters
|
|
|
"""
|
|
|
if profile_id in self.profiles:
|
|
|
return self.profiles[profile_id].get("parameters", {})
|
|
|
return {}
|
|
|
|
|
|
def get_attention_type(self, profile_id: str) -> str:
|
|
|
"""
|
|
|
Get the attention mechanism type for the specified profile.
|
|
|
|
|
|
Args:
|
|
|
profile_id: ID of the attention profile
|
|
|
|
|
|
Returns:
|
|
|
String identifying the attention type
|
|
|
"""
|
|
|
if profile_id in self.profiles:
|
|
|
return self.profiles[profile_id].get("attention_type", "standard")
|
|
|
return "standard"
|
|
|
|
|
|
|
|
|
|
|
|
def create_attention_mechanism(profile_id: str, model_dim: int, selector: AttentionProfileSelector):
|
|
|
"""
|
|
|
Factory function to create an attention mechanism based on the selected profile.
|
|
|
|
|
|
Args:
|
|
|
profile_id: ID of the selected attention profile
|
|
|
model_dim: Model hidden dimension
|
|
|
selector: AttentionProfileSelector instance
|
|
|
|
|
|
Returns:
|
|
|
Configured attention mechanism
|
|
|
"""
|
|
|
|
|
|
|
|
|
attention_type = selector.get_attention_type(profile_id)
|
|
|
parameters = selector.get_profile_parameters(profile_id)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from smartHybridAttention import EnhancedSmartHybridAttention, create_smart_hybrid_attention
|
|
|
|
|
|
|
|
|
attention_params = {
|
|
|
"dim": model_dim,
|
|
|
"num_heads": parameters.get("num_heads", 8),
|
|
|
"window_size": parameters.get("window_size", 256),
|
|
|
"use_sliding": parameters.get("use_sliding_window", True),
|
|
|
"use_global": parameters.get("use_global_tokens", True),
|
|
|
"global_token_ratio": parameters.get("global_token_ratio", 0.05),
|
|
|
"memory_tokens": parameters.get("memory_token_count", 16)
|
|
|
}
|
|
|
|
|
|
|
|
|
if attention_type == "hierarchical":
|
|
|
attention_params["use_hierarchical"] = True
|
|
|
|
|
|
return create_smart_hybrid_attention(**attention_params)
|
|
|
|
|
|
except ImportError:
|
|
|
print("Warning: smartHybridAttention not found. Using placeholder.")
|
|
|
|
|
|
import torch.nn as nn
|
|
|
return nn.MultiheadAttention(model_dim, 8)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
selector = AttentionProfileSelector()
|
|
|
|
|
|
|
|
|
code_input = "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)"
|
|
|
|
|
|
document_input = """# Chapter 1: Introduction
|
|
|
This technical document covers the architecture of our system.
|
|
|
## Section 1.1: Overview
|
|
|
The system consists of multiple components working together.
|
|
|
"""
|
|
|
|
|
|
conversation_input = "How did you like the movie we saw yesterday? I thought the ending was unexpected."
|
|
|
|
|
|
|
|
|
code_profile, code_conf = selector.select_profile(code_input)
|
|
|
doc_profile, doc_conf = selector.select_profile(document_input)
|
|
|
conv_profile, conv_conf = selector.select_profile(conversation_input)
|
|
|
|
|
|
print(f"Code input → {code_profile} (confidence: {code_conf:.2f})")
|
|
|
print(f"Document input → {doc_profile} (confidence: {doc_conf:.2f})")
|
|
|
print(f"Conversation input → {conv_profile} (confidence: {conv_conf:.2f})")
|
|
|
|