Typony commited on
Commit
cb30405
Β·
verified Β·
1 Parent(s): 0d820e3

add buttons

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +66 -7
src/streamlit_app.py CHANGED
@@ -1,7 +1,8 @@
1
  """
2
- Streamlit app for interactive Jacobian and Temperature Scope visualizations.
3
  """
4
 
 
5
  import gc
6
  import os
7
  import sys
@@ -19,13 +20,33 @@ from torch import nn
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
 
21
  # Add current directory to path for JCBScope_utils
22
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
23
  import JCBScope_utils
24
 
 
 
 
 
25
  # Device configuration: use CPU to match notebook and avoid device_map complexity
26
  device = torch.device("cpu")
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @st.cache_resource
30
  def load_model(model_name: str = "meta-llama/Llama-3.2-1B"):
31
  """Load and cache the tokenizer and model."""
@@ -347,13 +368,50 @@ def render_attribution_barplot(result, log_color: bool = False, cmap_name: str =
347
 
348
 
349
  def main():
350
- st.set_page_config(page_title="Jacobian Scope Demo", page_icon="πŸ”", layout="centered")
351
- st.title("πŸ” Jacobian & Temperature Scopes")
 
352
  st.markdown(
353
- "**Semantic Scope** explains the predicted logit for a specific target token: enter your input "
354
- "passage along with a target token.\n\n"
355
- "**Temperature Scope** explains the overall predictive distribution and does not require a target."
356
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  model_choice = st.selectbox(
359
  "Model",
@@ -374,6 +432,7 @@ def main():
374
  options=["Semantic Scope", "Temperature Scope"],
375
  index=0,
376
  horizontal=True,
 
377
  help="Semantic Scope: attribute toward a target token. Temperature Scope: use hidden-state norm.",
378
  )
379
  mode = "Semantic" if attribution_type == "Semantic Scope" else "Temperature"
 
1
  """
2
+ Streamlit app for interactive Semantic and Temperature Scope visualizations.
3
  """
4
 
5
+ import base64
6
  import gc
7
  import os
8
  import sys
 
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
  # Add current directory to path for JCBScope_utils
23
+ _APP_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.insert(0, _APP_DIR)
25
  import JCBScope_utils
26
 
27
+ DESIGN_DIR = os.path.join(_APP_DIR, "design")
28
+ if not os.path.exists(DESIGN_DIR):
29
+ DESIGN_DIR = os.path.join(os.path.dirname(_APP_DIR), "design")
30
+
31
  # Device configuration: use CPU to match notebook and avoid device_map complexity
32
  device = torch.device("cpu")
33
 
34
 
35
+ @st.cache_data
36
+ def _load_svg(path: str) -> str | None:
37
+ """Load SVG file content; returns None if not found."""
38
+ if not os.path.exists(path):
39
+ return None
40
+ with open(path, encoding="utf-8") as f:
41
+ return f.read()
42
+
43
+
44
+ def _render_svg_html(svg_content: str, max_width: int = 140) -> str:
45
+ """Return HTML to render SVG via base64 (reliable in Streamlit)."""
46
+ b64 = base64.b64encode(svg_content.encode("utf-8")).decode("utf-8")
47
+ return f'<img src="data:image/svg+xml;base64,{b64}" style="max-width:{max_width}px;height:auto;"/>'
48
+
49
+
50
  @st.cache_resource
51
  def load_model(model_name: str = "meta-llama/Llama-3.2-1B"):
52
  """Load and cache the tokenizer and model."""
 
368
 
369
 
370
  def main():
371
+ st.set_page_config(page_title="Jacobian Scopes Demo", page_icon="πŸ”", layout="centered")
372
+ st.title("πŸ” Jacobian Scopes Demo")
373
+ # Keep scope columns on one line (Streamlit stacks them below ~640px by default)
374
  st.markdown(
375
+ '<style>div[data-testid="stHorizontalBlock"]{flex-wrap:nowrap!important}'
376
+ '[data-testid="column"]{min-width:120px!important}</style>',
377
+ unsafe_allow_html=True,
378
  )
379
+ scope_col1, scope_div1, scope_col2, scope_div2, scope_col3 = st.columns([1, 0.02, 1, 0.02, 1])
380
+ semantic_svg = _load_svg(os.path.join(DESIGN_DIR, "semantic_scope_button.svg"))
381
+ temp_svg = _load_svg(os.path.join(DESIGN_DIR, "temperature_scope_button.svg"))
382
+ fisher_svg = _load_svg(os.path.join(DESIGN_DIR, "fisher_scope_button.svg"))
383
+ with scope_col1:
384
+ if semantic_svg:
385
+ st.markdown(_render_svg_html(semantic_svg), unsafe_allow_html=True)
386
+ st.markdown(
387
+ "**Semantic Scope** β€” explains the predicted logit for a specific target token. "
388
+ "Enter your input passage along with a target token."
389
+ )
390
+ with scope_div1:
391
+ st.markdown(
392
+ '<div style="border-left: 5px solid #888; min-height: 200px; margin: 0;"></div>',
393
+ # '<div style="border-left: 5px solid steelblue; min-height: 160px; margin: 0;"></div>',
394
+ unsafe_allow_html=True,
395
+ )
396
+ with scope_col2:
397
+ if temp_svg:
398
+ st.markdown(_render_svg_html(temp_svg), unsafe_allow_html=True)
399
+ st.markdown(
400
+ "**Temperature Scope** β€” explains the confidence (effective inverse temperature) of the predictive distribution. "
401
+ "Target token not required."
402
+ )
403
+ with scope_div2:
404
+ st.markdown(
405
+ '<div style="border-left: 5px solid #888; min-height: 200px; margin: 0;"></div>',
406
+ unsafe_allow_html=True,
407
+ )
408
+ with scope_col3:
409
+ if fisher_svg:
410
+ st.markdown(_render_svg_html(fisher_svg), unsafe_allow_html=True)
411
+ st.markdown(
412
+ "**Fisher Scope** β€” A more refined attribution tool that explains the overall predictive distribution, motivated by information geometry. "
413
+ "Not shown in this demo due to limited compute."
414
+ )
415
 
416
  model_choice = st.selectbox(
417
  "Model",
 
432
  options=["Semantic Scope", "Temperature Scope"],
433
  index=0,
434
  horizontal=True,
435
+ key="attribution_type",
436
  help="Semantic Scope: attribute toward a target token. Temperature Scope: use hidden-state norm.",
437
  )
438
  mode = "Semantic" if attribution_type == "Semantic Scope" else "Temperature"