import streamlit as st
import json
from typing import List, Dict, Any, Optional
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
from huggingface_hub import HfApi
import os
_HF_TOKEN = os.getenv("HF_TOKEN")
_DEFAULT_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
_DEFAULT_DATASET_NAME = "layer_31_10104940_longfact_test"
st.set_page_config(
page_title="Cached Probe Viewer",
layout="wide",
)
st.title("Cached Probe Viewer")
def value_to_color_hex(value, vmin=0.0, vmax=1.0, power=3.0):
normalized = min(1, max(0, (value - vmin) / (vmax - vmin)))
normalized = normalized ** power
r = int(255 * (1 - normalized))
g = 255
b = int(255 * (1 - normalized))
return f'#{r:02x}{g:02x}{b:02x}'
def should_highlight_token(token: str) -> bool:
# Determine if a token should be highlighted based on its content
if "\n" in token:
return False
return True
def display_token_values(values: List[float], tokens: List[str], vmin=0.0, vmax=0.0, classification_labels: List[float] = None):
# Add CSS for hover functionality
st.markdown("""
""", unsafe_allow_html=True)
html_content = "
"
for idx in range(len(values)):
displayed_token = tokens[idx]
value = values[idx]
# Determine if this token should be highlighted
should_highlight = should_highlight_token(displayed_token)
if not should_highlight:
value = 0.0
color = value_to_color_hex(value, vmin=vmin, vmax=vmax)
# Check if this token has a positive classification label
label_class = ""
if classification_labels is not None and classification_labels[idx] == 1:
label_class = " label-span"
html_content += f"""{displayed_token.replace("$", "")}"""
html_content += "
"
return html_content
def list_cached_datasets():
api = HfApi()
repo_id = "obalcells/probe-activations-cache"
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", revision="main")
# Here's how the repo is structured:
# [
# '.gitattributes',
# 'README.md',
# 'layer_31_10104940_longfact_test/train-00000-of-00001.parquet',
# 'layer_31_10104940_longfact_train/train-00000-of-00001.parquet',
# 'layer_31_10104940_longfact_validation/train-00000-of-00001.parquet',
# 'layer_31_18222453_longfact_test/train-00000-of-00001.parquet'
# ]
# We want to extract the subsets from the repo_files
# (e.g. 'layer_31_10104940_longfact_test', 'layer_31_10104940_longfact_train', ...)
cached_datasets = [path.split('/')[0] for path in repo_files if path.endswith('.parquet')]
cached_datasets = list(set(cached_datasets))
return cached_datasets
# Load dataset from Hugging Face
@st.cache_data(show_spinner=True)
def load_hf_dataset(subset: str) -> Optional[Any]:
try:
with st.spinner("Loading dataset from Hugging Face..."):
dataset = load_dataset('obalcells/probe-activations-cache', subset, split='train', token=_HF_TOKEN)
return list(reversed(list(dataset)))
except Exception as e:
st.error(f"Failed to load dataset: {str(e)}")
return None
# Load tokenizer
@st.cache_resource
def load_tokenizer(model_name: str) -> Optional[AutoTokenizer]:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
except Exception as e:
st.error(f"Failed to load tokenizer: {str(e)}")
return None
# Main app
def main():
# Load dataset
cached_dataset_names: List[str] = list_cached_datasets()
default_dataset = _DEFAULT_DATASET_NAME if _DEFAULT_DATASET_NAME in cached_dataset_names else cached_dataset_names[0]
dataset_name = st.selectbox("Select dataset:", cached_dataset_names, index=cached_dataset_names.index(default_dataset))
dataset = load_hf_dataset(dataset_name)
if dataset is None:
st.error("Failed to load dataset. Please check your internet connection and try again.")
st.stop()
# Model selection
model_name = st.text_input("Model name (for tokenizer):", _DEFAULT_MODEL_NAME)
tokenizer = load_tokenizer(model_name)
if tokenizer is None:
st.error("Failed to load tokenizer. Please check the model name and try again.")
st.stop()
# Initialize session state for selected index if not exists
if 'selected_idx' not in st.session_state:
st.session_state.selected_idx = 0
# Select datapoint
selected_idx = st.selectbox("Select datapoint:", range(len(dataset)), index=st.session_state.selected_idx)
datapoint = dataset[selected_idx]
# Decode tokens
tokens = [tokenizer.decode(token_id) for token_id in datapoint['input_ids']]
# Visualization type selection
viz_type = st.radio(
"Select visualization type:",
["Probe Probabilities", "BCE Loss", "Classification Labels", "Classification Weights"]
)
# Display selected visualization
if viz_type == "Probe Probabilities":
values = datapoint['probe_probs']
vmax = 1.0
elif viz_type == "BCE Loss":
values = datapoint['bce_loss']
vmax = max(values)
elif viz_type == "Classification Labels":
values = datapoint['classification_labels']
vmax = 1.0
else: # Classification Weights
values = datapoint['classification_weights']
vmax = 1.0
st.write(f"### {viz_type}")
highlighted_text = display_token_values(
values,
tokens,
vmax=vmax,
classification_labels=datapoint['classification_labels']
)
st.markdown(highlighted_text, unsafe_allow_html=True)
# Navigation buttons at the bottom
st.write("") # Add some space
col1, col2, col3 = st.columns([1, 2, 1])
with col1:
if st.button("← Previous", disabled=selected_idx == 0):
st.session_state.selected_idx = selected_idx - 1
st.rerun()
with col2:
st.write(f"Example {selected_idx + 1} of {len(dataset)}")
with col3:
if st.button("Next →", disabled=selected_idx == len(dataset) - 1):
st.session_state.selected_idx = selected_idx + 1
st.rerun()
# Add keyboard navigation
st.markdown("""
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()