""" Image loading, heatmap rendering, and HTML builder helpers. All functions here are pure (no Bokeh widget dependencies) so they can be called from worker threads or tested in isolation. """ import base64 import io import os from concurrent.futures import ThreadPoolExecutor import cv2 import matplotlib matplotlib.use('Agg') import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np from PIL import Image from .args import args from .state import active_ds # ---------- Thread pool for parallel image loading ---------- _img_pool = ThreadPoolExecutor(max_workers=8) # ---------- Constants ---------- THUMB = args.thumb_size # Jet colormap with alpha ramp so low-activation regions are transparent. def _make_alpha_jet() -> mcolors.LinearSegmentedColormap: base = plt.cm.get_cmap('jet') colors = base(np.arange(base.N)) colors[:, -1] = np.linspace(0.0, 1.0, base.N) return mcolors.LinearSegmentedColormap.from_list('alpha_jet', colors) ALPHA_JET = _make_alpha_jet() # ---------- Image loading ---------- def resolve_img_path(stored_path: str) -> str | None: """Find a stored image path, searching --image-dir and --extra-image-dir.""" if os.path.isabs(stored_path) and os.path.exists(stored_path): return stored_path basename = os.path.basename(stored_path) for base_dir in filter(None, [args.image_dir] + (args.extra_image_dir or [])): candidate = os.path.join(base_dir, basename) if os.path.exists(candidate): return candidate if os.path.exists(stored_path): return stored_path return None def load_image_by_path(path: str) -> Image.Image: """Open an image file, searching image dirs first.""" resolved = resolve_img_path(path) or path return Image.open(resolved).convert("RGB") def load_image(img_idx: int) -> Image.Image: """Load image by dataset index using the active dataset's image_paths.""" return load_image_by_path(active_ds()['image_paths'][img_idx]) def parse_img_label(value: str) -> int: """Parse an image label into an integer dataset index. Accepts: exact filename ('nsd_31215.jpg'), bare int ('42'), or ImageNet-style synset ('n02655020_475'). """ val = value.strip() basename_index = active_ds()['basename_index'] key = os.path.splitext(val)[0] if key in basename_index: return basename_index[key] if val in basename_index: return basename_index[val] try: return int(val) except ValueError: pass return int(val.rsplit('_', 1)[-1]) # ---------- Heatmap rendering ---------- def render_heatmap_overlay(img_idx: int, heatmap_16x16, size: int = THUMB, cmap=ALPHA_JET, alpha: float = 1.0) -> Image.Image: """Blend a patch-grid heatmap over an image.""" img = load_image(img_idx).resize((size, size), Image.BILINEAR) base = np.array(img).astype(np.float32) / 255.0 hmap = heatmap_16x16.numpy() if hasattr(heatmap_16x16, 'numpy') else heatmap_16x16 hmap = hmap.astype(np.float32) hmap_up = cv2.resize(hmap, (size, size), interpolation=cv2.INTER_CUBIC) hmax = hmap_up.max() hmap_norm = hmap_up / hmax if hmax > 0 else hmap_up overlay = cmap(hmap_norm) ov_alpha = overlay[:, :, 3:4] * alpha blended = base * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha return Image.fromarray(np.clip(blended * 255, 0, 255).astype(np.uint8)) def render_zoomed_overlay(img_idx: int, heatmap_16x16, size: int = THUMB, pg: int | None = None, alpha: float = 1.0, zoom_patches: int | None = None, center: str = 'peak') -> Image.Image: """Heatmap overlay cropped to a zoom window. zoom_patches controls the neighbourhood size (in patches). At full zoom (zoom_patches >= pg) the whole image is returned. center='peak' — window centred on the argmax patch. center='centroid' — window centred on the activation-weighted centroid. """ ds = active_ds() if pg is None: pg = ds['heatmap_patch_grid'] if zoom_patches is None: zoom_patches = pg hmap = heatmap_16x16.numpy() if hasattr(heatmap_16x16, 'numpy') else heatmap_16x16 # Render at native resolution so the crop is high quality image_size = ds['image_size'] overlay = render_heatmap_overlay(img_idx, hmap, size=image_size, alpha=alpha) if zoom_patches >= pg: return overlay.resize((size, size), Image.BILINEAR) # Find crop centre if center == 'centroid': total = hmap.sum() if total > 0: peak_row = int(np.average(np.arange(pg), weights=hmap.sum(axis=1))) peak_col = int(np.average(np.arange(pg), weights=hmap.sum(axis=0))) else: peak_row = peak_col = pg // 2 else: peak_idx = np.argmax(hmap) peak_row, peak_col = divmod(int(peak_idx), pg) patch_px = image_size // pg half = (zoom_patches * patch_px) // 2 cy = peak_row * patch_px + patch_px // 2 cx = peak_col * patch_px + patch_px // 2 y0 = max(0, cy - half); y1 = min(image_size, cy + half) x0 = max(0, cx - half); x1 = min(image_size, cx + half) return overlay.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR) def pil_to_data_url(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="JPEG", quality=85) b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/jpeg;base64,{b64}" # ---------- Thumbnail cache ---------- _thumb_cache: dict[tuple[int, int], str] = {} # (img_idx, size) → data URL _THUMB_CACHE_MAX = 4096 def _get_thumb_url(img_idx: int, size: int) -> str | None: """Return a cached data URL for a plain (no heatmap) thumbnail, or compute and cache it.""" key = (img_idx, size) url = _thumb_cache.get(key) if url is not None: return url try: pil = load_image(img_idx).resize((size, size), Image.BILINEAR) url = pil_to_data_url(pil) except Exception: return None if len(_thumb_cache) >= _THUMB_CACHE_MAX: # Evict oldest quarter for k in list(_thumb_cache)[:_THUMB_CACHE_MAX // 4]: del _thumb_cache[k] _thumb_cache[key] = url return url def pil_to_bokeh_rgba(pil_img: Image.Image, size: int) -> np.ndarray: """Convert PIL image to a uint32 RGBA array suitable for Bokeh image_rgba.""" pil_img = pil_img.resize((size, size), Image.BILINEAR).convert("RGBA") arr = np.array(pil_img, dtype=np.uint8) out = np.empty((size, size), dtype=np.uint32) view = out.view(dtype=np.uint8).reshape((size, size, 4)) view[:, :, :] = arr return out[::-1].copy() # ---------- HTML builders ---------- def status_html(state: str, msg: str) -> str: """Styled status banner. state ∈ {'idle', 'loading', 'ok', 'dead'}.""" styles = { 'idle': 'background:#f9fafb;border-left:3px solid #d1d5db;color:#6b7280', 'loading': 'background:#fffbeb;border-left:3px solid #f59e0b;color:#92400e', 'ok': 'background:#ecfdf5;border-left:3px solid #10b981;color:#065f46', 'dead': 'background:#fef2f2;border-left:3px solid #ef4444;color:#991b1b', } style = styles.get(state, styles['idle']) return f'
{msg}
' def make_image_grid_html(images_info: list, title: str, img_indices: list | None = None, cols: int | None = None) -> str: """Flex-wrap grid of thumbnail images with captions. If img_indices is provided (same length as images_info), each image gets an onclick that calls window._sae_load_patch_image(idx) to load it into the patch explorer. If cols is given, a fixed CSS grid with that many columns is used. """ if not images_info: return (f'
{title}
' f'

