""" Image loading, heatmap rendering, and HTML builder helpers. All functions here are pure (no Bokeh widget dependencies) so they can be called from worker threads or tested in isolation. """ import base64 import io import os from concurrent.futures import ThreadPoolExecutor import cv2 import matplotlib matplotlib.use('Agg') import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np from PIL import Image from .args import args from .state import active_ds # ---------- Thread pool for parallel image loading ---------- _img_pool = ThreadPoolExecutor(max_workers=8) # ---------- Constants ---------- THUMB = args.thumb_size # Jet colormap with alpha ramp so low-activation regions are transparent. def _make_alpha_jet() -> mcolors.LinearSegmentedColormap: base = plt.cm.get_cmap('jet') colors = base(np.arange(base.N)) colors[:, -1] = np.linspace(0.0, 1.0, base.N) return mcolors.LinearSegmentedColormap.from_list('alpha_jet', colors) ALPHA_JET = _make_alpha_jet() # ---------- Image loading ---------- def resolve_img_path(stored_path: str) -> str | None: """Find a stored image path, searching --image-dir and --extra-image-dir.""" if os.path.isabs(stored_path) and os.path.exists(stored_path): return stored_path basename = os.path.basename(stored_path) for base_dir in filter(None, [args.image_dir] + (args.extra_image_dir or [])): candidate = os.path.join(base_dir, basename) if os.path.exists(candidate): return candidate if os.path.exists(stored_path): return stored_path return None def load_image_by_path(path: str) -> Image.Image: """Open an image file, searching image dirs first.""" resolved = resolve_img_path(path) or path return Image.open(resolved).convert("RGB") def load_image(img_idx: int) -> Image.Image: """Load image by dataset index using the active dataset's image_paths.""" return load_image_by_path(active_ds()['image_paths'][img_idx]) def parse_img_label(value: str) -> int: """Parse an image label into an integer dataset index. Accepts: exact filename ('nsd_31215.jpg'), bare int ('42'), or ImageNet-style synset ('n02655020_475'). """ val = value.strip() basename_index = active_ds()['basename_index'] key = os.path.splitext(val)[0] if key in basename_index: return basename_index[key] if val in basename_index: return basename_index[val] try: return int(val) except ValueError: pass return int(val.rsplit('_', 1)[-1]) # ---------- Heatmap rendering ---------- def render_heatmap_overlay(img_idx: int, heatmap_16x16, size: int = THUMB, cmap=ALPHA_JET, alpha: float = 1.0) -> Image.Image: """Blend a patch-grid heatmap over an image.""" img = load_image(img_idx).resize((size, size), Image.BILINEAR) base = np.array(img).astype(np.float32) / 255.0 hmap = heatmap_16x16.numpy() if hasattr(heatmap_16x16, 'numpy') else heatmap_16x16 hmap = hmap.astype(np.float32) hmap_up = cv2.resize(hmap, (size, size), interpolation=cv2.INTER_CUBIC) hmax = hmap_up.max() hmap_norm = hmap_up / hmax if hmax > 0 else hmap_up overlay = cmap(hmap_norm) ov_alpha = overlay[:, :, 3:4] * alpha blended = base * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha return Image.fromarray(np.clip(blended * 255, 0, 255).astype(np.uint8)) def render_zoomed_overlay(img_idx: int, heatmap_16x16, size: int = THUMB, pg: int | None = None, alpha: float = 1.0, zoom_patches: int | None = None, center: str = 'peak') -> Image.Image: """Heatmap overlay cropped to a zoom window. zoom_patches controls the neighbourhood size (in patches). At full zoom (zoom_patches >= pg) the whole image is returned. center='peak' — window centred on the argmax patch. center='centroid' — window centred on the activation-weighted centroid. """ ds = active_ds() if pg is None: pg = ds['heatmap_patch_grid'] if zoom_patches is None: zoom_patches = pg hmap = heatmap_16x16.numpy() if hasattr(heatmap_16x16, 'numpy') else heatmap_16x16 # Render at native resolution so the crop is high quality image_size = ds['image_size'] overlay = render_heatmap_overlay(img_idx, hmap, size=image_size, alpha=alpha) if zoom_patches >= pg: return overlay.resize((size, size), Image.BILINEAR) # Find crop centre if center == 'centroid': total = hmap.sum() if total > 0: peak_row = int(np.average(np.arange(pg), weights=hmap.sum(axis=1))) peak_col = int(np.average(np.arange(pg), weights=hmap.sum(axis=0))) else: peak_row = peak_col = pg // 2 else: peak_idx = np.argmax(hmap) peak_row, peak_col = divmod(int(peak_idx), pg) patch_px = image_size // pg half = (zoom_patches * patch_px) // 2 cy = peak_row * patch_px + patch_px // 2 cx = peak_col * patch_px + patch_px // 2 y0 = max(0, cy - half); y1 = min(image_size, cy + half) x0 = max(0, cx - half); x1 = min(image_size, cx + half) return overlay.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR) def pil_to_data_url(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="JPEG", quality=85) b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/jpeg;base64,{b64}" # ---------- Thumbnail cache ---------- _thumb_cache: dict[tuple[int, int], str] = {} # (img_idx, size) → data URL _THUMB_CACHE_MAX = 4096 def _get_thumb_url(img_idx: int, size: int) -> str | None: """Return a cached data URL for a plain (no heatmap) thumbnail, or compute and cache it.""" key = (img_idx, size) url = _thumb_cache.get(key) if url is not None: return url try: pil = load_image(img_idx).resize((size, size), Image.BILINEAR) url = pil_to_data_url(pil) except Exception: return None if len(_thumb_cache) >= _THUMB_CACHE_MAX: # Evict oldest quarter for k in list(_thumb_cache)[:_THUMB_CACHE_MAX // 4]: del _thumb_cache[k] _thumb_cache[key] = url return url def pil_to_bokeh_rgba(pil_img: Image.Image, size: int) -> np.ndarray: """Convert PIL image to a uint32 RGBA array suitable for Bokeh image_rgba.""" pil_img = pil_img.resize((size, size), Image.BILINEAR).convert("RGBA") arr = np.array(pil_img, dtype=np.uint8) out = np.empty((size, size), dtype=np.uint32) view = out.view(dtype=np.uint8).reshape((size, size, 4)) view[:, :, :] = arr return out[::-1].copy() # ---------- HTML builders ---------- def status_html(state: str, msg: str) -> str: """Styled status banner. state ∈ {'idle', 'loading', 'ok', 'dead'}.""" styles = { 'idle': 'background:#f9fafb;border-left:3px solid #d1d5db;color:#6b7280', 'loading': 'background:#fffbeb;border-left:3px solid #f59e0b;color:#92400e', 'ok': 'background:#ecfdf5;border-left:3px solid #10b981;color:#065f46', 'dead': 'background:#fef2f2;border-left:3px solid #ef4444;color:#991b1b', } style = styles.get(state, styles['idle']) return f'
No examples available
') tw = min(THUMB, 224) if cols is not None: grid_style = (f'display:grid;grid-template-columns:repeat({cols},{tw}px);' f'gap:10px;padding:4px 0 10px 0') else: grid_style = 'display:flex;flex-wrap:wrap;gap:10px;padding:4px 0 10px 0' html = (f'