import streamlit as st import json from typing import List, Dict, Any, Optional from datasets import load_dataset from transformers import AutoTokenizer import torch from huggingface_hub import HfApi import os _HF_TOKEN = os.getenv("HF_TOKEN") _DEFAULT_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" _DEFAULT_DATASET_NAME = "layer_31_10104940_longfact_test" st.set_page_config( page_title="Cached Probe Viewer", layout="wide", ) st.title("Cached Probe Viewer") def value_to_color_hex(value, vmin=0.0, vmax=1.0, power=3.0): normalized = min(1, max(0, (value - vmin) / (vmax - vmin))) normalized = normalized ** power r = int(255 * (1 - normalized)) g = 255 b = int(255 * (1 - normalized)) return f'#{r:02x}{g:02x}{b:02x}' def should_highlight_token(token: str) -> bool: # Determine if a token should be highlighted based on its content if "\n" in token: return False return True def display_token_values(values: List[float], tokens: List[str], vmin=0.0, vmax=0.0, classification_labels: List[float] = None): # Add CSS for hover functionality st.markdown(""" """, unsafe_allow_html=True) html_content = "
" for idx in range(len(values)): displayed_token = tokens[idx] value = values[idx] # Determine if this token should be highlighted should_highlight = should_highlight_token(displayed_token) if not should_highlight: value = 0.0 color = value_to_color_hex(value, vmin=vmin, vmax=vmax) # Check if this token has a positive classification label label_class = "" if classification_labels is not None and classification_labels[idx] == 1: label_class = " label-span" html_content += f"""{displayed_token.replace("$", "")}""" html_content += "
" return html_content def list_cached_datasets(): api = HfApi() repo_id = "obalcells/probe-activations-cache" repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", revision="main") # Here's how the repo is structured: # [ # '.gitattributes', # 'README.md', # 'layer_31_10104940_longfact_test/train-00000-of-00001.parquet', # 'layer_31_10104940_longfact_train/train-00000-of-00001.parquet', # 'layer_31_10104940_longfact_validation/train-00000-of-00001.parquet', # 'layer_31_18222453_longfact_test/train-00000-of-00001.parquet' # ] # We want to extract the subsets from the repo_files # (e.g. 'layer_31_10104940_longfact_test', 'layer_31_10104940_longfact_train', ...) cached_datasets = [path.split('/')[0] for path in repo_files if path.endswith('.parquet')] cached_datasets = list(set(cached_datasets)) return cached_datasets # Load dataset from Hugging Face @st.cache_data(show_spinner=True) def load_hf_dataset(subset: str) -> Optional[Any]: try: with st.spinner("Loading dataset from Hugging Face..."): dataset = load_dataset('obalcells/probe-activations-cache', subset, split='train', token=_HF_TOKEN) return list(reversed(list(dataset))) except Exception as e: st.error(f"Failed to load dataset: {str(e)}") return None # Load tokenizer @st.cache_resource def load_tokenizer(model_name: str) -> Optional[AutoTokenizer]: try: tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.padding_side = "left" tokenizer.pad_token = tokenizer.eos_token return tokenizer except Exception as e: st.error(f"Failed to load tokenizer: {str(e)}") return None # Main app def main(): # Load dataset cached_dataset_names: List[str] = list_cached_datasets() default_dataset = _DEFAULT_DATASET_NAME if _DEFAULT_DATASET_NAME in cached_dataset_names else cached_dataset_names[0] dataset_name = st.selectbox("Select dataset:", cached_dataset_names, index=cached_dataset_names.index(default_dataset)) dataset = load_hf_dataset(dataset_name) if dataset is None: st.error("Failed to load dataset. Please check your internet connection and try again.") st.stop() # Model selection model_name = st.text_input("Model name (for tokenizer):", _DEFAULT_MODEL_NAME) tokenizer = load_tokenizer(model_name) if tokenizer is None: st.error("Failed to load tokenizer. Please check the model name and try again.") st.stop() # Initialize session state for selected index if not exists if 'selected_idx' not in st.session_state: st.session_state.selected_idx = 0 # Select datapoint selected_idx = st.selectbox("Select datapoint:", range(len(dataset)), index=st.session_state.selected_idx) datapoint = dataset[selected_idx] # Decode tokens tokens = [tokenizer.decode(token_id) for token_id in datapoint['input_ids']] # Visualization type selection viz_type = st.radio( "Select visualization type:", ["Probe Probabilities", "BCE Loss", "Classification Labels", "Classification Weights"] ) # Display selected visualization if viz_type == "Probe Probabilities": values = datapoint['probe_probs'] vmax = 1.0 elif viz_type == "BCE Loss": values = datapoint['bce_loss'] vmax = max(values) elif viz_type == "Classification Labels": values = datapoint['classification_labels'] vmax = 1.0 else: # Classification Weights values = datapoint['classification_weights'] vmax = 1.0 st.write(f"### {viz_type}") highlighted_text = display_token_values( values, tokens, vmax=vmax, classification_labels=datapoint['classification_labels'] ) st.markdown(highlighted_text, unsafe_allow_html=True) # Navigation buttons at the bottom st.write("") # Add some space col1, col2, col3 = st.columns([1, 2, 1]) with col1: if st.button("← Previous", disabled=selected_idx == 0): st.session_state.selected_idx = selected_idx - 1 st.rerun() with col2: st.write(f"Example {selected_idx + 1} of {len(dataset)}") with col3: if st.button("Next →", disabled=selected_idx == len(dataset) - 1): st.session_state.selected_idx = selected_idx + 1 st.rerun() # Add keyboard navigation st.markdown(""" """, unsafe_allow_html=True) if __name__ == "__main__": main()