File size: 5,086 Bytes
678f3d4
 
 
2694f4f
0c0f584
2694f4f
 
3fbdd2b
2694f4f
 
678f3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f90485e
 
678f3d4
 
bdf7fdf
678f3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e15c23c
 
 
 
 
 
00d17a9
 
e15c23c
 
 
 
 
00d17a9
e15c23c
 
 
 
 
 
 
 
 
 
 
 
678f3d4
 
 
3713d67
e15c23c
 
 
00d17a9
678f3d4
 
 
3713d67
678f3d4
 
 
3efaf3b
678f3d4
e61e9cc
80451ed
0c0f584
3fbdd2b
 
 
 
 
 
 
 
 
 
 
2271f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678f3d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import re
from pathlib import Path
import numpy as np

import pandas as pd
import streamlit as st
import base64


def parse_filename(fname: str):
    """
    Extract:
      - number of features (nf)
      - number of hidden states (nh)
      - sparsity index on log scale = (instance + 1) out of ni

    Example:
      5panel_nf7_nh4_ni10_instance_0_feature.gif
      -> features=7, hidden=4, sparsity=(1, 10)
    """
    pattern = r"nf(\d+)_nh(\d+)_ni(\d+)_instance_(\d+)"
    m = re.search(pattern, fname)
    if not m:
        return None  # ignore non-matching files

    nf = int(m.group(1))
    nh = int(m.group(2))
    ni = int(m.group(3))
    instance = int(m.group(4))

    return {
        "features": nf,
        "hidden_dim": nh,
        "sparsity": instance + 1,      # 1-indexed display
        "sparsity_out_of": ni,
    }


st.set_page_config(layout="wide")

# ----------------------------
# GIF Browser
# ----------------------------
st.header("GIF Browser")

from pathlib import Path

# streamlit_app.py is under src/, and hf_gifs is a sibling of src/
SRC_DIR = Path(__file__).resolve().parent
GIF_DIR = SRC_DIR / "hf_gifs"

if not GIF_DIR.exists():
    st.error(f"Could not find hf_gifs folder at: {GIF_DIR}")
else:
    # Collect gif metadata
    rows = []
    for p in sorted(GIF_DIR.glob("*.gif")):
        meta = parse_filename(p.name)
        if meta is None:
            continue
        rows.append(
            {
                "path": str(p),
                "filename": p.name,
                **meta,
            }
        )

    if not rows:
        st.warning(f"No matching .gif files found in {GIF_DIR}")
    else:
        gifs_df = pd.DataFrame(rows)

        # Single dropdown selecting a 4-tuple (nf, nh, sparsity, out_of)
        tuples_df = (
            gifs_df[["features", "hidden_dim", "sparsity", "sparsity_out_of"]]
            .drop_duplicates()
            .sort_values(["features", "hidden_dim", "sparsity_out_of", "sparsity"])
            .reset_index(drop=True)
        )

        # readable labels
        labels = [
            f"nf={r.features}, nh={r.hidden_dim}, sparsity={r.sparsity}/{r.sparsity_out_of}"
            for r in tuples_df.itertuples(index=False)
        ]

        selected_label = st.selectbox("Select (nf, nh, sparsity/out_of)", labels)

        # decode selection
        idx = labels.index(selected_label)
        sel = tuples_df.iloc[idx]

        selected_features = int(sel["features"])
        selected_hidden = int(sel["hidden_dim"])
        selected_sparsity = int(sel["sparsity"])
        selected_outof = int(sel["sparsity_out_of"])

        # Filter on full 4-tuple
        filtered = gifs_df[
            (gifs_df["features"] == selected_features)
            & (gifs_df["hidden_dim"] == selected_hidden)
            & (gifs_df["sparsity"] == selected_sparsity)
            & (gifs_df["sparsity_out_of"] == selected_outof)
        ].copy()


        if filtered.empty:
            st.info("No GIFs found for that (nf, nh) combination.")
        else:
            filtered = filtered.sort_values(["filename"])

            st.caption(f"Showing {len(filtered)} GIF(s) from: {GIF_DIR}")

            # Render in rows of 2
            items = filtered.to_dict(orient="records")
            for item in items:
                caption = f"{item['sparsity']} out of {item['sparsity_out_of']} for {item['features']} features and {item['hidden_dim']} hidden dimensions"
                gif_bytes = Path(item["path"]).read_bytes()
                b64 = base64.b64encode(gif_bytes).decode("ascii")

                st.markdown(
                    f"""
                    <figure style="margin:0;">
                      <img src="data:image/gif;base64,{b64}" style="max-width:100%; height:auto;" />
                      <figcaption style="font-size:0.9em; color: #666;">{caption}</figcaption>
                    </figure>
                    """,
                    unsafe_allow_html=True,
                )

st.subheader("Phase Tracer Viewer")

PHASE_TRACER_DIR = SRC_DIR / "phase_diagram_tracers"

if not PHASE_TRACER_DIR.exists():
    st.error(f"Could not find phase_tracer_gifs folder at: {PHASE_TRACER_DIR}")
else:
    phase_gifs = sorted(PHASE_TRACER_DIR.glob("*.gif"))

    if not phase_gifs:
        st.warning(f"No .gif files found in {PHASE_TRACER_DIR}")
    else:
        phase_labels = [p.name for p in phase_gifs]
        selected_phase = st.selectbox("Select a phase tracer gif", phase_labels)

        chosen_path = phase_gifs[phase_labels.index(selected_phase)]
        st.caption(f"Showing: {chosen_path.name}")

        gif_bytes = chosen_path.read_bytes()
        b64 = base64.b64encode(gif_bytes).decode("ascii")

        st.markdown(
            f"""
            <figure style="margin:0;">
              <img src="data:image/gif;base64,{b64}" style="max-width:100%; height:auto;" />
              <figcaption style="font-size:0.9em; color: #666;">{chosen_path.name}</figcaption>
            </figure>
            """,
            unsafe_allow_html=True,
        )

st.divider()