""" Model family configuration registry. Maps specific model names to families, and families to canonical module/parameter patterns. This allows automatic selection of appropriate modules and parameters based on model architecture. """ from typing import Dict, List, Optional, Any # Model family specifications MODEL_FAMILIES: Dict[str, Dict[str, Any]] = { # LLaMA-like models (LLaMA, Mistral, Qwen2) "llama_like": { "description": "LLaMA, Mistral, Qwen2 architecture", "templates": { "attention_pattern": "model.layers.{N}.self_attn", "mlp_pattern": "model.layers.{N}.mlp", "block_pattern": "model.layers.{N}", }, "norm_parameter": "model.norm.weight", "norm_type": "rmsnorm", }, # GPT-2 family "gpt2": { "description": "GPT-2 architecture", "templates": { "attention_pattern": "transformer.h.{N}.attn", "mlp_pattern": "transformer.h.{N}.mlp", "block_pattern": "transformer.h.{N}", }, "norm_parameter": "transformer.ln_f.weight", "logit_lens_pattern": "lm_head.weight", "norm_type": "layernorm", }, # GPT-Neo (EleutherAI) — similar to GPT-2 but with local attention "gpt_neo": { "description": "GPT-Neo architecture (EleutherAI)", "templates": { "attention_pattern": "transformer.h.{N}.attn.attention", "mlp_pattern": "transformer.h.{N}.mlp", "block_pattern": "transformer.h.{N}", }, "norm_parameter": "transformer.ln_f.weight", "norm_type": "layernorm", }, # OPT "opt": { "description": "OPT architecture", "templates": { "attention_pattern": "model.decoder.layers.{N}.self_attn", "mlp_pattern": "model.decoder.layers.{N}.fc2", "block_pattern": "model.decoder.layers.{N}", }, "norm_parameter": "model.decoder.final_layer_norm.weight", "norm_type": "layernorm", }, # GPT-NeoX "gpt_neox": { "description": "GPT-NeoX architecture", "templates": { "attention_pattern": "gpt_neox.layers.{N}.attention", "mlp_pattern": "gpt_neox.layers.{N}.mlp", "block_pattern": "gpt_neox.layers.{N}", }, "norm_parameter": "gpt_neox.final_layer_norm.weight", "norm_type": "layernorm", }, # BLOOM "bloom": { "description": "BLOOM architecture", "templates": { "attention_pattern": "transformer.h.{N}.self_attention", "mlp_pattern": "transformer.h.{N}.mlp", "block_pattern": "transformer.h.{N}", }, "norm_parameter": "transformer.ln_f.weight", "norm_type": "layernorm", }, # Falcon "falcon": { "description": "Falcon architecture", "templates": { "attention_pattern": "transformer.h.{N}.self_attention", "mlp_pattern": "transformer.h.{N}.mlp", "block_pattern": "transformer.h.{N}", }, "norm_parameter": "transformer.ln_f.weight", "norm_type": "layernorm", }, # MPT "mpt": { "description": "MPT architecture", "templates": { "attention_pattern": "transformer.blocks.{N}.attn", "mlp_pattern": "transformer.blocks.{N}.ffn", "block_pattern": "transformer.blocks.{N}", }, "norm_parameter": "transformer.norm_f.weight", "norm_type": "layernorm", }, } # Hard-coded mapping of specific model names to families MODEL_TO_FAMILY: Dict[str, str] = { # Qwen models "Qwen/Qwen2.5-0.5B": "llama_like", "Qwen/Qwen2.5-1.5B": "llama_like", "Qwen/Qwen2.5-3B": "llama_like", "Qwen/Qwen2.5-7B": "llama_like", "Qwen/Qwen2.5-14B": "llama_like", "Qwen/Qwen2.5-32B": "llama_like", "Qwen/Qwen2.5-72B": "llama_like", "Qwen/Qwen2-0.5B": "llama_like", "Qwen/Qwen2-1.5B": "llama_like", "Qwen/Qwen2-7B": "llama_like", # LLaMA models "meta-llama/Llama-2-7b-hf": "llama_like", "meta-llama/Llama-2-13b-hf": "llama_like", "meta-llama/Llama-2-70b-hf": "llama_like", "meta-llama/Llama-3.1-8B": "llama_like", "meta-llama/Llama-3.1-70B": "llama_like", "meta-llama/Llama-3.2-1B": "llama_like", "meta-llama/Llama-3.2-3B": "llama_like", # Mistral models "mistralai/Mistral-7B-v0.1": "llama_like", "mistralai/Mistral-7B-v0.3": "llama_like", "mistralai/Mixtral-8x7B-v0.1": "llama_like", "mistralai/Mixtral-8x22B-v0.1": "llama_like", # GPT-2 models "gpt2": "gpt2", "gpt2-medium": "gpt2", "gpt2-large": "gpt2", "gpt2-xl": "gpt2", "openai-community/gpt2": "gpt2", "openai-community/gpt2-medium": "gpt2", "openai-community/gpt2-large": "gpt2", "openai-community/gpt2-xl": "gpt2", # OPT models "facebook/opt-125m": "opt", "facebook/opt-350m": "opt", "facebook/opt-1.3b": "opt", "facebook/opt-2.7b": "opt", "facebook/opt-6.7b": "opt", "facebook/opt-13b": "opt", "facebook/opt-30b": "opt", # GPT-Neo models (EleutherAI) "EleutherAI/gpt-neo-125M": "gpt_neo", "EleutherAI/gpt-neo-1.3B": "gpt_neo", "EleutherAI/gpt-neo-2.7B": "gpt_neo", # GPT-NeoX / Pythia models (EleutherAI) "EleutherAI/gpt-neox-20b": "gpt_neox", "EleutherAI/pythia-70m": "gpt_neox", "EleutherAI/pythia-160m": "gpt_neox", "EleutherAI/pythia-410m": "gpt_neox", "EleutherAI/pythia-1b": "gpt_neox", "EleutherAI/pythia-1.4b": "gpt_neox", "EleutherAI/pythia-2.8b": "gpt_neox", "EleutherAI/pythia-6.9b": "gpt_neox", "EleutherAI/pythia-12b": "gpt_neox", # BLOOM models "bigscience/bloom-560m": "bloom", "bigscience/bloom-1b1": "bloom", "bigscience/bloom-1b7": "bloom", "bigscience/bloom-3b": "bloom", "bigscience/bloom-7b1": "bloom", # Falcon models "tiiuae/falcon-7b": "falcon", "tiiuae/falcon-40b": "falcon", # MPT models "mosaicml/mpt-7b": "mpt", "mosaicml/mpt-30b": "mpt", } def get_model_family(model_name: str) -> Optional[str]: """ Get the model family for a given model name. Args: model_name: HuggingFace model name/path Returns: Family name if found, None otherwise """ return MODEL_TO_FAMILY.get(model_name) def get_family_config(family_name: str) -> Optional[Dict[str, Any]]: """ Get the configuration for a model family. Args: family_name: Name of the model family Returns: Family configuration dict if found, None otherwise """ return MODEL_FAMILIES.get(family_name) def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]], param_patterns: Dict[str, List[str]]) -> Dict[str, Any]: """ Get automatic dropdown selections based on model family. Args: model_name: HuggingFace model name module_patterns: Available module patterns from the model param_patterns: Available parameter patterns from the model Returns: Dict with keys: attention_selection, block_selection, norm_selection Each value is a list of pattern keys that should be pre-selected """ family = get_model_family(model_name) if not family: return { 'attention_selection': [], 'block_selection': [], 'norm_selection': [], # Empty list for multi-select dropdown 'family_name': None } config = get_family_config(family) if not config: return { 'attention_selection': [], 'block_selection': [], 'norm_selection': [], # Empty list for multi-select dropdown 'family_name': None } # Find matching patterns in the available patterns attention_matches = [] block_matches = [] norm_match = None # Match attention patterns attention_template = config['templates'].get('attention_pattern', '') for pattern_key in module_patterns.keys(): if _pattern_matches_template(pattern_key, attention_template): attention_matches.append(pattern_key) # Match block patterns (full layer outputs - residual stream) block_template = config['templates'].get('block_pattern', '') for pattern_key in module_patterns.keys(): if _pattern_matches_template(pattern_key, block_template): block_matches.append(pattern_key) # Match normalization parameter # Note: norm-params-dropdown has multi=True, so return a list norm_parameter = config.get('norm_parameter', '') if norm_parameter: for pattern_key in param_patterns.keys(): if _pattern_matches_template(pattern_key, norm_parameter): norm_match = [pattern_key] # Return as list for multi-select dropdown break return { 'attention_selection': attention_matches, 'block_selection': block_matches, 'norm_selection': norm_match if norm_match else [], # Ensure list for multi-select 'family_name': family, 'family_description': config.get('description', '') } def _pattern_matches_template(pattern: str, template: str) -> bool: """ Check if a pattern string matches a template. Templates use {N} as wildcard, patterns use {N} for the same purpose. Args: pattern: Pattern string like "model.layers.{N}.mlp" template: Template string like "model.layers.{N}.mlp" Returns: True if pattern matches template """ if not template: return False # Simple check: remove {N} from both and see if they match pattern_normalized = pattern.replace('{N}', '').replace('.', '_') template_normalized = template.replace('{N}', '').replace('.', '_') # Exact match return pattern_normalized == template_normalized