Marlin Lee
UI redesign: consistent design system, card layout, unified colors and typography
4a5ed93
"""
Image loading, heatmap rendering, and HTML builder helpers.
All functions here are pure (no Bokeh widget dependencies) so they can be
called from worker threads or tested in isolation.
"""
import base64
import io
import os
from concurrent.futures import ThreadPoolExecutor
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from .args import args
from .state import active_ds
# ---------- Thread pool for parallel image loading ----------
_img_pool = ThreadPoolExecutor(max_workers=8)
# ---------- Constants ----------
THUMB = args.thumb_size
# Jet colormap with alpha ramp so low-activation regions are transparent.
def _make_alpha_jet() -> mcolors.LinearSegmentedColormap:
base = plt.cm.get_cmap('jet')
colors = base(np.arange(base.N))
colors[:, -1] = np.linspace(0.0, 1.0, base.N)
return mcolors.LinearSegmentedColormap.from_list('alpha_jet', colors)
ALPHA_JET = _make_alpha_jet()
# ---------- Image loading ----------
def resolve_img_path(stored_path: str) -> str | None:
"""Find a stored image path, searching --image-dir and --extra-image-dir."""
if os.path.isabs(stored_path) and os.path.exists(stored_path):
return stored_path
basename = os.path.basename(stored_path)
for base_dir in filter(None, [args.image_dir] + (args.extra_image_dir or [])):
candidate = os.path.join(base_dir, basename)
if os.path.exists(candidate):
return candidate
if os.path.exists(stored_path):
return stored_path
return None
def load_image_by_path(path: str) -> Image.Image:
"""Open an image file, searching image dirs first."""
resolved = resolve_img_path(path) or path
return Image.open(resolved).convert("RGB")
def load_image(img_idx: int) -> Image.Image:
"""Load image by dataset index using the active dataset's image_paths."""
return load_image_by_path(active_ds()['image_paths'][img_idx])
def parse_img_label(value: str) -> int:
"""Parse an image label into an integer dataset index.
Accepts: exact filename ('nsd_31215.jpg'), bare int ('42'),
or ImageNet-style synset ('n02655020_475').
"""
val = value.strip()
basename_index = active_ds()['basename_index']
key = os.path.splitext(val)[0]
if key in basename_index:
return basename_index[key]
if val in basename_index:
return basename_index[val]
try:
return int(val)
except ValueError:
pass
return int(val.rsplit('_', 1)[-1])
# ---------- Heatmap rendering ----------
def render_heatmap_overlay(img_idx: int, heatmap_16x16,
size: int = THUMB,
cmap=ALPHA_JET,
alpha: float = 1.0) -> Image.Image:
"""Blend a patch-grid heatmap over an image."""
img = load_image(img_idx).resize((size, size), Image.BILINEAR)
base = np.array(img).astype(np.float32) / 255.0
hmap = heatmap_16x16.numpy() if hasattr(heatmap_16x16, 'numpy') else heatmap_16x16
hmap = hmap.astype(np.float32)
hmap_up = cv2.resize(hmap, (size, size), interpolation=cv2.INTER_CUBIC)
hmax = hmap_up.max()
hmap_norm = hmap_up / hmax if hmax > 0 else hmap_up
overlay = cmap(hmap_norm)
ov_alpha = overlay[:, :, 3:4] * alpha
blended = base * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha
return Image.fromarray(np.clip(blended * 255, 0, 255).astype(np.uint8))
def render_zoomed_overlay(img_idx: int, heatmap_16x16,
size: int = THUMB,
pg: int | None = None,
alpha: float = 1.0,
zoom_patches: int | None = None,
center: str = 'peak') -> Image.Image:
"""Heatmap overlay cropped to a zoom window.
zoom_patches controls the neighbourhood size (in patches).
At full zoom (zoom_patches >= pg) the whole image is returned.
center='peak' — window centred on the argmax patch.
center='centroid' — window centred on the activation-weighted centroid.
"""
ds = active_ds()
if pg is None:
pg = ds['heatmap_patch_grid']
if zoom_patches is None:
zoom_patches = pg
hmap = heatmap_16x16.numpy() if hasattr(heatmap_16x16, 'numpy') else heatmap_16x16
# Render at native resolution so the crop is high quality
image_size = ds['image_size']
overlay = render_heatmap_overlay(img_idx, hmap, size=image_size, alpha=alpha)
if zoom_patches >= pg:
return overlay.resize((size, size), Image.BILINEAR)
# Find crop centre
if center == 'centroid':
total = hmap.sum()
if total > 0:
peak_row = int(np.average(np.arange(pg), weights=hmap.sum(axis=1)))
peak_col = int(np.average(np.arange(pg), weights=hmap.sum(axis=0)))
else:
peak_row = peak_col = pg // 2
else:
peak_idx = np.argmax(hmap)
peak_row, peak_col = divmod(int(peak_idx), pg)
patch_px = image_size // pg
half = (zoom_patches * patch_px) // 2
cy = peak_row * patch_px + patch_px // 2
cx = peak_col * patch_px + patch_px // 2
y0 = max(0, cy - half); y1 = min(image_size, cy + half)
x0 = max(0, cx - half); x1 = min(image_size, cx + half)
return overlay.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR)
def pil_to_data_url(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/jpeg;base64,{b64}"
# ---------- Thumbnail cache ----------
_thumb_cache: dict[tuple[int, int], str] = {} # (img_idx, size) → data URL
_THUMB_CACHE_MAX = 4096
def _get_thumb_url(img_idx: int, size: int) -> str | None:
"""Return a cached data URL for a plain (no heatmap) thumbnail, or compute and cache it."""
key = (img_idx, size)
url = _thumb_cache.get(key)
if url is not None:
return url
try:
pil = load_image(img_idx).resize((size, size), Image.BILINEAR)
url = pil_to_data_url(pil)
except Exception:
return None
if len(_thumb_cache) >= _THUMB_CACHE_MAX:
# Evict oldest quarter
for k in list(_thumb_cache)[:_THUMB_CACHE_MAX // 4]:
del _thumb_cache[k]
_thumb_cache[key] = url
return url
def pil_to_bokeh_rgba(pil_img: Image.Image, size: int) -> np.ndarray:
"""Convert PIL image to a uint32 RGBA array suitable for Bokeh image_rgba."""
pil_img = pil_img.resize((size, size), Image.BILINEAR).convert("RGBA")
arr = np.array(pil_img, dtype=np.uint8)
out = np.empty((size, size), dtype=np.uint32)
view = out.view(dtype=np.uint8).reshape((size, size, 4))
view[:, :, :] = arr
return out[::-1].copy()
# ---------- HTML builders ----------
def status_html(state: str, msg: str) -> str:
"""Styled status banner. state ∈ {'idle', 'loading', 'ok', 'dead'}."""
styles = {
'idle': 'background:#f9fafb;border-left:3px solid #d1d5db;color:#6b7280',
'loading': 'background:#fffbeb;border-left:3px solid #f59e0b;color:#92400e',
'ok': 'background:#ecfdf5;border-left:3px solid #10b981;color:#065f46',
'dead': 'background:#fef2f2;border-left:3px solid #ef4444;color:#991b1b',
}
style = styles.get(state, styles['idle'])
return f'<div style="{style};padding:7px 12px;border-radius:6px;font-size:13px">{msg}</div>'
def make_image_grid_html(images_info: list, title: str,
img_indices: list | None = None,
cols: int | None = None) -> str:
"""Flex-wrap grid of thumbnail images with captions.
If img_indices is provided (same length as images_info), each image gets an
onclick that calls window._sae_load_patch_image(idx) to load it into the
patch explorer.
If cols is given, a fixed CSS grid with that many columns is used.
"""
if not images_info:
return (f'<div class="sae-section-title">{title}</div>'
f'<p style="color:#9ca3af;font-style:italic;margin:4px 0">No examples available</p>')
tw = min(THUMB, 224)
if cols is not None:
grid_style = (f'display:grid;grid-template-columns:repeat({cols},{tw}px);'
f'gap:10px;padding:4px 0 10px 0')
else:
grid_style = 'display:flex;flex-wrap:wrap;gap:10px;padding:4px 0 10px 0'
html = (f'<div class="sae-section-title">{title}</div>'
f'<div style="{grid_style}">')
for i, (img, caption) in enumerate(images_info):
url = pil_to_data_url(img)
cap_html = ''.join(f'<div>{p}</div>' for p in caption.split('<br>'))
if img_indices is not None and i < len(img_indices):
idx = img_indices[i]
onclick = (f' onclick="window._sae_load_patch_image({idx})" '
f'style="border:1px solid #e2e5ea;border-radius:8px;'
f'display:block;cursor:pointer;box-shadow:0 1px 2px rgba(0,0,0,0.04);'
f'transition:border-color 0.15s,box-shadow 0.15s"'
f' onmouseover="this.style.borderColor=\'#2563eb\';'
f'this.style.boxShadow=\'0 2px 8px rgba(37,99,235,0.15)\'"'
f' onmouseout="this.style.borderColor=\'#e2e5ea\';'
f'this.style.boxShadow=\'0 1px 2px rgba(0,0,0,0.04)\'"')
else:
onclick = (' style="border:1px solid #e2e5ea;border-radius:8px;display:block;'
'box-shadow:0 1px 2px rgba(0,0,0,0.04)"')
html += (f'<div style="text-align:center;width:{tw}px">'
f'<img src="{url}" width="{tw}" height="{tw}"{onclick}/>'
f'<div style="font-size:10px;color:#6b7280;margin-top:3px;line-height:1.4">'
f'{cap_html}</div></div>')
html += '</div>'
return html
# ---------- Layout helpers ----------
def make_search_result_html(features: list, ds: dict,
n_meis: int = 3, size: int = 80,
max_height: int = 270) -> str:
"""Feature search results: one card per feature with N MEI thumbnails.
Uses NSD sub01 images when available, falling back to full-dataset images.
Each card is clickable and calls window._sae_select_feature(feat).
"""
if not features:
return '<div style="color:#9ca3af;font-style:italic;font-size:12px;padding:8px">No results.</div>'
idx_key = 'nsd_top_img_idx' if ds.get('nsd_top_img_idx') is not None else 'top_img_idx'
# Collect all (feat, j, img_idx) pairs that need thumbnails
work_items = []
for feat in features:
for j in range(n_meis):
img_idx = int(ds[idx_key][feat, j].item()) if ds[idx_key] is not None else -1
if img_idx < 0:
break
work_items.append((feat, j, img_idx))
# Load thumbnails in parallel using the cache
def _load_one(item):
_feat, _j, _img_idx = item
return (item, _get_thumb_url(_img_idx, size))
thumb_urls = {} # (feat, j) → url
for item, url in _img_pool.map(_load_one, work_items):
feat_i, j_i, _ = item
if url is not None:
thumb_urls[(feat_i, j_i)] = url
cards = []
for feat in features:
imgs_html = []
for j in range(n_meis):
url = thumb_urls.get((feat, j))
if url is not None:
imgs_html.append(
f'<img src="{url}" width="{size}" height="{size}" '
f'style="border-radius:6px;display:block;flex-shrink:0"/>')
elif (int(ds[idx_key][feat, j].item()) if ds[idx_key] is not None else -1) >= 0:
imgs_html.append(
f'<div style="width:{size}px;height:{size}px;background:#f3f4f6;'
f'border-radius:6px;flex-shrink:0"></div>')
else:
break
if not imgs_html:
continue
human_label = ds['feature_names'].get(feat) or ''
auto_label = ds['auto_interp_names'].get(feat) or ''
label = human_label or auto_label
label_color = '#2563eb' if human_label else '#059669'
meta_html = (
f'<div style="display:flex;align-items:baseline;gap:5px;margin-top:4px;'
f'max-width:{n_meis*size + 8}px;min-width:0">'
f'<span class="sae-feat-num">#{feat}</span>'
+ (f'<span style="font-size:11px;font-style:italic;color:{label_color};'
f'overflow:hidden;text-overflow:ellipsis;white-space:nowrap">{label}</span>'
if label else '')
+ f'</div>'
)
cards.append(
f'<div onclick="window._sae_select_feature({feat})" '
f'style="cursor:pointer;display:flex;flex-direction:column;'
f'padding:8px 10px;border-radius:8px;border:1px solid #e2e5ea;'
f'margin-bottom:6px;background:#fff;'
f'transition:border-color 0.15s,background 0.15s,box-shadow 0.15s" '
f'onmouseover="this.style.borderColor=\'#2563eb\';this.style.background=\'#eff4ff\';'
f'this.style.boxShadow=\'0 2px 6px rgba(37,99,235,0.1)\'" '
f'onmouseout="this.style.borderColor=\'#e2e5ea\';this.style.background=\'#fff\';'
f'this.style.boxShadow=\'none\'">'
f'<div style="display:flex;gap:4px">{"".join(imgs_html)}</div>'
f'{meta_html}'
f'</div>'
)
return (f'<div style="overflow-y:auto;max-height:{max_height}px;padding:2px">'
f'{"".join(cards)}</div>')
def make_feature_thumb_gallery_html(features: list, ds: dict,
size: int = 72,
page: int = 0,
page_size: int = 50) -> str:
"""Paginated, scrollable grid of MEI thumbnail tiles for the feature gallery.
Each tile has an onclick that calls window._sae_select_feature(feat_idx),
which must be installed via the JS bridge in feature_list.py.
"""
total = len(features)
n_pages = max(1, (total + page_size - 1) // page_size)
page = max(0, min(page, n_pages - 1))
slice_ = features[page * page_size: (page + 1) * page_size]
# Pre-load all thumbnails for this page in parallel
gallery_items = []
for feat in slice_:
img_idx = int(ds['top_img_idx'][feat, 0].item()) if ds['top_img_idx'] is not None else -1
gallery_items.append((feat, img_idx))
def _load_gallery_thumb(item):
_feat, _img_idx = item
if _img_idx >= 0:
return (_feat, _get_thumb_url(_img_idx, size))
return (_feat, None)
gallery_urls = {} # feat → url
for feat_i, url in _img_pool.map(_load_gallery_thumb, gallery_items):
gallery_urls[feat_i] = url
tiles = []
for feat in slice_:
img_idx = int(ds['top_img_idx'][feat, 0].item()) if ds['top_img_idx'] is not None else -1
url = gallery_urls.get(feat)
if url is not None:
img_html = (f'<img src="{url}" width="{size}" height="{size}" '
f'style="border-radius:6px;display:block;'
f'border:2px solid transparent"/>')
elif img_idx >= 0:
img_html = (f'<div style="width:{size}px;height:{size}px;background:#f3f4f6;'
f'border-radius:6px"></div>')
else:
img_html = (f'<div style="width:{size}px;height:{size}px;background:#f3f4f6;'
f'border-radius:6px"></div>')
label = ds['feature_names'].get(feat) or ds['auto_interp_names'].get(feat) or ''
label_html = (f'<div style="font-size:9px;color:#6b7280;max-width:{size}px;'
f'overflow:hidden;text-overflow:ellipsis;white-space:nowrap;'
f'text-align:center">{label}</div>') if label else ''
freq = int(ds['freq'][feat]) if feat < len(ds['freq']) else 0
tiles.append(
f'<div onclick="window._sae_select_feature({feat})" '
f'title="Feature {feat} | freq={freq:,}{chr(10) + label if label else chr(10)}" '
f'style="cursor:pointer;display:inline-block;margin:4px;vertical-align:top;'
f'border-radius:8px;padding:4px;transition:background 0.15s,box-shadow 0.15s" '
f'onmouseover="this.style.background=\'#eff4ff\';'
f'this.style.boxShadow=\'0 1px 4px rgba(37,99,235,0.12)\'" '
f'onmouseout="this.style.background=\'transparent\';'
f'this.style.boxShadow=\'none\'">'
f'{img_html}'
f'<div style="font-family:\'SF Mono\',\'Fira Code\',monospace;font-size:9px;'
f'color:#9ca3af;text-align:center;margin-top:2px">{feat}</div>'
f'{label_html}'
f'</div>'
)
tiles_html = ''.join(tiles)
# Pagination strip
if n_pages > 1:
_pager_btn = ('cursor:pointer;padding:3px 10px;color:#2563eb;'
'border-radius:4px;font-weight:500;transition:background 0.15s')
if page > 0:
prev_btn = (f'<span onclick="window._sae_gallery_page({page - 1})" '
f'style="{_pager_btn}" '
f'onmouseover="this.style.background=\'#eff4ff\'" '
f'onmouseout="this.style.background=\'transparent\'">'
f'&#9664; Prev</span>')
else:
prev_btn = '<span style="color:#d1d5db;padding:3px 10px">&#9664; Prev</span>'
if page < n_pages - 1:
next_btn = (f'<span onclick="window._sae_gallery_page({page + 1})" '
f'style="{_pager_btn}" '
f'onmouseover="this.style.background=\'#eff4ff\'" '
f'onmouseout="this.style.background=\'transparent\'">'
f'Next &#9654;</span>')
else:
next_btn = '<span style="color:#d1d5db;padding:3px 10px">Next &#9654;</span>'
pager = (f'<div style="text-align:center;font-size:11px;color:#6b7280;'
f'padding:6px 0;border-top:1px solid #e2e5ea;margin-top:6px;'
f'display:flex;align-items:center;justify-content:center;gap:4px">'
f'{prev_btn}'
f'<span style="padding:2px 8px">Page {page + 1} / {n_pages}'
f' &nbsp;({total} features)</span>'
f'{next_btn}</div>')
else:
pager = (f'<div style="font-size:10px;color:#9ca3af;text-align:center;padding:4px 0">'
f'{total} feature{"s" if total != 1 else ""}</div>')
return (
f'<div style="overflow-y:auto;max-height:580px;padding:2px">'
f'<div style="display:flex;flex-wrap:wrap">{tiles_html}</div>'
f'{pager}'
f'</div>'
)
def make_active_features_tile_html(feats: list, ds: dict, mei_size: int = 72,
removable: bool = False,
lams: list | None = None) -> str:
"""Steering list: one card per feature showing brain phi map + top 3 MEIs.
If removable=True, each card has remove (✕) and negate (±) buttons, plus a
lambda number input. lams is a parallel list of current lambda values for
pre-filling the inputs.
"""
if not feats:
return ('<div style="color:#9ca3af;font-style:italic;font-size:12px;padding:8px">'
'No active features — select a feature and click + Add to Steer.</div>')
from .brain import _render_phi_map_b64_compact
idx_key = 'nsd_top_img_idx' if ds.get('nsd_top_img_idx') is not None else 'top_img_idx'
lam_map = {feat: lams[i] for i, feat in enumerate(feats)} if lams else {}
cards = []
for feat in feats:
# Brain phi map — same size as each MEI
brain_b64 = _render_phi_map_b64_compact(feat)
if brain_b64:
brain_html = (f'<img src="data:image/png;base64,{brain_b64}" '
f'width="{mei_size}" height="{mei_size}" '
f'style="object-fit:cover;border-radius:6px;'
f'border:1px solid #e2e5ea;flex-shrink:0"/>')
else:
brain_html = (f'<div style="width:{mei_size}px;height:{mei_size}px;'
f'background:#f9fafb;border-radius:6px;border:1px dashed #d1d5db;'
f'display:flex;align-items:center;justify-content:center;'
f'color:#d1d5db;font-size:10px;flex-shrink:0">no φ</div>')
# Top 3 MEIs in a row at the same size (use thumbnail cache)
meis_html = []
for j in range(3):
img_idx = int(ds[idx_key][feat, j].item()) if ds[idx_key] is not None else -1
if img_idx >= 0:
url = _get_thumb_url(img_idx, mei_size)
if url:
meis_html.append(
f'<img src="{url}" width="{mei_size}" height="{mei_size}" '
f'style="border-radius:6px;flex-shrink:0;border:1px solid #e2e5ea"/>')
else:
meis_html.append(
f'<div style="width:{mei_size}px;height:{mei_size}px;'
f'background:#f3f4f6;border-radius:6px;flex-shrink:0"></div>')
else:
break
feat_num = str(feat)
label = ds['feature_names'].get(feat) or ds['auto_interp_names'].get(feat) or ''
label_html = (f'<div style="font-size:10px;font-style:italic;color:#2563eb;'
f'margin-top:3px;text-align:center;overflow:hidden;'
f'text-overflow:ellipsis;white-space:nowrap;max-width:'
f'{mei_size * 4 + 20}px">{label}</div>') if label else ''
# Left column: ✕ and λ stacked (only when removable)
if removable:
lam_val = lam_map.get(feat, 3.0)
left_col = (
f'<div style="display:flex;flex-direction:column;align-items:center;'
f'gap:6px;justify-content:center;padding-right:4px">'
f'<span onclick="event.stopPropagation();window._dd_feat_action(\'remove_feat:{feat}\')" '
f'style="color:#dc2626;cursor:pointer;font-size:14px;font-weight:bold;'
f'line-height:1;user-select:none;width:22px;height:22px;'
f'display:flex;align-items:center;justify-content:center;'
f'border-radius:4px;transition:background 0.15s" '
f'onmouseover="this.style.background=\'#fef2f2\'" '
f'onmouseout="this.style.background=\'transparent\'" '
f'title="Remove">✕</span>'
f'<div style="display:flex;flex-direction:column;align-items:center;gap:1px">'
f'<span style="font-size:10px;font-weight:600;color:#9ca3af">λ</span>'
f'<input type="text" value="{lam_val:.2g}" '
f'style="width:38px;font-size:14px;font-weight:600;text-align:center;'
f'color:#2563eb;border:1px solid transparent;border-radius:4px;padding:2px;'
f'background:transparent;outline:none;cursor:default;'
f'transition:border-color 0.15s,background 0.15s" '
f'onfocus="this.style.background=\'#fff\';this.style.borderColor=\'#2563eb\';this.style.cursor=\'text\'" '
f'onblur="this.style.background=\'transparent\';this.style.borderColor=\'transparent\';this.style.cursor=\'default\'" '
f'onclick="event.stopPropagation()" '
f'onchange="window._dd_feat_action(\'set_lam:{feat}:\'+this.value)" '
f'title="Steering strength λ"/>'
f'</div>'
f'</div>'
)
else:
left_col = ''
cards.append(
f'<div onclick="window._sae_select_feature({feat})" '
f'style="cursor:pointer;display:inline-flex;flex-direction:column;'
f'align-items:center;gap:4px;'
f'padding:8px;border-radius:8px;border:1px solid #e2e5ea;'
f'margin:4px;background:#fff;vertical-align:top;'
f'box-shadow:0 1px 2px rgba(0,0,0,0.04);'
f'transition:border-color 0.15s,background 0.15s,box-shadow 0.15s" '
f'onmouseover="this.style.borderColor=\'#2563eb\';this.style.background=\'#eff4ff\';'
f'this.style.boxShadow=\'0 2px 8px rgba(37,99,235,0.12)\'" '
f'onmouseout="this.style.borderColor=\'#e2e5ea\';this.style.background=\'#fff\';'
f'this.style.boxShadow=\'0 1px 2px rgba(0,0,0,0.04)\'">'
f'<span class="sae-feat-num">#{feat_num}</span>'
f'<div style="display:flex;gap:4px;align-items:center">'
f'{left_col}{brain_html}{"".join(meis_html)}'
f'</div>'
f'{label_html}'
f'</div>'
)
return f'<div style="display:flex;flex-wrap:wrap;padding:2px">{"".join(cards)}</div>'