from __future__ import annotations from functools import lru_cache from typing import Iterable, Mapping import numpy as np Stimulus = Mapping[str, str] _DUMMY_MODELS = [ "vit_base_patch16_224", "vit_large_patch16_224", "resnet50", "resnet101", "convnext_base", "convnext_large", "deit_base_patch16_224", "clip_vit_b32", "swin_tiny_patch4_window7_224", ] _DUMMY_STIMULI = [ {"dataset_name": "cifar100", "image_identifier": "test/airplane/image_0001.png"}, {"dataset_name": "cifar100", "image_identifier": "test/bear/image_0007.png"}, {"dataset_name": "cifar100", "image_identifier": "test/bottle/image_0012.png"}, {"dataset_name": "cifar100", "image_identifier": "test/bus/image_0021.png"}, {"dataset_name": "cifar100", "image_identifier": "test/girl/image_0033.png"}, {"dataset_name": "cifar100", "image_identifier": "test/keyboard/image_0044.png"}, {"dataset_name": "cifar100", "image_identifier": "test/rocket/image_0051.png"}, {"dataset_name": "cifar100", "image_identifier": "test/whale/image_0068.png"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n01440764/ILSVRC2012_val_00000964.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n02123159/ILSVRC2012_val_00001459.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n03255030/ILSVRC2012_val_00001903.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n03445777/ILSVRC2012_val_00003572.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n03729826/ILSVRC2012_val_00005336.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n03902125/ILSVRC2012_val_00006614.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n04254777/ILSVRC2012_val_00007190.JPEG"}, {"dataset_name": "imagenet1k", "image_identifier": "val/n04557648/ILSVRC2012_val_00009024.JPEG"}, ] def list_dummy_models() -> list[str]: return list(_DUMMY_MODELS) def list_dummy_stimuli() -> list[dict[str, str]]: return [dict(item) for item in _DUMMY_STIMULI] def stimulus_key(stimulus: Stimulus) -> str: dataset_name = stimulus.get("dataset_name", "").strip() image_identifier = stimulus.get("image_identifier", "").strip() if not dataset_name or not image_identifier: raise ValueError("Stimulus must include dataset_name and image_identifier.") return f"{dataset_name}::{image_identifier}" def resolve_stimulus_indices( selected_stimuli: Iterable[Stimulus | str], available_stimuli: Iterable[Stimulus], ) -> list[int]: stimulus_index = {stimulus_key(stimulus): idx for idx, stimulus in enumerate(available_stimuli)} keys: list[str] = [] for item in selected_stimuli: if isinstance(item, str): key = item else: key = stimulus_key(item) keys.append(key) if not keys: raise ValueError("Select at least one stimulus.") if len(keys) != len(set(keys)): raise ValueError("Stimulus selections must be unique.") missing = [key for key in keys if key not in stimulus_index] if missing: missing_str = ", ".join(missing) raise ValueError(f"Unknown stimuli requested: {missing_str}") return [stimulus_index[key] for key in keys] @lru_cache(maxsize=1) def get_dummy_model_embeddings() -> dict[str, np.ndarray]: rng = np.random.default_rng(2026) models = list_dummy_models() stimuli = list_dummy_stimuli() num_samples = len(stimuli) dim = 64 family_by_model = { "vit_base_patch16_224": "vit", "vit_large_patch16_224": "vit", "resnet50": "resnet", "resnet101": "resnet", "convnext_base": "convnext", "convnext_large": "convnext", "deit_base_patch16_224": "vit", "clip_vit_b32": "vit", "swin_tiny_patch4_window7_224": "swin", } global_base = rng.normal(size=(num_samples, dim)).astype(np.float32) family_bases = {} for family in sorted(set(family_by_model.values())): family_noise = rng.normal(size=(num_samples, dim)).astype(np.float32) family_stimulus = rng.normal(size=(num_samples, dim)).astype(np.float32) # Separate family-level structure so cross-family CKA drops more. family_bases[family] = 0.15 * global_base + 0.55 * family_noise + 0.30 * family_stimulus scales = np.linspace(0.02, 0.08, len(models)) embeddings: dict[str, np.ndarray] = {} for model_name, scale in zip(models, scales): family = family_by_model.get(model_name, "other") base = family_bases.get(family, global_base) noise = rng.normal(size=(num_samples, dim)).astype(np.float32) embeddings[model_name] = (base + scale * noise).astype(np.float32) return embeddings