File size: 10,046 Bytes
c6cb681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
f478beb
 
 
 
 
 
 
 
 
 
 
 
c6cb681
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
 
 
 
 
 
 
 
f33c95a
c6cb681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f478beb
 
 
 
 
 
c6cb681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e6de8
c6cb681
 
 
 
 
 
f33c95a
b98d237
c6cb681
 
 
 
 
 
 
f33c95a
b98d237
c6cb681
 
 
 
 
f33c95a
 
c6cb681
 
 
 
 
 
 
f33c95a
 
c6cb681
f33c95a
 
c6cb681
f33c95a
b98d237
f33c95a
 
c6cb681
f33c95a
b98d237
f33c95a
c6cb681
 
 
f33c95a
b98d237
c6cb681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4e6de8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""
Model family configuration registry.

Maps specific model names to families, and families to canonical module/parameter patterns.
This allows automatic selection of appropriate modules and parameters based on model architecture.
"""

from typing import Dict, List, Optional, Any

# Model family specifications
MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
    # LLaMA-like models (LLaMA, Mistral, Qwen2)
    "llama_like": {
        "description": "LLaMA, Mistral, Qwen2 architecture",
        "templates": {
            "attention_pattern": "model.layers.{N}.self_attn",
            "mlp_pattern": "model.layers.{N}.mlp",
            "block_pattern": "model.layers.{N}",
        },
        "norm_parameter": "model.norm.weight",
        "norm_type": "rmsnorm",
    },
    
    # GPT-2 family
    "gpt2": {
        "description": "GPT-2 architecture",
        "templates": {
            "attention_pattern": "transformer.h.{N}.attn",
            "mlp_pattern": "transformer.h.{N}.mlp",
            "block_pattern": "transformer.h.{N}",
        },
        "norm_parameter": "transformer.ln_f.weight",
        "logit_lens_pattern": "lm_head.weight",
        "norm_type": "layernorm",
    },
    
    # GPT-Neo (EleutherAI) — similar to GPT-2 but with local attention
    "gpt_neo": {
        "description": "GPT-Neo architecture (EleutherAI)",
        "templates": {
            "attention_pattern": "transformer.h.{N}.attn.attention",
            "mlp_pattern": "transformer.h.{N}.mlp",
            "block_pattern": "transformer.h.{N}",
        },
        "norm_parameter": "transformer.ln_f.weight",
        "norm_type": "layernorm",
    },
    
    # OPT
    "opt": {
        "description": "OPT architecture",
        "templates": {
            "attention_pattern": "model.decoder.layers.{N}.self_attn",
            "mlp_pattern": "model.decoder.layers.{N}.fc2",
            "block_pattern": "model.decoder.layers.{N}",
        },
        "norm_parameter": "model.decoder.final_layer_norm.weight",
        "norm_type": "layernorm",
    },
    
    # GPT-NeoX
    "gpt_neox": {
        "description": "GPT-NeoX architecture",
        "templates": {
            "attention_pattern": "gpt_neox.layers.{N}.attention",
            "mlp_pattern": "gpt_neox.layers.{N}.mlp",
            "block_pattern": "gpt_neox.layers.{N}",
        },
        "norm_parameter": "gpt_neox.final_layer_norm.weight",
        "norm_type": "layernorm",
    },
    
    # BLOOM
    "bloom": {
        "description": "BLOOM architecture",
        "templates": {
            "attention_pattern": "transformer.h.{N}.self_attention",
            "mlp_pattern": "transformer.h.{N}.mlp",
            "block_pattern": "transformer.h.{N}",
        },
        "norm_parameter": "transformer.ln_f.weight",
        "norm_type": "layernorm",
    },
    
    # Falcon
    "falcon": {
        "description": "Falcon architecture",
        "templates": {
            "attention_pattern": "transformer.h.{N}.self_attention",
            "mlp_pattern": "transformer.h.{N}.mlp",
            "block_pattern": "transformer.h.{N}",
        },
        "norm_parameter": "transformer.ln_f.weight",
        "norm_type": "layernorm",
    },
    
    # MPT
    "mpt": {
        "description": "MPT architecture",
        "templates": {
            "attention_pattern": "transformer.blocks.{N}.attn",
            "mlp_pattern": "transformer.blocks.{N}.ffn",
            "block_pattern": "transformer.blocks.{N}",
        },
        "norm_parameter": "transformer.norm_f.weight",
        "norm_type": "layernorm",
    },
}

# Hard-coded mapping of specific model names to families
MODEL_TO_FAMILY: Dict[str, str] = {
    # Qwen models
    "Qwen/Qwen2.5-0.5B": "llama_like",
    "Qwen/Qwen2.5-1.5B": "llama_like",
    "Qwen/Qwen2.5-3B": "llama_like",
    "Qwen/Qwen2.5-7B": "llama_like",
    "Qwen/Qwen2.5-14B": "llama_like",
    "Qwen/Qwen2.5-32B": "llama_like",
    "Qwen/Qwen2.5-72B": "llama_like",
    "Qwen/Qwen2-0.5B": "llama_like",
    "Qwen/Qwen2-1.5B": "llama_like",
    "Qwen/Qwen2-7B": "llama_like",
    
    # LLaMA models
    "meta-llama/Llama-2-7b-hf": "llama_like",
    "meta-llama/Llama-2-13b-hf": "llama_like",
    "meta-llama/Llama-2-70b-hf": "llama_like",
    "meta-llama/Llama-3.1-8B": "llama_like",
    "meta-llama/Llama-3.1-70B": "llama_like",
    "meta-llama/Llama-3.2-1B": "llama_like",
    "meta-llama/Llama-3.2-3B": "llama_like",
    
    # Mistral models
    "mistralai/Mistral-7B-v0.1": "llama_like",
    "mistralai/Mistral-7B-v0.3": "llama_like",
    "mistralai/Mixtral-8x7B-v0.1": "llama_like",
    "mistralai/Mixtral-8x22B-v0.1": "llama_like",
    
    # GPT-2 models
    "gpt2": "gpt2",
    "gpt2-medium": "gpt2",
    "gpt2-large": "gpt2",
    "gpt2-xl": "gpt2",
    "openai-community/gpt2": "gpt2",
    "openai-community/gpt2-medium": "gpt2",
    "openai-community/gpt2-large": "gpt2",
    "openai-community/gpt2-xl": "gpt2",
    
    # OPT models
    "facebook/opt-125m": "opt",
    "facebook/opt-350m": "opt",
    "facebook/opt-1.3b": "opt",
    "facebook/opt-2.7b": "opt",
    "facebook/opt-6.7b": "opt",
    "facebook/opt-13b": "opt",
    "facebook/opt-30b": "opt",
    
    # GPT-Neo models (EleutherAI)
    "EleutherAI/gpt-neo-125M": "gpt_neo",
    "EleutherAI/gpt-neo-1.3B": "gpt_neo",
    "EleutherAI/gpt-neo-2.7B": "gpt_neo",
    
    # GPT-NeoX / Pythia models (EleutherAI)
    "EleutherAI/gpt-neox-20b": "gpt_neox",
    "EleutherAI/pythia-70m": "gpt_neox",
    "EleutherAI/pythia-160m": "gpt_neox",
    "EleutherAI/pythia-410m": "gpt_neox",
    "EleutherAI/pythia-1b": "gpt_neox",
    "EleutherAI/pythia-1.4b": "gpt_neox",
    "EleutherAI/pythia-2.8b": "gpt_neox",
    "EleutherAI/pythia-6.9b": "gpt_neox",
    "EleutherAI/pythia-12b": "gpt_neox",
    
    # BLOOM models
    "bigscience/bloom-560m": "bloom",
    "bigscience/bloom-1b1": "bloom",
    "bigscience/bloom-1b7": "bloom",
    "bigscience/bloom-3b": "bloom",
    "bigscience/bloom-7b1": "bloom",
    
    # Falcon models
    "tiiuae/falcon-7b": "falcon",
    "tiiuae/falcon-40b": "falcon",
    
    # MPT models
    "mosaicml/mpt-7b": "mpt",
    "mosaicml/mpt-30b": "mpt",
}


def get_model_family(model_name: str) -> Optional[str]:
    """
    Get the model family for a given model name.
    
    Args:
        model_name: HuggingFace model name/path
        
    Returns:
        Family name if found, None otherwise
    """
    return MODEL_TO_FAMILY.get(model_name)


def get_family_config(family_name: str) -> Optional[Dict[str, Any]]:
    """
    Get the configuration for a model family.
    
    Args:
        family_name: Name of the model family
        
    Returns:
        Family configuration dict if found, None otherwise
    """
    return MODEL_FAMILIES.get(family_name)


def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]], 
                        param_patterns: Dict[str, List[str]]) -> Dict[str, Any]:
    """
    Get automatic dropdown selections based on model family.
    
    Args:
        model_name: HuggingFace model name
        module_patterns: Available module patterns from the model
        param_patterns: Available parameter patterns from the model
        
    Returns:
        Dict with keys: attention_selection, block_selection, norm_selection
        Each value is a list of pattern keys that should be pre-selected
    """
    family = get_model_family(model_name)
    if not family:
        return {
            'attention_selection': [],
            'block_selection': [],
            'norm_selection': [],  # Empty list for multi-select dropdown
            'family_name': None
        }
    
    config = get_family_config(family)
    if not config:
        return {
            'attention_selection': [],
            'block_selection': [],
            'norm_selection': [],  # Empty list for multi-select dropdown
            'family_name': None
        }
    
    # Find matching patterns in the available patterns
    attention_matches = []
    block_matches = []
    norm_match = None
    
    # Match attention patterns
    attention_template = config['templates'].get('attention_pattern', '')
    for pattern_key in module_patterns.keys():
        if _pattern_matches_template(pattern_key, attention_template):
            attention_matches.append(pattern_key)
    
    # Match block patterns (full layer outputs - residual stream)
    block_template = config['templates'].get('block_pattern', '')
    for pattern_key in module_patterns.keys():
        if _pattern_matches_template(pattern_key, block_template):
            block_matches.append(pattern_key)
    
    # Match normalization parameter
    # Note: norm-params-dropdown has multi=True, so return a list
    norm_parameter = config.get('norm_parameter', '')
    if norm_parameter:
        for pattern_key in param_patterns.keys():
            if _pattern_matches_template(pattern_key, norm_parameter):
                norm_match = [pattern_key]  # Return as list for multi-select dropdown
                break
    
    return {
        'attention_selection': attention_matches,
        'block_selection': block_matches,
        'norm_selection': norm_match if norm_match else [],  # Ensure list for multi-select
        'family_name': family,
        'family_description': config.get('description', '')
    }


def _pattern_matches_template(pattern: str, template: str) -> bool:
    """
    Check if a pattern string matches a template.
    Templates use {N} as wildcard, patterns use {N} for the same purpose.
    
    Args:
        pattern: Pattern string like "model.layers.{N}.mlp"
        template: Template string like "model.layers.{N}.mlp"
        
    Returns:
        True if pattern matches template
    """
    if not template:
        return False
    
    # Simple check: remove {N} from both and see if they match
    pattern_normalized = pattern.replace('{N}', '').replace('.', '_')
    template_normalized = template.replace('{N}', '').replace('.', '_')
    
    # Exact match
    return pattern_normalized == template_normalized