File size: 4,956 Bytes
4b1fd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""

Optimize smartHybridAttention parameters for 256-dimensional models

"""
import os
import json
import logging
import torch
from typing import Dict, Any

logger = logging.getLogger(__name__)

def optimize_attention_for_small_dimensions(

    dim: int = 256,

    model_dir: str = None

) -> Dict[str, Any]:
    """

    Creates optimized attention parameters for small-dimensional models

    

    Args:

        dim: Model dimension (default: 256)

        model_dir: Directory to save optimization settings

    

    Returns:

        Dictionary with optimized attention parameters

    """
    # Base config with enhanced parameters for 256-dim models
    config = {
        "DIM": dim,
        "NUM_HEADS": 8,  # 8 heads works well for 256-dim (32 dim per head)
        "WINDOW_SIZE": 512,  # Larger window to capture more context
        "USE_SLIDING": True,
        "USE_GLOBAL": True,
        "USE_HIERARCHICAL": True,  # Enable hierarchical attention for 256-dim
        "GLOBAL_TOKEN_RATIO": 0.12,  # Increase global tokens (12% vs standard 5%)
        "MEMORY_TOKENS": 48,  # More memory tokens (48 vs standard 32)
        "STRIDE": 256,  # Stride = window_size / 2
        "MAX_SEQ_LENGTH": 2048,  # Support longer sequences with sparse attention
        "LAYER_SPECIALIZATION": True,  # Each layer can have different attention types
        "ATTENTION_DROPOUT": 0.1,
        "RECENCY_BIAS": 0.3,  # Add recency bias to prioritize recent context
    }
    
    # Special layer-specific optimizations for 256-dim models
    config["LAYER_CONFIG"] = {
        # Lower layers focus on local patterns
        "0": {"WINDOW_SIZE": 128, "GLOBAL_TOKEN_RATIO": 0.05, "USE_HIERARCHICAL": False},
        "1": {"WINDOW_SIZE": 256, "GLOBAL_TOKEN_RATIO": 0.08, "USE_HIERARCHICAL": False},
        # Middle layers use hybrid approach
        "2": {"WINDOW_SIZE": 384, "GLOBAL_TOKEN_RATIO": 0.10, "USE_HIERARCHICAL": True},
        "3": {"WINDOW_SIZE": 512, "GLOBAL_TOKEN_RATIO": 0.12, "USE_HIERARCHICAL": True},
        # Upper layers focus more on global connections
        "4": {"WINDOW_SIZE": 768, "GLOBAL_TOKEN_RATIO": 0.15, "USE_HIERARCHICAL": True},
        "5": {"WINDOW_SIZE": 1024, "GLOBAL_TOKEN_RATIO": 0.18, "USE_HIERARCHICAL": True},
    }
    
    if model_dir:
        os.makedirs(model_dir, exist_ok=True)
        config_path = os.path.join(model_dir, "attention_config_256dim.json")
        with open(config_path, "w") as f:
            json.dump(config, f, indent=2)
        logger.info(f"Saved optimized attention config to {config_path}")
        
    return config

def apply_optimized_attention_to_model(

    model, 

    dim: int = 256,

    config: Dict[str, Any] = None

) -> bool:
    """

    Apply optimized attention parameters to existing model

    

    Args:

        model: The model to optimize

        dim: Model dimension (default: 256)

        config: Attention configuration (generated if None)

        

    Returns:

        Success status

    """
    try:
        if config is None:
            config = optimize_attention_for_small_dimensions(dim)
            
        # Find attention modules in model
        attention_layers = []
        for name, module in model.named_modules():
            if "attention" in name.lower() or hasattr(module, 'smartHybridAttention'):
                attention_layers.append((name, module))
                
        if not attention_layers:
            logger.warning("No attention layers found in model")
            return False
            
        logger.info(f"Found {len(attention_layers)} attention layers to optimize")
        
        # Apply configuration to each layer
        for i, (name, layer) in enumerate(attention_layers):
            layer_idx = str(i)
            layer_config = config["LAYER_CONFIG"].get(layer_idx, {})
            
            # Apply layer-specific configs
            for key, value in layer_config.items():
                if hasattr(layer, key.lower()):
                    setattr(layer, key.lower(), value)
                    logger.info(f"Set {key.lower()}={value} for layer {name}")
            
            # Apply global configs where specific isn't set
            for key, value in config.items():
                if key != "LAYER_CONFIG" and hasattr(layer, key.lower()) and key not in layer_config:
                    setattr(layer, key.lower(), value)
        
        logger.info("Successfully applied optimized attention parameters")
        return True
        
    except Exception as e:
        logger.error(f"Error applying attention optimization: {e}")
        return False

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    config = optimize_attention_for_small_dimensions()
    print("Generated optimized attention config for 256-dim model:")
    print(json.dumps(config, indent=2))