File size: 5,086 Bytes
678f3d4 2694f4f 0c0f584 2694f4f 3fbdd2b 2694f4f 678f3d4 f90485e 678f3d4 bdf7fdf 678f3d4 e15c23c 00d17a9 e15c23c 00d17a9 e15c23c 678f3d4 3713d67 e15c23c 00d17a9 678f3d4 3713d67 678f3d4 3efaf3b 678f3d4 e61e9cc 80451ed 0c0f584 3fbdd2b 2271f05 678f3d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
import streamlit as st
import base64
def parse_filename(fname: str):
"""
Extract:
- number of features (nf)
- number of hidden states (nh)
- sparsity index on log scale = (instance + 1) out of ni
Example:
5panel_nf7_nh4_ni10_instance_0_feature.gif
-> features=7, hidden=4, sparsity=(1, 10)
"""
pattern = r"nf(\d+)_nh(\d+)_ni(\d+)_instance_(\d+)"
m = re.search(pattern, fname)
if not m:
return None # ignore non-matching files
nf = int(m.group(1))
nh = int(m.group(2))
ni = int(m.group(3))
instance = int(m.group(4))
return {
"features": nf,
"hidden_dim": nh,
"sparsity": instance + 1, # 1-indexed display
"sparsity_out_of": ni,
}
st.set_page_config(layout="wide")
# ----------------------------
# GIF Browser
# ----------------------------
st.header("GIF Browser")
from pathlib import Path
# streamlit_app.py is under src/, and hf_gifs is a sibling of src/
SRC_DIR = Path(__file__).resolve().parent
GIF_DIR = SRC_DIR / "hf_gifs"
if not GIF_DIR.exists():
st.error(f"Could not find hf_gifs folder at: {GIF_DIR}")
else:
# Collect gif metadata
rows = []
for p in sorted(GIF_DIR.glob("*.gif")):
meta = parse_filename(p.name)
if meta is None:
continue
rows.append(
{
"path": str(p),
"filename": p.name,
**meta,
}
)
if not rows:
st.warning(f"No matching .gif files found in {GIF_DIR}")
else:
gifs_df = pd.DataFrame(rows)
# Single dropdown selecting a 4-tuple (nf, nh, sparsity, out_of)
tuples_df = (
gifs_df[["features", "hidden_dim", "sparsity", "sparsity_out_of"]]
.drop_duplicates()
.sort_values(["features", "hidden_dim", "sparsity_out_of", "sparsity"])
.reset_index(drop=True)
)
# readable labels
labels = [
f"nf={r.features}, nh={r.hidden_dim}, sparsity={r.sparsity}/{r.sparsity_out_of}"
for r in tuples_df.itertuples(index=False)
]
selected_label = st.selectbox("Select (nf, nh, sparsity/out_of)", labels)
# decode selection
idx = labels.index(selected_label)
sel = tuples_df.iloc[idx]
selected_features = int(sel["features"])
selected_hidden = int(sel["hidden_dim"])
selected_sparsity = int(sel["sparsity"])
selected_outof = int(sel["sparsity_out_of"])
# Filter on full 4-tuple
filtered = gifs_df[
(gifs_df["features"] == selected_features)
& (gifs_df["hidden_dim"] == selected_hidden)
& (gifs_df["sparsity"] == selected_sparsity)
& (gifs_df["sparsity_out_of"] == selected_outof)
].copy()
if filtered.empty:
st.info("No GIFs found for that (nf, nh) combination.")
else:
filtered = filtered.sort_values(["filename"])
st.caption(f"Showing {len(filtered)} GIF(s) from: {GIF_DIR}")
# Render in rows of 2
items = filtered.to_dict(orient="records")
for item in items:
caption = f"{item['sparsity']} out of {item['sparsity_out_of']} for {item['features']} features and {item['hidden_dim']} hidden dimensions"
gif_bytes = Path(item["path"]).read_bytes()
b64 = base64.b64encode(gif_bytes).decode("ascii")
st.markdown(
f"""
<figure style="margin:0;">
<img src="data:image/gif;base64,{b64}" style="max-width:100%; height:auto;" />
<figcaption style="font-size:0.9em; color: #666;">{caption}</figcaption>
</figure>
""",
unsafe_allow_html=True,
)
st.subheader("Phase Tracer Viewer")
PHASE_TRACER_DIR = SRC_DIR / "phase_diagram_tracers"
if not PHASE_TRACER_DIR.exists():
st.error(f"Could not find phase_tracer_gifs folder at: {PHASE_TRACER_DIR}")
else:
phase_gifs = sorted(PHASE_TRACER_DIR.glob("*.gif"))
if not phase_gifs:
st.warning(f"No .gif files found in {PHASE_TRACER_DIR}")
else:
phase_labels = [p.name for p in phase_gifs]
selected_phase = st.selectbox("Select a phase tracer gif", phase_labels)
chosen_path = phase_gifs[phase_labels.index(selected_phase)]
st.caption(f"Showing: {chosen_path.name}")
gif_bytes = chosen_path.read_bytes()
b64 = base64.b64encode(gif_bytes).decode("ascii")
st.markdown(
f"""
<figure style="margin:0;">
<img src="data:image/gif;base64,{b64}" style="max-width:100%; height:auto;" />
<figcaption style="font-size:0.9em; color: #666;">{chosen_path.name}</figcaption>
</figure>
""",
unsafe_allow_html=True,
)
st.divider()
|