""" Optimize smartHybridAttention parameters for 256-dimensional models """ import os import json import logging import torch from typing import Dict, Any logger = logging.getLogger(__name__) def optimize_attention_for_small_dimensions( dim: int = 256, model_dir: str = None ) -> Dict[str, Any]: """ Creates optimized attention parameters for small-dimensional models Args: dim: Model dimension (default: 256) model_dir: Directory to save optimization settings Returns: Dictionary with optimized attention parameters """ # Base config with enhanced parameters for 256-dim models config = { "DIM": dim, "NUM_HEADS": 8, # 8 heads works well for 256-dim (32 dim per head) "WINDOW_SIZE": 512, # Larger window to capture more context "USE_SLIDING": True, "USE_GLOBAL": True, "USE_HIERARCHICAL": True, # Enable hierarchical attention for 256-dim "GLOBAL_TOKEN_RATIO": 0.12, # Increase global tokens (12% vs standard 5%) "MEMORY_TOKENS": 48, # More memory tokens (48 vs standard 32) "STRIDE": 256, # Stride = window_size / 2 "MAX_SEQ_LENGTH": 2048, # Support longer sequences with sparse attention "LAYER_SPECIALIZATION": True, # Each layer can have different attention types "ATTENTION_DROPOUT": 0.1, "RECENCY_BIAS": 0.3, # Add recency bias to prioritize recent context } # Special layer-specific optimizations for 256-dim models config["LAYER_CONFIG"] = { # Lower layers focus on local patterns "0": {"WINDOW_SIZE": 128, "GLOBAL_TOKEN_RATIO": 0.05, "USE_HIERARCHICAL": False}, "1": {"WINDOW_SIZE": 256, "GLOBAL_TOKEN_RATIO": 0.08, "USE_HIERARCHICAL": False}, # Middle layers use hybrid approach "2": {"WINDOW_SIZE": 384, "GLOBAL_TOKEN_RATIO": 0.10, "USE_HIERARCHICAL": True}, "3": {"WINDOW_SIZE": 512, "GLOBAL_TOKEN_RATIO": 0.12, "USE_HIERARCHICAL": True}, # Upper layers focus more on global connections "4": {"WINDOW_SIZE": 768, "GLOBAL_TOKEN_RATIO": 0.15, "USE_HIERARCHICAL": True}, "5": {"WINDOW_SIZE": 1024, "GLOBAL_TOKEN_RATIO": 0.18, "USE_HIERARCHICAL": True}, } if model_dir: os.makedirs(model_dir, exist_ok=True) config_path = os.path.join(model_dir, "attention_config_256dim.json") with open(config_path, "w") as f: json.dump(config, f, indent=2) logger.info(f"Saved optimized attention config to {config_path}") return config def apply_optimized_attention_to_model( model, dim: int = 256, config: Dict[str, Any] = None ) -> bool: """ Apply optimized attention parameters to existing model Args: model: The model to optimize dim: Model dimension (default: 256) config: Attention configuration (generated if None) Returns: Success status """ try: if config is None: config = optimize_attention_for_small_dimensions(dim) # Find attention modules in model attention_layers = [] for name, module in model.named_modules(): if "attention" in name.lower() or hasattr(module, 'smartHybridAttention'): attention_layers.append((name, module)) if not attention_layers: logger.warning("No attention layers found in model") return False logger.info(f"Found {len(attention_layers)} attention layers to optimize") # Apply configuration to each layer for i, (name, layer) in enumerate(attention_layers): layer_idx = str(i) layer_config = config["LAYER_CONFIG"].get(layer_idx, {}) # Apply layer-specific configs for key, value in layer_config.items(): if hasattr(layer, key.lower()): setattr(layer, key.lower(), value) logger.info(f"Set {key.lower()}={value} for layer {name}") # Apply global configs where specific isn't set for key, value in config.items(): if key != "LAYER_CONFIG" and hasattr(layer, key.lower()) and key not in layer_config: setattr(layer, key.lower(), value) logger.info("Successfully applied optimized attention parameters") return True except Exception as e: logger.error(f"Error applying attention optimization: {e}") return False if __name__ == "__main__": logging.basicConfig(level=logging.INFO) config = optimize_attention_for_small_dimensions() print("Generated optimized attention config for 256-dim model:") print(json.dumps(config, indent=2))