File size: 4,789 Bytes
d6c8a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations

from functools import lru_cache
from typing import Iterable, Mapping

import numpy as np

Stimulus = Mapping[str, str]

_DUMMY_MODELS = [
    "vit_base_patch16_224",
    "vit_large_patch16_224",
    "resnet50",
    "resnet101",
    "convnext_base",
    "convnext_large",
    "deit_base_patch16_224",
    "clip_vit_b32",
    "swin_tiny_patch4_window7_224",
]

_DUMMY_STIMULI = [
    {"dataset_name": "cifar100", "image_identifier": "test/airplane/image_0001.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/bear/image_0007.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/bottle/image_0012.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/bus/image_0021.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/girl/image_0033.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/keyboard/image_0044.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/rocket/image_0051.png"},
    {"dataset_name": "cifar100", "image_identifier": "test/whale/image_0068.png"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n01440764/ILSVRC2012_val_00000964.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n02123159/ILSVRC2012_val_00001459.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n03255030/ILSVRC2012_val_00001903.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n03445777/ILSVRC2012_val_00003572.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n03729826/ILSVRC2012_val_00005336.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n03902125/ILSVRC2012_val_00006614.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n04254777/ILSVRC2012_val_00007190.JPEG"},
    {"dataset_name": "imagenet1k", "image_identifier": "val/n04557648/ILSVRC2012_val_00009024.JPEG"},
]


def list_dummy_models() -> list[str]:
    return list(_DUMMY_MODELS)


def list_dummy_stimuli() -> list[dict[str, str]]:
    return [dict(item) for item in _DUMMY_STIMULI]


def stimulus_key(stimulus: Stimulus) -> str:
    dataset_name = stimulus.get("dataset_name", "").strip()
    image_identifier = stimulus.get("image_identifier", "").strip()
    if not dataset_name or not image_identifier:
        raise ValueError("Stimulus must include dataset_name and image_identifier.")
    return f"{dataset_name}::{image_identifier}"


def resolve_stimulus_indices(
    selected_stimuli: Iterable[Stimulus | str],
    available_stimuli: Iterable[Stimulus],
) -> list[int]:
    stimulus_index = {stimulus_key(stimulus): idx for idx, stimulus in enumerate(available_stimuli)}
    keys: list[str] = []
    for item in selected_stimuli:
        if isinstance(item, str):
            key = item
        else:
            key = stimulus_key(item)
        keys.append(key)

    if not keys:
        raise ValueError("Select at least one stimulus.")
    if len(keys) != len(set(keys)):
        raise ValueError("Stimulus selections must be unique.")

    missing = [key for key in keys if key not in stimulus_index]
    if missing:
        missing_str = ", ".join(missing)
        raise ValueError(f"Unknown stimuli requested: {missing_str}")

    return [stimulus_index[key] for key in keys]


@lru_cache(maxsize=1)
def get_dummy_model_embeddings() -> dict[str, np.ndarray]:
    rng = np.random.default_rng(2026)
    models = list_dummy_models()
    stimuli = list_dummy_stimuli()
    num_samples = len(stimuli)
    dim = 64

    family_by_model = {
        "vit_base_patch16_224": "vit",
        "vit_large_patch16_224": "vit",
        "resnet50": "resnet",
        "resnet101": "resnet",
        "convnext_base": "convnext",
        "convnext_large": "convnext",
        "deit_base_patch16_224": "vit",
        "clip_vit_b32": "vit",
        "swin_tiny_patch4_window7_224": "swin",
    }

    global_base = rng.normal(size=(num_samples, dim)).astype(np.float32)
    family_bases = {}
    for family in sorted(set(family_by_model.values())):
        family_noise = rng.normal(size=(num_samples, dim)).astype(np.float32)
        family_stimulus = rng.normal(size=(num_samples, dim)).astype(np.float32)
        # Separate family-level structure so cross-family CKA drops more.
        family_bases[family] = 0.15 * global_base + 0.55 * family_noise + 0.30 * family_stimulus

    scales = np.linspace(0.02, 0.08, len(models))

    embeddings: dict[str, np.ndarray] = {}
    for model_name, scale in zip(models, scales):
        family = family_by_model.get(model_name, "other")
        base = family_bases.get(family, global_base)
        noise = rng.normal(size=(num_samples, dim)).astype(np.float32)
        embeddings[model_name] = (base + scale * noise).astype(np.float32)

    return embeddings