|
|
import os |
|
|
import re |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
|
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
import base64 |
|
|
|
|
|
|
|
|
def parse_filename(fname: str): |
|
|
""" |
|
|
Extract: |
|
|
- number of features (nf) |
|
|
- number of hidden states (nh) |
|
|
- sparsity index on log scale = (instance + 1) out of ni |
|
|
|
|
|
Example: |
|
|
5panel_nf7_nh4_ni10_instance_0_feature.gif |
|
|
-> features=7, hidden=4, sparsity=(1, 10) |
|
|
""" |
|
|
pattern = r"nf(\d+)_nh(\d+)_ni(\d+)_instance_(\d+)" |
|
|
m = re.search(pattern, fname) |
|
|
if not m: |
|
|
return None |
|
|
|
|
|
nf = int(m.group(1)) |
|
|
nh = int(m.group(2)) |
|
|
ni = int(m.group(3)) |
|
|
instance = int(m.group(4)) |
|
|
|
|
|
return { |
|
|
"features": nf, |
|
|
"hidden_dim": nh, |
|
|
"sparsity": instance + 1, |
|
|
"sparsity_out_of": ni, |
|
|
} |
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.header("GIF Browser") |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
SRC_DIR = Path(__file__).resolve().parent |
|
|
GIF_DIR = SRC_DIR / "hf_gifs" |
|
|
|
|
|
if not GIF_DIR.exists(): |
|
|
st.error(f"Could not find hf_gifs folder at: {GIF_DIR}") |
|
|
else: |
|
|
|
|
|
rows = [] |
|
|
for p in sorted(GIF_DIR.glob("*.gif")): |
|
|
meta = parse_filename(p.name) |
|
|
if meta is None: |
|
|
continue |
|
|
rows.append( |
|
|
{ |
|
|
"path": str(p), |
|
|
"filename": p.name, |
|
|
**meta, |
|
|
} |
|
|
) |
|
|
|
|
|
if not rows: |
|
|
st.warning(f"No matching .gif files found in {GIF_DIR}") |
|
|
else: |
|
|
gifs_df = pd.DataFrame(rows) |
|
|
|
|
|
|
|
|
tuples_df = ( |
|
|
gifs_df[["features", "hidden_dim", "sparsity", "sparsity_out_of"]] |
|
|
.drop_duplicates() |
|
|
.sort_values(["features", "hidden_dim", "sparsity_out_of", "sparsity"]) |
|
|
.reset_index(drop=True) |
|
|
) |
|
|
|
|
|
|
|
|
labels = [ |
|
|
f"nf={r.features}, nh={r.hidden_dim}, sparsity={r.sparsity}/{r.sparsity_out_of}" |
|
|
for r in tuples_df.itertuples(index=False) |
|
|
] |
|
|
|
|
|
selected_label = st.selectbox("Select (nf, nh, sparsity/out_of)", labels) |
|
|
|
|
|
|
|
|
idx = labels.index(selected_label) |
|
|
sel = tuples_df.iloc[idx] |
|
|
|
|
|
selected_features = int(sel["features"]) |
|
|
selected_hidden = int(sel["hidden_dim"]) |
|
|
selected_sparsity = int(sel["sparsity"]) |
|
|
selected_outof = int(sel["sparsity_out_of"]) |
|
|
|
|
|
|
|
|
filtered = gifs_df[ |
|
|
(gifs_df["features"] == selected_features) |
|
|
& (gifs_df["hidden_dim"] == selected_hidden) |
|
|
& (gifs_df["sparsity"] == selected_sparsity) |
|
|
& (gifs_df["sparsity_out_of"] == selected_outof) |
|
|
].copy() |
|
|
|
|
|
|
|
|
if filtered.empty: |
|
|
st.info("No GIFs found for that (nf, nh) combination.") |
|
|
else: |
|
|
filtered = filtered.sort_values(["filename"]) |
|
|
|
|
|
st.caption(f"Showing {len(filtered)} GIF(s) from: {GIF_DIR}") |
|
|
|
|
|
|
|
|
items = filtered.to_dict(orient="records") |
|
|
for item in items: |
|
|
caption = f"{item['sparsity']} out of {item['sparsity_out_of']} for {item['features']} features and {item['hidden_dim']} hidden dimensions" |
|
|
gif_bytes = Path(item["path"]).read_bytes() |
|
|
b64 = base64.b64encode(gif_bytes).decode("ascii") |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<figure style="margin:0;"> |
|
|
<img src="data:image/gif;base64,{b64}" style="max-width:100%; height:auto;" /> |
|
|
<figcaption style="font-size:0.9em; color: #666;">{caption}</figcaption> |
|
|
</figure> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
st.subheader("Phase Tracer Viewer") |
|
|
|
|
|
PHASE_TRACER_DIR = SRC_DIR / "phase_diagram_tracers" |
|
|
|
|
|
if not PHASE_TRACER_DIR.exists(): |
|
|
st.error(f"Could not find phase_tracer_gifs folder at: {PHASE_TRACER_DIR}") |
|
|
else: |
|
|
phase_gifs = sorted(PHASE_TRACER_DIR.glob("*.gif")) |
|
|
|
|
|
if not phase_gifs: |
|
|
st.warning(f"No .gif files found in {PHASE_TRACER_DIR}") |
|
|
else: |
|
|
phase_labels = [p.name for p in phase_gifs] |
|
|
selected_phase = st.selectbox("Select a phase tracer gif", phase_labels) |
|
|
|
|
|
chosen_path = phase_gifs[phase_labels.index(selected_phase)] |
|
|
st.caption(f"Showing: {chosen_path.name}") |
|
|
|
|
|
gif_bytes = chosen_path.read_bytes() |
|
|
b64 = base64.b64encode(gif_bytes).decode("ascii") |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<figure style="margin:0;"> |
|
|
<img src="data:image/gif;base64,{b64}" style="max-width:100%; height:auto;" /> |
|
|
<figcaption style="font-size:0.9em; color: #666;">{chosen_path.name}</figcaption> |
|
|
</figure> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
st.divider() |
|
|
|