No examples available

') tw = min(THUMB, 224) if cols is not None: grid_style = (f'display:grid;grid-template-columns:repeat({cols},{tw}px);' f'gap:10px;padding:4px 0 10px 0') else: grid_style = 'display:flex;flex-wrap:wrap;gap:10px;padding:4px 0 10px 0' html = (f'
{title}
' f'
') for i, (img, caption) in enumerate(images_info): url = pil_to_data_url(img) cap_html = ''.join(f'
{p}
' for p in caption.split('
')) if img_indices is not None and i < len(img_indices): idx = img_indices[i] onclick = (f' onclick="window._sae_load_patch_image({idx})" ' f'style="border:1px solid #e2e5ea;border-radius:8px;' f'display:block;cursor:pointer;box-shadow:0 1px 2px rgba(0,0,0,0.04);' f'transition:border-color 0.15s,box-shadow 0.15s"' f' onmouseover="this.style.borderColor=\'#2563eb\';' f'this.style.boxShadow=\'0 2px 8px rgba(37,99,235,0.15)\'"' f' onmouseout="this.style.borderColor=\'#e2e5ea\';' f'this.style.boxShadow=\'0 1px 2px rgba(0,0,0,0.04)\'"') else: onclick = (' style="border:1px solid #e2e5ea;border-radius:8px;display:block;' 'box-shadow:0 1px 2px rgba(0,0,0,0.04)"') html += (f'
' f'' f'
' f'{cap_html}
') html += '
' return html # ---------- Layout helpers ---------- def make_search_result_html(features: list, ds: dict, n_meis: int = 3, size: int = 80, max_height: int = 270) -> str: """Feature search results: one card per feature with N MEI thumbnails. Uses NSD sub01 images when available, falling back to full-dataset images. Each card is clickable and calls window._sae_select_feature(feat). """ if not features: return '
No results.
' idx_key = 'nsd_top_img_idx' if ds.get('nsd_top_img_idx') is not None else 'top_img_idx' # Collect all (feat, j, img_idx) pairs that need thumbnails work_items = [] for feat in features: for j in range(n_meis): img_idx = int(ds[idx_key][feat, j].item()) if ds[idx_key] is not None else -1 if img_idx < 0: break work_items.append((feat, j, img_idx)) # Load thumbnails in parallel using the cache def _load_one(item): _feat, _j, _img_idx = item return (item, _get_thumb_url(_img_idx, size)) thumb_urls = {} # (feat, j) → url for item, url in _img_pool.map(_load_one, work_items): feat_i, j_i, _ = item if url is not None: thumb_urls[(feat_i, j_i)] = url cards = [] for feat in features: imgs_html = [] for j in range(n_meis): url = thumb_urls.get((feat, j)) if url is not None: imgs_html.append( f'') elif (int(ds[idx_key][feat, j].item()) if ds[idx_key] is not None else -1) >= 0: imgs_html.append( f'
') else: break if not imgs_html: continue human_label = ds['feature_names'].get(feat) or '' auto_label = ds['auto_interp_names'].get(feat) or '' label = human_label or auto_label label_color = '#2563eb' if human_label else '#059669' meta_html = ( f'
' f'#{feat}' + (f'{label}' if label else '') + f'
' ) cards.append( f'
' f'
{"".join(imgs_html)}
' f'{meta_html}' f'
' ) return (f'
' f'{"".join(cards)}
') def make_feature_thumb_gallery_html(features: list, ds: dict, size: int = 72, page: int = 0, page_size: int = 50) -> str: """Paginated, scrollable grid of MEI thumbnail tiles for the feature gallery. Each tile has an onclick that calls window._sae_select_feature(feat_idx), which must be installed via the JS bridge in feature_list.py. """ total = len(features) n_pages = max(1, (total + page_size - 1) // page_size) page = max(0, min(page, n_pages - 1)) slice_ = features[page * page_size: (page + 1) * page_size] # Pre-load all thumbnails for this page in parallel gallery_items = [] for feat in slice_: img_idx = int(ds['top_img_idx'][feat, 0].item()) if ds['top_img_idx'] is not None else -1 gallery_items.append((feat, img_idx)) def _load_gallery_thumb(item): _feat, _img_idx = item if _img_idx >= 0: return (_feat, _get_thumb_url(_img_idx, size)) return (_feat, None) gallery_urls = {} # feat → url for feat_i, url in _img_pool.map(_load_gallery_thumb, gallery_items): gallery_urls[feat_i] = url tiles = [] for feat in slice_: img_idx = int(ds['top_img_idx'][feat, 0].item()) if ds['top_img_idx'] is not None else -1 url = gallery_urls.get(feat) if url is not None: img_html = (f'') elif img_idx >= 0: img_html = (f'
') else: img_html = (f'
') label = ds['feature_names'].get(feat) or ds['auto_interp_names'].get(feat) or '' label_html = (f'
{label}
') if label else '' freq = int(ds['freq'][feat]) if feat < len(ds['freq']) else 0 tiles.append( f'
' f'{img_html}' f'
{feat}
' f'{label_html}' f'
' ) tiles_html = ''.join(tiles) # Pagination strip if n_pages > 1: _pager_btn = ('cursor:pointer;padding:3px 10px;color:#2563eb;' 'border-radius:4px;font-weight:500;transition:background 0.15s') if page > 0: prev_btn = (f'' f'◀ Prev') else: prev_btn = '◀ Prev' if page < n_pages - 1: next_btn = (f'' f'Next ▶') else: next_btn = 'Next ▶' pager = (f'
' f'{prev_btn}' f'Page {page + 1} / {n_pages}' f'  ({total} features)' f'{next_btn}
') else: pager = (f'
' f'{total} feature{"s" if total != 1 else ""}
') return ( f'
' f'
{tiles_html}
' f'{pager}' f'
' ) def make_active_features_tile_html(feats: list, ds: dict, mei_size: int = 72, removable: bool = False, lams: list | None = None) -> str: """Steering list: one card per feature showing brain phi map + top 3 MEIs. If removable=True, each card has remove (✕) and negate (±) buttons, plus a lambda number input. lams is a parallel list of current lambda values for pre-filling the inputs. """ if not feats: return ('
' 'No active features — select a feature and click + Add to Steer.
') from .brain import _render_phi_map_b64_compact idx_key = 'nsd_top_img_idx' if ds.get('nsd_top_img_idx') is not None else 'top_img_idx' lam_map = {feat: lams[i] for i, feat in enumerate(feats)} if lams else {} cards = [] for feat in feats: # Brain phi map — same size as each MEI brain_b64 = _render_phi_map_b64_compact(feat) if brain_b64: brain_html = (f'') else: brain_html = (f'
no φ
') # Top 3 MEIs in a row at the same size (use thumbnail cache) meis_html = [] for j in range(3): img_idx = int(ds[idx_key][feat, j].item()) if ds[idx_key] is not None else -1 if img_idx >= 0: url = _get_thumb_url(img_idx, mei_size) if url: meis_html.append( f'') else: meis_html.append( f'
') else: break feat_num = str(feat) label = ds['feature_names'].get(feat) or ds['auto_interp_names'].get(feat) or '' label_html = (f'
{label}
') if label else '' # Left column: ✕ and λ stacked (only when removable) if removable: lam_val = lam_map.get(feat, 3.0) left_col = ( f'
' f'' f'
' f'λ' f'' f'
' f'
' ) else: left_col = '' cards.append( f'
' f'#{feat_num}' f'
' f'{left_col}{brain_html}{"".join(meis_html)}' f'
' f'{label_html}' f'
' ) return f'
{"".join(cards)}
'