LLMVis / utils /model_config.py
cdpearlman's picture
New models chosen and rag docs updated
f478beb
"""
Model family configuration registry.
Maps specific model names to families, and families to canonical module/parameter patterns.
This allows automatic selection of appropriate modules and parameters based on model architecture.
"""
from typing import Dict, List, Optional, Any
# Model family specifications
MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
# LLaMA-like models (LLaMA, Mistral, Qwen2)
"llama_like": {
"description": "LLaMA, Mistral, Qwen2 architecture",
"templates": {
"attention_pattern": "model.layers.{N}.self_attn",
"mlp_pattern": "model.layers.{N}.mlp",
"block_pattern": "model.layers.{N}",
},
"norm_parameter": "model.norm.weight",
"norm_type": "rmsnorm",
},
# GPT-2 family
"gpt2": {
"description": "GPT-2 architecture",
"templates": {
"attention_pattern": "transformer.h.{N}.attn",
"mlp_pattern": "transformer.h.{N}.mlp",
"block_pattern": "transformer.h.{N}",
},
"norm_parameter": "transformer.ln_f.weight",
"logit_lens_pattern": "lm_head.weight",
"norm_type": "layernorm",
},
# GPT-Neo (EleutherAI) — similar to GPT-2 but with local attention
"gpt_neo": {
"description": "GPT-Neo architecture (EleutherAI)",
"templates": {
"attention_pattern": "transformer.h.{N}.attn.attention",
"mlp_pattern": "transformer.h.{N}.mlp",
"block_pattern": "transformer.h.{N}",
},
"norm_parameter": "transformer.ln_f.weight",
"norm_type": "layernorm",
},
# OPT
"opt": {
"description": "OPT architecture",
"templates": {
"attention_pattern": "model.decoder.layers.{N}.self_attn",
"mlp_pattern": "model.decoder.layers.{N}.fc2",
"block_pattern": "model.decoder.layers.{N}",
},
"norm_parameter": "model.decoder.final_layer_norm.weight",
"norm_type": "layernorm",
},
# GPT-NeoX
"gpt_neox": {
"description": "GPT-NeoX architecture",
"templates": {
"attention_pattern": "gpt_neox.layers.{N}.attention",
"mlp_pattern": "gpt_neox.layers.{N}.mlp",
"block_pattern": "gpt_neox.layers.{N}",
},
"norm_parameter": "gpt_neox.final_layer_norm.weight",
"norm_type": "layernorm",
},
# BLOOM
"bloom": {
"description": "BLOOM architecture",
"templates": {
"attention_pattern": "transformer.h.{N}.self_attention",
"mlp_pattern": "transformer.h.{N}.mlp",
"block_pattern": "transformer.h.{N}",
},
"norm_parameter": "transformer.ln_f.weight",
"norm_type": "layernorm",
},
# Falcon
"falcon": {
"description": "Falcon architecture",
"templates": {
"attention_pattern": "transformer.h.{N}.self_attention",
"mlp_pattern": "transformer.h.{N}.mlp",
"block_pattern": "transformer.h.{N}",
},
"norm_parameter": "transformer.ln_f.weight",
"norm_type": "layernorm",
},
# MPT
"mpt": {
"description": "MPT architecture",
"templates": {
"attention_pattern": "transformer.blocks.{N}.attn",
"mlp_pattern": "transformer.blocks.{N}.ffn",
"block_pattern": "transformer.blocks.{N}",
},
"norm_parameter": "transformer.norm_f.weight",
"norm_type": "layernorm",
},
}
# Hard-coded mapping of specific model names to families
MODEL_TO_FAMILY: Dict[str, str] = {
# Qwen models
"Qwen/Qwen2.5-0.5B": "llama_like",
"Qwen/Qwen2.5-1.5B": "llama_like",
"Qwen/Qwen2.5-3B": "llama_like",
"Qwen/Qwen2.5-7B": "llama_like",
"Qwen/Qwen2.5-14B": "llama_like",
"Qwen/Qwen2.5-32B": "llama_like",
"Qwen/Qwen2.5-72B": "llama_like",
"Qwen/Qwen2-0.5B": "llama_like",
"Qwen/Qwen2-1.5B": "llama_like",
"Qwen/Qwen2-7B": "llama_like",
# LLaMA models
"meta-llama/Llama-2-7b-hf": "llama_like",
"meta-llama/Llama-2-13b-hf": "llama_like",
"meta-llama/Llama-2-70b-hf": "llama_like",
"meta-llama/Llama-3.1-8B": "llama_like",
"meta-llama/Llama-3.1-70B": "llama_like",
"meta-llama/Llama-3.2-1B": "llama_like",
"meta-llama/Llama-3.2-3B": "llama_like",
# Mistral models
"mistralai/Mistral-7B-v0.1": "llama_like",
"mistralai/Mistral-7B-v0.3": "llama_like",
"mistralai/Mixtral-8x7B-v0.1": "llama_like",
"mistralai/Mixtral-8x22B-v0.1": "llama_like",
# GPT-2 models
"gpt2": "gpt2",
"gpt2-medium": "gpt2",
"gpt2-large": "gpt2",
"gpt2-xl": "gpt2",
"openai-community/gpt2": "gpt2",
"openai-community/gpt2-medium": "gpt2",
"openai-community/gpt2-large": "gpt2",
"openai-community/gpt2-xl": "gpt2",
# OPT models
"facebook/opt-125m": "opt",
"facebook/opt-350m": "opt",
"facebook/opt-1.3b": "opt",
"facebook/opt-2.7b": "opt",
"facebook/opt-6.7b": "opt",
"facebook/opt-13b": "opt",
"facebook/opt-30b": "opt",
# GPT-Neo models (EleutherAI)
"EleutherAI/gpt-neo-125M": "gpt_neo",
"EleutherAI/gpt-neo-1.3B": "gpt_neo",
"EleutherAI/gpt-neo-2.7B": "gpt_neo",
# GPT-NeoX / Pythia models (EleutherAI)
"EleutherAI/gpt-neox-20b": "gpt_neox",
"EleutherAI/pythia-70m": "gpt_neox",
"EleutherAI/pythia-160m": "gpt_neox",
"EleutherAI/pythia-410m": "gpt_neox",
"EleutherAI/pythia-1b": "gpt_neox",
"EleutherAI/pythia-1.4b": "gpt_neox",
"EleutherAI/pythia-2.8b": "gpt_neox",
"EleutherAI/pythia-6.9b": "gpt_neox",
"EleutherAI/pythia-12b": "gpt_neox",
# BLOOM models
"bigscience/bloom-560m": "bloom",
"bigscience/bloom-1b1": "bloom",
"bigscience/bloom-1b7": "bloom",
"bigscience/bloom-3b": "bloom",
"bigscience/bloom-7b1": "bloom",
# Falcon models
"tiiuae/falcon-7b": "falcon",
"tiiuae/falcon-40b": "falcon",
# MPT models
"mosaicml/mpt-7b": "mpt",
"mosaicml/mpt-30b": "mpt",
}
def get_model_family(model_name: str) -> Optional[str]:
"""
Get the model family for a given model name.
Args:
model_name: HuggingFace model name/path
Returns:
Family name if found, None otherwise
"""
return MODEL_TO_FAMILY.get(model_name)
def get_family_config(family_name: str) -> Optional[Dict[str, Any]]:
"""
Get the configuration for a model family.
Args:
family_name: Name of the model family
Returns:
Family configuration dict if found, None otherwise
"""
return MODEL_FAMILIES.get(family_name)
def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]],
param_patterns: Dict[str, List[str]]) -> Dict[str, Any]:
"""
Get automatic dropdown selections based on model family.
Args:
model_name: HuggingFace model name
module_patterns: Available module patterns from the model
param_patterns: Available parameter patterns from the model
Returns:
Dict with keys: attention_selection, block_selection, norm_selection
Each value is a list of pattern keys that should be pre-selected
"""
family = get_model_family(model_name)
if not family:
return {
'attention_selection': [],
'block_selection': [],
'norm_selection': [], # Empty list for multi-select dropdown
'family_name': None
}
config = get_family_config(family)
if not config:
return {
'attention_selection': [],
'block_selection': [],
'norm_selection': [], # Empty list for multi-select dropdown
'family_name': None
}
# Find matching patterns in the available patterns
attention_matches = []
block_matches = []
norm_match = None
# Match attention patterns
attention_template = config['templates'].get('attention_pattern', '')
for pattern_key in module_patterns.keys():
if _pattern_matches_template(pattern_key, attention_template):
attention_matches.append(pattern_key)
# Match block patterns (full layer outputs - residual stream)
block_template = config['templates'].get('block_pattern', '')
for pattern_key in module_patterns.keys():
if _pattern_matches_template(pattern_key, block_template):
block_matches.append(pattern_key)
# Match normalization parameter
# Note: norm-params-dropdown has multi=True, so return a list
norm_parameter = config.get('norm_parameter', '')
if norm_parameter:
for pattern_key in param_patterns.keys():
if _pattern_matches_template(pattern_key, norm_parameter):
norm_match = [pattern_key] # Return as list for multi-select dropdown
break
return {
'attention_selection': attention_matches,
'block_selection': block_matches,
'norm_selection': norm_match if norm_match else [], # Ensure list for multi-select
'family_name': family,
'family_description': config.get('description', '')
}
def _pattern_matches_template(pattern: str, template: str) -> bool:
"""
Check if a pattern string matches a template.
Templates use {N} as wildcard, patterns use {N} for the same purpose.
Args:
pattern: Pattern string like "model.layers.{N}.mlp"
template: Template string like "model.layers.{N}.mlp"
Returns:
True if pattern matches template
"""
if not template:
return False
# Simple check: remove {N} from both and see if they match
pattern_normalized = pattern.replace('{N}', '').replace('.', '_')
template_normalized = template.replace('{N}', '').replace('.', '_')
# Exact match
return pattern_normalized == template_normalized