cdpearlman commited on
Commit
c19d5a8
·
1 Parent(s): 4c3d673

Fixed errors for beam search refactor

Browse files
app.py CHANGED
@@ -827,8 +827,13 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
827
  tracking_data = compute_layer_wise_summaries(layer_data, activation_data)
828
  layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
829
  significant_layers = tracking_data.get('significant_layers', [])
 
830
  global_top5 = activation_data.get('global_top5_tokens', [])
831
 
 
 
 
 
832
  # If in ablation mode, also extract original layer data
833
  original_layer_data = None
834
  original_layer_wise_probs = {}
@@ -855,6 +860,10 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
855
  layer_wise_probs2 = tracking_data2.get('layer_wise_top5_probs', {})
856
  significant_layers2 = tracking_data2.get('significant_layers', [])
857
  global_top5_2 = activation_data2.get('global_top5_tokens', [])
 
 
 
 
858
 
859
  # Create accordion panels (reversed to show final layer first)
860
  accordions = []
 
827
  tracking_data = compute_layer_wise_summaries(layer_data, activation_data)
828
  layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
829
  significant_layers = tracking_data.get('significant_layers', [])
830
+ # Get global top 5 tokens from activation data
831
  global_top5 = activation_data.get('global_top5_tokens', [])
832
 
833
+ # Ensure global_top5 is list of dicts (handle legacy tuples/lists from old sessions)
834
+ if global_top5 and isinstance(global_top5[0], (list, tuple)):
835
+ global_top5 = [{'token': t, 'probability': p} for t, p in global_top5]
836
+
837
  # If in ablation mode, also extract original layer data
838
  original_layer_data = None
839
  original_layer_wise_probs = {}
 
860
  layer_wise_probs2 = tracking_data2.get('layer_wise_top5_probs', {})
861
  significant_layers2 = tracking_data2.get('significant_layers', [])
862
  global_top5_2 = activation_data2.get('global_top5_tokens', [])
863
+
864
+ # Ensure global_top5_2 is list of dicts (handle legacy tuples)
865
+ if global_top5_2 and isinstance(global_top5_2[0], (list, tuple)):
866
+ global_top5_2 = [{'token': t, 'probability': p} for t, p in global_top5_2]
867
 
868
  # Create accordion panels (reversed to show final layer first)
869
  accordions = []
components/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (177 Bytes)
 
components/__pycache__/glossary.cpython-311.pyc DELETED
Binary file (4.69 kB)
 
components/__pycache__/main_panel.cpython-311.pyc DELETED
Binary file (5.08 kB)
 
components/__pycache__/model_selector.cpython-311.pyc DELETED
Binary file (2.87 kB)
 
components/__pycache__/sidebar.cpython-311.pyc DELETED
Binary file (3.34 kB)
 
components/__pycache__/tokenization_panel.cpython-311.pyc DELETED
Binary file (9.78 kB)
 
utils/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.86 kB)
 
utils/__pycache__/beam_search.cpython-311.pyc DELETED
Binary file (7 kB)
 
utils/__pycache__/head_detection.cpython-311.pyc DELETED
Binary file (16.1 kB)
 
utils/__pycache__/model_config.cpython-311.pyc DELETED
Binary file (8.45 kB)
 
utils/__pycache__/model_patterns.cpython-311.pyc DELETED
Binary file (59.4 kB)
 
utils/__pycache__/prompt_comparison.cpython-311.pyc DELETED
Binary file (13.9 kB)
 
utils/model_patterns.py CHANGED
@@ -92,7 +92,7 @@ def merge_token_probabilities(token_probs: List[Tuple[str, float]]) -> List[Tupl
92
  return result
93
 
94
 
95
- def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[Tuple[str, float]]:
96
  """
97
  Compute the global top-5 tokens from model's final output with merged probabilities.
98
 
@@ -102,7 +102,7 @@ def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[
102
  top_k: Number of top tokens to return (default: 5)
103
 
104
  Returns:
105
- List of (token_string, probability) tuples for top K tokens with merged probabilities
106
  """
107
  with torch.no_grad():
108
  # Get probabilities for next token (last position)
@@ -121,8 +121,8 @@ def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[
121
  # Merge tokens with/without leading space
122
  merged = merge_token_probabilities(candidates)
123
 
124
- # Return top K after merging
125
- return merged[:top_k]
126
 
127
 
128
  def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
@@ -1048,7 +1048,12 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
1048
 
1049
  # Get global top 5 tokens from final output
1050
  global_top5_tokens = activation_data.get('global_top5_tokens', [])
1051
- global_top5_token_names = [token for token, _ in global_top5_tokens]
 
 
 
 
 
1052
 
1053
  layer_data = []
1054
  prev_token_probs = {} # Track previous layer's token probabilities (layer's own top 5)
 
92
  return result
93
 
94
 
95
+ def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[Dict[str, Any]]:
96
  """
97
  Compute the global top-5 tokens from model's final output with merged probabilities.
98
 
 
102
  top_k: Number of top tokens to return (default: 5)
103
 
104
  Returns:
105
+ List of dicts {'token': str, 'probability': float} for top K tokens
106
  """
107
  with torch.no_grad():
108
  # Get probabilities for next token (last position)
 
121
  # Merge tokens with/without leading space
122
  merged = merge_token_probabilities(candidates)
123
 
124
+ # Return top K after merging, formatted as dicts
125
+ return [{'token': t, 'probability': p} for t, p in merged[:top_k]]
126
 
127
 
128
  def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
 
1048
 
1049
  # Get global top 5 tokens from final output
1050
  global_top5_tokens = activation_data.get('global_top5_tokens', [])
1051
+
1052
+ # Handle both dicts (new format) and tuples (legacy)
1053
+ if global_top5_tokens and isinstance(global_top5_tokens[0], dict):
1054
+ global_top5_token_names = [t.get('token') for t in global_top5_tokens]
1055
+ else:
1056
+ global_top5_token_names = [token for token, _ in global_top5_tokens]
1057
 
1058
  layer_data = []
1059
  prev_token_probs = {} # Track previous layer's token probabilities (layer's own top 5)