Spaces:
Running
Running
add buttons
Browse files- src/streamlit_app.py +66 -7
src/streamlit_app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
Streamlit app for interactive
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
import gc
|
| 6 |
import os
|
| 7 |
import sys
|
|
@@ -19,13 +20,33 @@ from torch import nn
|
|
| 19 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 20 |
|
| 21 |
# Add current directory to path for JCBScope_utils
|
| 22 |
-
|
|
|
|
| 23 |
import JCBScope_utils
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Device configuration: use CPU to match notebook and avoid device_map complexity
|
| 26 |
device = torch.device("cpu")
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
@st.cache_resource
|
| 30 |
def load_model(model_name: str = "meta-llama/Llama-3.2-1B"):
|
| 31 |
"""Load and cache the tokenizer and model."""
|
|
@@ -347,13 +368,50 @@ def render_attribution_barplot(result, log_color: bool = False, cmap_name: str =
|
|
| 347 |
|
| 348 |
|
| 349 |
def main():
|
| 350 |
-
st.set_page_config(page_title="Jacobian
|
| 351 |
-
st.title("π Jacobian
|
|
|
|
| 352 |
st.markdown(
|
| 353 |
-
"
|
| 354 |
-
"
|
| 355 |
-
|
| 356 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
model_choice = st.selectbox(
|
| 359 |
"Model",
|
|
@@ -374,6 +432,7 @@ def main():
|
|
| 374 |
options=["Semantic Scope", "Temperature Scope"],
|
| 375 |
index=0,
|
| 376 |
horizontal=True,
|
|
|
|
| 377 |
help="Semantic Scope: attribute toward a target token. Temperature Scope: use hidden-state norm.",
|
| 378 |
)
|
| 379 |
mode = "Semantic" if attribution_type == "Semantic Scope" else "Temperature"
|
|
|
|
| 1 |
"""
|
| 2 |
+
Streamlit app for interactive Semantic and Temperature Scope visualizations.
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import base64
|
| 6 |
import gc
|
| 7 |
import os
|
| 8 |
import sys
|
|
|
|
| 20 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 21 |
|
| 22 |
# Add current directory to path for JCBScope_utils
|
| 23 |
+
_APP_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
sys.path.insert(0, _APP_DIR)
|
| 25 |
import JCBScope_utils
|
| 26 |
|
| 27 |
+
DESIGN_DIR = os.path.join(_APP_DIR, "design")
|
| 28 |
+
if not os.path.exists(DESIGN_DIR):
|
| 29 |
+
DESIGN_DIR = os.path.join(os.path.dirname(_APP_DIR), "design")
|
| 30 |
+
|
| 31 |
# Device configuration: use CPU to match notebook and avoid device_map complexity
|
| 32 |
device = torch.device("cpu")
|
| 33 |
|
| 34 |
|
| 35 |
+
@st.cache_data
|
| 36 |
+
def _load_svg(path: str) -> str | None:
|
| 37 |
+
"""Load SVG file content; returns None if not found."""
|
| 38 |
+
if not os.path.exists(path):
|
| 39 |
+
return None
|
| 40 |
+
with open(path, encoding="utf-8") as f:
|
| 41 |
+
return f.read()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _render_svg_html(svg_content: str, max_width: int = 140) -> str:
|
| 45 |
+
"""Return HTML to render SVG via base64 (reliable in Streamlit)."""
|
| 46 |
+
b64 = base64.b64encode(svg_content.encode("utf-8")).decode("utf-8")
|
| 47 |
+
return f'<img src="data:image/svg+xml;base64,{b64}" style="max-width:{max_width}px;height:auto;"/>'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
@st.cache_resource
|
| 51 |
def load_model(model_name: str = "meta-llama/Llama-3.2-1B"):
|
| 52 |
"""Load and cache the tokenizer and model."""
|
|
|
|
| 368 |
|
| 369 |
|
| 370 |
def main():
|
| 371 |
+
st.set_page_config(page_title="Jacobian Scopes Demo", page_icon="π", layout="centered")
|
| 372 |
+
st.title("π Jacobian Scopes Demo")
|
| 373 |
+
# Keep scope columns on one line (Streamlit stacks them below ~640px by default)
|
| 374 |
st.markdown(
|
| 375 |
+
'<style>div[data-testid="stHorizontalBlock"]{flex-wrap:nowrap!important}'
|
| 376 |
+
'[data-testid="column"]{min-width:120px!important}</style>',
|
| 377 |
+
unsafe_allow_html=True,
|
| 378 |
)
|
| 379 |
+
scope_col1, scope_div1, scope_col2, scope_div2, scope_col3 = st.columns([1, 0.02, 1, 0.02, 1])
|
| 380 |
+
semantic_svg = _load_svg(os.path.join(DESIGN_DIR, "semantic_scope_button.svg"))
|
| 381 |
+
temp_svg = _load_svg(os.path.join(DESIGN_DIR, "temperature_scope_button.svg"))
|
| 382 |
+
fisher_svg = _load_svg(os.path.join(DESIGN_DIR, "fisher_scope_button.svg"))
|
| 383 |
+
with scope_col1:
|
| 384 |
+
if semantic_svg:
|
| 385 |
+
st.markdown(_render_svg_html(semantic_svg), unsafe_allow_html=True)
|
| 386 |
+
st.markdown(
|
| 387 |
+
"**Semantic Scope** β explains the predicted logit for a specific target token. "
|
| 388 |
+
"Enter your input passage along with a target token."
|
| 389 |
+
)
|
| 390 |
+
with scope_div1:
|
| 391 |
+
st.markdown(
|
| 392 |
+
'<div style="border-left: 5px solid #888; min-height: 200px; margin: 0;"></div>',
|
| 393 |
+
# '<div style="border-left: 5px solid steelblue; min-height: 160px; margin: 0;"></div>',
|
| 394 |
+
unsafe_allow_html=True,
|
| 395 |
+
)
|
| 396 |
+
with scope_col2:
|
| 397 |
+
if temp_svg:
|
| 398 |
+
st.markdown(_render_svg_html(temp_svg), unsafe_allow_html=True)
|
| 399 |
+
st.markdown(
|
| 400 |
+
"**Temperature Scope** β explains the confidence (effective inverse temperature) of the predictive distribution. "
|
| 401 |
+
"Target token not required."
|
| 402 |
+
)
|
| 403 |
+
with scope_div2:
|
| 404 |
+
st.markdown(
|
| 405 |
+
'<div style="border-left: 5px solid #888; min-height: 200px; margin: 0;"></div>',
|
| 406 |
+
unsafe_allow_html=True,
|
| 407 |
+
)
|
| 408 |
+
with scope_col3:
|
| 409 |
+
if fisher_svg:
|
| 410 |
+
st.markdown(_render_svg_html(fisher_svg), unsafe_allow_html=True)
|
| 411 |
+
st.markdown(
|
| 412 |
+
"**Fisher Scope** β A more refined attribution tool that explains the overall predictive distribution, motivated by information geometry. "
|
| 413 |
+
"Not shown in this demo due to limited compute."
|
| 414 |
+
)
|
| 415 |
|
| 416 |
model_choice = st.selectbox(
|
| 417 |
"Model",
|
|
|
|
| 432 |
options=["Semantic Scope", "Temperature Scope"],
|
| 433 |
index=0,
|
| 434 |
horizontal=True,
|
| 435 |
+
key="attribution_type",
|
| 436 |
help="Semantic Scope: attribute toward a target token. Temperature Scope: use hidden-state norm.",
|
| 437 |
)
|
| 438 |
mode = "Semantic" if attribution_type == "Semantic Scope" else "Temperature"
|