Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| from typing import List, Dict, Any, Optional | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| import torch | |
| from huggingface_hub import HfApi | |
| import os | |
| _HF_TOKEN = os.getenv("HF_TOKEN") | |
| _DEFAULT_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| _DEFAULT_DATASET_NAME = "layer_31_10104940_longfact_test" | |
| st.set_page_config( | |
| page_title="Cached Probe Viewer", | |
| layout="wide", | |
| ) | |
| st.title("Cached Probe Viewer") | |
| def value_to_color_hex(value, vmin=0.0, vmax=1.0, power=3.0): | |
| normalized = min(1, max(0, (value - vmin) / (vmax - vmin))) | |
| normalized = normalized ** power | |
| r = int(255 * (1 - normalized)) | |
| g = 255 | |
| b = int(255 * (1 - normalized)) | |
| return f'#{r:02x}{g:02x}{b:02x}' | |
| def should_highlight_token(token: str) -> bool: | |
| # Determine if a token should be highlighted based on its content | |
| if "\n" in token: | |
| return False | |
| return True | |
| def display_token_values(values: List[float], tokens: List[str], vmin=0.0, vmax=0.0, classification_labels: List[float] = None): | |
| # Add CSS for hover functionality | |
| st.markdown(""" | |
| <style> | |
| .token-span { | |
| position: relative; | |
| display: inline; | |
| padding: 0 1px; | |
| margin: 0; | |
| line-height: 1.5; | |
| } | |
| .token-span:hover::after { | |
| content: attr(data-value); | |
| position: absolute; | |
| bottom: 100%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| background-color: black; | |
| color: white; | |
| padding: 4px 8px; | |
| border-radius: 4px; | |
| font-size: 12px; | |
| white-space: nowrap; | |
| z-index: 1000; | |
| } | |
| .label-span { | |
| text-decoration: underline; | |
| text-decoration-thickness: 3px; | |
| text-decoration-color: #85081b; | |
| } | |
| .text-container { | |
| font-family: monospace; | |
| white-space: pre-wrap; | |
| line-height: 1.5; | |
| padding: 10px; | |
| background-color: #f8f9fa; | |
| border-radius: 4px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| html_content = "<div class='text-container'>" | |
| for idx in range(len(values)): | |
| displayed_token = tokens[idx] | |
| value = values[idx] | |
| # Determine if this token should be highlighted | |
| should_highlight = should_highlight_token(displayed_token) | |
| if not should_highlight: | |
| value = 0.0 | |
| color = value_to_color_hex(value, vmin=vmin, vmax=vmax) | |
| # Check if this token has a positive classification label | |
| label_class = "" | |
| if classification_labels is not None and classification_labels[idx] == 1: | |
| label_class = " label-span" | |
| html_content += f"""<span class='token-span{label_class}' | |
| style='background-color: {color}; | |
| color: black;' | |
| data-value='Value: {value:.3f}'>{displayed_token.replace("$", "")}</span>""" | |
| html_content += "</div>" | |
| return html_content | |
| def list_cached_datasets(): | |
| api = HfApi() | |
| repo_id = "obalcells/probe-activations-cache" | |
| repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", revision="main") | |
| # Here's how the repo is structured: | |
| # [ | |
| # '.gitattributes', | |
| # 'README.md', | |
| # 'layer_31_10104940_longfact_test/train-00000-of-00001.parquet', | |
| # 'layer_31_10104940_longfact_train/train-00000-of-00001.parquet', | |
| # 'layer_31_10104940_longfact_validation/train-00000-of-00001.parquet', | |
| # 'layer_31_18222453_longfact_test/train-00000-of-00001.parquet' | |
| # ] | |
| # We want to extract the subsets from the repo_files | |
| # (e.g. 'layer_31_10104940_longfact_test', 'layer_31_10104940_longfact_train', ...) | |
| cached_datasets = [path.split('/')[0] for path in repo_files if path.endswith('.parquet')] | |
| cached_datasets = list(set(cached_datasets)) | |
| return cached_datasets | |
| # Load dataset from Hugging Face | |
| def load_hf_dataset(subset: str) -> Optional[Any]: | |
| try: | |
| with st.spinner("Loading dataset from Hugging Face..."): | |
| dataset = load_dataset('obalcells/probe-activations-cache', subset, split='train', token=_HF_TOKEN) | |
| return list(reversed(list(dataset))) | |
| except Exception as e: | |
| st.error(f"Failed to load dataset: {str(e)}") | |
| return None | |
| # Load tokenizer | |
| def load_tokenizer(model_name: str) -> Optional[AutoTokenizer]: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.padding_side = "left" | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| except Exception as e: | |
| st.error(f"Failed to load tokenizer: {str(e)}") | |
| return None | |
| # Main app | |
| def main(): | |
| # Load dataset | |
| cached_dataset_names: List[str] = list_cached_datasets() | |
| default_dataset = _DEFAULT_DATASET_NAME if _DEFAULT_DATASET_NAME in cached_dataset_names else cached_dataset_names[0] | |
| dataset_name = st.selectbox("Select dataset:", cached_dataset_names, index=cached_dataset_names.index(default_dataset)) | |
| dataset = load_hf_dataset(dataset_name) | |
| if dataset is None: | |
| st.error("Failed to load dataset. Please check your internet connection and try again.") | |
| st.stop() | |
| # Model selection | |
| model_name = st.text_input("Model name (for tokenizer):", _DEFAULT_MODEL_NAME) | |
| tokenizer = load_tokenizer(model_name) | |
| if tokenizer is None: | |
| st.error("Failed to load tokenizer. Please check the model name and try again.") | |
| st.stop() | |
| # Initialize session state for selected index if not exists | |
| if 'selected_idx' not in st.session_state: | |
| st.session_state.selected_idx = 0 | |
| # Select datapoint | |
| selected_idx = st.selectbox("Select datapoint:", range(len(dataset)), index=st.session_state.selected_idx) | |
| datapoint = dataset[selected_idx] | |
| # Decode tokens | |
| tokens = [tokenizer.decode(token_id) for token_id in datapoint['input_ids']] | |
| # Visualization type selection | |
| viz_type = st.radio( | |
| "Select visualization type:", | |
| ["Probe Probabilities", "BCE Loss", "Classification Labels", "Classification Weights"] | |
| ) | |
| # Display selected visualization | |
| if viz_type == "Probe Probabilities": | |
| values = datapoint['probe_probs'] | |
| vmax = 1.0 | |
| elif viz_type == "BCE Loss": | |
| values = datapoint['bce_loss'] | |
| vmax = max(values) | |
| elif viz_type == "Classification Labels": | |
| values = datapoint['classification_labels'] | |
| vmax = 1.0 | |
| else: # Classification Weights | |
| values = datapoint['classification_weights'] | |
| vmax = 1.0 | |
| st.write(f"### {viz_type}") | |
| highlighted_text = display_token_values( | |
| values, | |
| tokens, | |
| vmax=vmax, | |
| classification_labels=datapoint['classification_labels'] | |
| ) | |
| st.markdown(highlighted_text, unsafe_allow_html=True) | |
| # Navigation buttons at the bottom | |
| st.write("") # Add some space | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col1: | |
| if st.button("← Previous", disabled=selected_idx == 0): | |
| st.session_state.selected_idx = selected_idx - 1 | |
| st.rerun() | |
| with col2: | |
| st.write(f"Example {selected_idx + 1} of {len(dataset)}") | |
| with col3: | |
| if st.button("Next →", disabled=selected_idx == len(dataset) - 1): | |
| st.session_state.selected_idx = selected_idx + 1 | |
| st.rerun() | |
| # Add keyboard navigation | |
| st.markdown(""" | |
| <script> | |
| document.addEventListener('keydown', function(e) { | |
| if (e.key === 'ArrowLeft') { | |
| document.querySelector('button:has-text("← Previous")').click(); | |
| } else if (e.key === 'ArrowRight') { | |
| document.querySelector('button:has-text("Next →")').click(); | |
| } | |
| }); | |
| </script> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |