Wildnerve-tlm01_Hybrid_Model / optimize_attention.py
WildnerveAI's picture
Upload 11 files
4b1fd1d verified
"""
Optimize smartHybridAttention parameters for 256-dimensional models
"""
import os
import json
import logging
import torch
from typing import Dict, Any
logger = logging.getLogger(__name__)
def optimize_attention_for_small_dimensions(
dim: int = 256,
model_dir: str = None
) -> Dict[str, Any]:
"""
Creates optimized attention parameters for small-dimensional models
Args:
dim: Model dimension (default: 256)
model_dir: Directory to save optimization settings
Returns:
Dictionary with optimized attention parameters
"""
# Base config with enhanced parameters for 256-dim models
config = {
"DIM": dim,
"NUM_HEADS": 8, # 8 heads works well for 256-dim (32 dim per head)
"WINDOW_SIZE": 512, # Larger window to capture more context
"USE_SLIDING": True,
"USE_GLOBAL": True,
"USE_HIERARCHICAL": True, # Enable hierarchical attention for 256-dim
"GLOBAL_TOKEN_RATIO": 0.12, # Increase global tokens (12% vs standard 5%)
"MEMORY_TOKENS": 48, # More memory tokens (48 vs standard 32)
"STRIDE": 256, # Stride = window_size / 2
"MAX_SEQ_LENGTH": 2048, # Support longer sequences with sparse attention
"LAYER_SPECIALIZATION": True, # Each layer can have different attention types
"ATTENTION_DROPOUT": 0.1,
"RECENCY_BIAS": 0.3, # Add recency bias to prioritize recent context
}
# Special layer-specific optimizations for 256-dim models
config["LAYER_CONFIG"] = {
# Lower layers focus on local patterns
"0": {"WINDOW_SIZE": 128, "GLOBAL_TOKEN_RATIO": 0.05, "USE_HIERARCHICAL": False},
"1": {"WINDOW_SIZE": 256, "GLOBAL_TOKEN_RATIO": 0.08, "USE_HIERARCHICAL": False},
# Middle layers use hybrid approach
"2": {"WINDOW_SIZE": 384, "GLOBAL_TOKEN_RATIO": 0.10, "USE_HIERARCHICAL": True},
"3": {"WINDOW_SIZE": 512, "GLOBAL_TOKEN_RATIO": 0.12, "USE_HIERARCHICAL": True},
# Upper layers focus more on global connections
"4": {"WINDOW_SIZE": 768, "GLOBAL_TOKEN_RATIO": 0.15, "USE_HIERARCHICAL": True},
"5": {"WINDOW_SIZE": 1024, "GLOBAL_TOKEN_RATIO": 0.18, "USE_HIERARCHICAL": True},
}
if model_dir:
os.makedirs(model_dir, exist_ok=True)
config_path = os.path.join(model_dir, "attention_config_256dim.json")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
logger.info(f"Saved optimized attention config to {config_path}")
return config
def apply_optimized_attention_to_model(
model,
dim: int = 256,
config: Dict[str, Any] = None
) -> bool:
"""
Apply optimized attention parameters to existing model
Args:
model: The model to optimize
dim: Model dimension (default: 256)
config: Attention configuration (generated if None)
Returns:
Success status
"""
try:
if config is None:
config = optimize_attention_for_small_dimensions(dim)
# Find attention modules in model
attention_layers = []
for name, module in model.named_modules():
if "attention" in name.lower() or hasattr(module, 'smartHybridAttention'):
attention_layers.append((name, module))
if not attention_layers:
logger.warning("No attention layers found in model")
return False
logger.info(f"Found {len(attention_layers)} attention layers to optimize")
# Apply configuration to each layer
for i, (name, layer) in enumerate(attention_layers):
layer_idx = str(i)
layer_config = config["LAYER_CONFIG"].get(layer_idx, {})
# Apply layer-specific configs
for key, value in layer_config.items():
if hasattr(layer, key.lower()):
setattr(layer, key.lower(), value)
logger.info(f"Set {key.lower()}={value} for layer {name}")
# Apply global configs where specific isn't set
for key, value in config.items():
if key != "LAYER_CONFIG" and hasattr(layer, key.lower()) and key not in layer_config:
setattr(layer, key.lower(), value)
logger.info("Successfully applied optimized attention parameters")
return True
except Exception as e:
logger.error(f"Error applying attention optimization: {e}")
return False
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
config = optimize_attention_for_small_dimensions()
print("Generated optimized attention config for 256-dim model:")
print(json.dumps(config, indent=2))