File size: 2,253 Bytes
3850656
 
2965a7d
ae08976
 
 
d60cfe2
 
2965a7d
c6cb681
7fa8fb4
2965a7d
2ad1c2e
3850656
2ad1c2e
1a93ca4
c6cb681
3850656
c6cb681
65dd6b7
ac1a7df
ae08976
2ad1c2e
44b60b5
5c05f37
c6cb681
ac1a7df
 
d60cfe2
ac1a7df
5cad664
3850656
 
c6cb681
 
 
 
24b78dc
3850656
 
7fa8fb4
 
 
3850656
 
2ad1c2e
3850656
 
2ad1c2e
 
3850656
 
 
 
 
 
c6cb681
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from .model_patterns import (load_model_and_get_patterns, execute_forward_pass, 
                             logit_lens_transformation, extract_layer_data, 
                             generate_bertviz_html,
                             execute_forward_pass_with_head_ablation,
                             execute_forward_pass_with_multi_layer_head_ablation,
                             merge_token_probabilities, 
                             compute_global_top5_tokens, compute_per_position_top5,
                             detect_significant_probability_increases, 
                             evaluate_sequence_ablation, generate_bertviz_model_view_html)
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
from .head_detection import load_head_categories, verify_head_activation, get_active_head_summary
from .beam_search import perform_beam_search
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data


__all__ = [
    # Model patterns
    'load_model_and_get_patterns', 
    'execute_forward_pass',
    'execute_forward_pass_with_head_ablation',
    'execute_forward_pass_with_multi_layer_head_ablation',
    'evaluate_sequence_ablation',
    'logit_lens_transformation',
    'extract_layer_data',
    'generate_bertviz_html',
    'merge_token_probabilities',
    'compute_global_top5_tokens',
    'compute_per_position_top5',
    'detect_significant_probability_increases',
    'generate_bertviz_model_view_html',
    
    # Model config
    'get_model_family',
    'get_family_config',
    'get_auto_selections',
    'MODEL_TO_FAMILY',
    'MODEL_FAMILIES',
    
    # Head detection
    'load_head_categories',
    'verify_head_activation',
    'get_active_head_summary',
    
    # Beam search
    'perform_beam_search',
    
    # Ablation metrics
    'compute_kl_divergence',
    'score_sequence',
    'get_token_probability_deltas',
    
    # Token attribution
    'compute_integrated_gradients',
    'compute_simple_gradient_attribution',
    'create_attribution_visualization_data'
]