LLMVis / utils /__init__.py
cdpearlman's picture
Attention refactor, better categorization and explanation
7fa8fb4
from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
logit_lens_transformation, extract_layer_data,
generate_bertviz_html,
execute_forward_pass_with_head_ablation,
execute_forward_pass_with_multi_layer_head_ablation,
merge_token_probabilities,
compute_global_top5_tokens, compute_per_position_top5,
detect_significant_probability_increases,
evaluate_sequence_ablation, generate_bertviz_model_view_html)
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
from .head_detection import load_head_categories, verify_head_activation, get_active_head_summary
from .beam_search import perform_beam_search
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
__all__ = [
# Model patterns
'load_model_and_get_patterns',
'execute_forward_pass',
'execute_forward_pass_with_head_ablation',
'execute_forward_pass_with_multi_layer_head_ablation',
'evaluate_sequence_ablation',
'logit_lens_transformation',
'extract_layer_data',
'generate_bertviz_html',
'merge_token_probabilities',
'compute_global_top5_tokens',
'compute_per_position_top5',
'detect_significant_probability_increases',
'generate_bertviz_model_view_html',
# Model config
'get_model_family',
'get_family_config',
'get_auto_selections',
'MODEL_TO_FAMILY',
'MODEL_FAMILIES',
# Head detection
'load_head_categories',
'verify_head_activation',
'get_active_head_summary',
# Beam search
'perform_beam_search',
# Ablation metrics
'compute_kl_divergence',
'score_sequence',
'get_token_probability_deltas',
# Token attribution
'compute_integrated_gradients',
'compute_simple_gradient_attribution',
'create_attribution_visualization_data'
]