File size: 3,659 Bytes
fb9c7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a5ed93
 
fb9c7be
4a5ed93
 
fb9c7be
4a5ed93
 
fb9c7be
 
4a5ed93
 
fb9c7be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Pure feature display logic — no Bokeh dependencies.

Builds the data needed to render a feature's detail view: stats HTML,
top MEI image list, and NSD sample basename.  Called by panels/feature.py.
"""

import os

from PIL import Image as _PILImage

from .rendering import (
    THUMB, load_image, render_zoomed_overlay, _img_pool,
)


def build_stats_html(feat: int, freq_val, feat_name: str,
                     auto_name: str) -> str:
    """Return the feature header HTML (name + label)."""
    if freq_val == 0:
        label = ('<span style="color:#dc2626;font-size:13px;margin-left:8px;'
                 'font-weight:500">dead feature</span>')
    elif feat_name:
        label = (f'<div style="color:#2563eb;font-style:italic;'
                 f'font-size:14px;margin-top:4px">{feat_name}</div>')
    elif auto_name:
        label = (f'<div style="color:#059669;font-style:italic;'
                 f'font-size:14px;margin-top:4px">{auto_name}</div>')
    else:
        label = ''
    return (f'<h2 style="margin:4px 0;font-size:22px;font-weight:700;'
            f'color:#1a1d23">Feature {feat}</h2>' + label)


def build_top_mei_items(feat: int, ds: dict, n_display: int = 9,
                        zoom_patches: int = 16,
                        heatmap_alpha: float = 1.0):
    """Load and render top MEI images for a feature.

    Returns (top_infos, top_img_is, subset_label) where:
      top_infos  — list of (PIL.Image, caption_str) tuples
      top_img_is — parallel list of dataset image indices
      subset_label — " [NSD sub01]" or ""
    """
    use_nsd = ds.get('nsd_top_img_idx') is not None
    top_idx = ds['nsd_top_img_idx'] if use_nsd else ds['top_img_idx']
    top_hm = ds.get('nsd_top_heatmaps') if use_nsd else ds.get('top_heatmaps')
    subset_label = " [NSD sub01]" if use_nsd else ""

    def _render_one(ranking_idx):
        img_i = top_idx[feat, ranking_idx].item()
        try:
            hm = None
            if top_hm is not None and ds['heatmap_patch_grid'] > 1:
                hm = top_hm[feat, ranking_idx].float().numpy()
                hm = hm.reshape(ds['heatmap_patch_grid'],
                                ds['heatmap_patch_grid'])
            if hm is None:
                plain = load_image(img_i).resize((THUMB, THUMB))
                return (plain, "")
            img_out = render_zoomed_overlay(
                img_i, hm, size=THUMB,
                zoom_patches=zoom_patches,
                alpha=heatmap_alpha,
                center='peak',
            )
            return (img_out, "")
        except (FileNotFoundError, OSError):
            return None
        except Exception as e:
            return (_PILImage.new("RGB", (THUMB, THUMB), "gray"),
                    f"Error: {e}")

    # Build work list
    work = []
    for j in range(min(n_display, top_idx.shape[1])):
        img_i = top_idx[feat, j].item()
        if img_i < 0:
            break
        work.append((j, img_i))

    # Load MEIs in parallel
    top_infos = []
    top_img_is = []
    for (j, img_i), item in zip(
            work, _img_pool.map(lambda w: _render_one(w[0]), work)):
        if item is not None:
            top_infos.append(item)
            top_img_is.append(img_i)

    return top_infos, top_img_is, subset_label


def get_nsd_sample_basename(feat: int, ds: dict) -> str | None:
    """Return the NSD basename for the top image of a feature, or None."""
    if ds.get('nsd_top_img_idx') is None:
        return None
    top_i = ds['nsd_top_img_idx'][feat, 0].item()
    if top_i < 0:
        return None
    return os.path.splitext(os.path.basename(ds['image_paths'][top_i]))[0]