| from __future__ import annotations |
|
|
| from functools import lru_cache |
| from typing import Iterable, Mapping |
|
|
| import numpy as np |
|
|
| Stimulus = Mapping[str, str] |
|
|
| _DUMMY_MODELS = [ |
| "vit_base_patch16_224", |
| "vit_large_patch16_224", |
| "resnet50", |
| "resnet101", |
| "convnext_base", |
| "convnext_large", |
| "deit_base_patch16_224", |
| "clip_vit_b32", |
| "swin_tiny_patch4_window7_224", |
| ] |
|
|
| _DUMMY_STIMULI = [ |
| {"dataset_name": "cifar100", "image_identifier": "test/airplane/image_0001.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/bear/image_0007.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/bottle/image_0012.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/bus/image_0021.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/girl/image_0033.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/keyboard/image_0044.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/rocket/image_0051.png"}, |
| {"dataset_name": "cifar100", "image_identifier": "test/whale/image_0068.png"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n01440764/ILSVRC2012_val_00000964.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n02123159/ILSVRC2012_val_00001459.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n03255030/ILSVRC2012_val_00001903.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n03445777/ILSVRC2012_val_00003572.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n03729826/ILSVRC2012_val_00005336.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n03902125/ILSVRC2012_val_00006614.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n04254777/ILSVRC2012_val_00007190.JPEG"}, |
| {"dataset_name": "imagenet1k", "image_identifier": "val/n04557648/ILSVRC2012_val_00009024.JPEG"}, |
| ] |
|
|
|
|
| def list_dummy_models() -> list[str]: |
| return list(_DUMMY_MODELS) |
|
|
|
|
| def list_dummy_stimuli() -> list[dict[str, str]]: |
| return [dict(item) for item in _DUMMY_STIMULI] |
|
|
|
|
| def stimulus_key(stimulus: Stimulus) -> str: |
| dataset_name = stimulus.get("dataset_name", "").strip() |
| image_identifier = stimulus.get("image_identifier", "").strip() |
| if not dataset_name or not image_identifier: |
| raise ValueError("Stimulus must include dataset_name and image_identifier.") |
| return f"{dataset_name}::{image_identifier}" |
|
|
|
|
| def resolve_stimulus_indices( |
| selected_stimuli: Iterable[Stimulus | str], |
| available_stimuli: Iterable[Stimulus], |
| ) -> list[int]: |
| stimulus_index = {stimulus_key(stimulus): idx for idx, stimulus in enumerate(available_stimuli)} |
| keys: list[str] = [] |
| for item in selected_stimuli: |
| if isinstance(item, str): |
| key = item |
| else: |
| key = stimulus_key(item) |
| keys.append(key) |
|
|
| if not keys: |
| raise ValueError("Select at least one stimulus.") |
| if len(keys) != len(set(keys)): |
| raise ValueError("Stimulus selections must be unique.") |
|
|
| missing = [key for key in keys if key not in stimulus_index] |
| if missing: |
| missing_str = ", ".join(missing) |
| raise ValueError(f"Unknown stimuli requested: {missing_str}") |
|
|
| return [stimulus_index[key] for key in keys] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_dummy_model_embeddings() -> dict[str, np.ndarray]: |
| rng = np.random.default_rng(2026) |
| models = list_dummy_models() |
| stimuli = list_dummy_stimuli() |
| num_samples = len(stimuli) |
| dim = 64 |
|
|
| family_by_model = { |
| "vit_base_patch16_224": "vit", |
| "vit_large_patch16_224": "vit", |
| "resnet50": "resnet", |
| "resnet101": "resnet", |
| "convnext_base": "convnext", |
| "convnext_large": "convnext", |
| "deit_base_patch16_224": "vit", |
| "clip_vit_b32": "vit", |
| "swin_tiny_patch4_window7_224": "swin", |
| } |
|
|
| global_base = rng.normal(size=(num_samples, dim)).astype(np.float32) |
| family_bases = {} |
| for family in sorted(set(family_by_model.values())): |
| family_noise = rng.normal(size=(num_samples, dim)).astype(np.float32) |
| family_stimulus = rng.normal(size=(num_samples, dim)).astype(np.float32) |
| |
| family_bases[family] = 0.15 * global_base + 0.55 * family_noise + 0.30 * family_stimulus |
|
|
| scales = np.linspace(0.02, 0.08, len(models)) |
|
|
| embeddings: dict[str, np.ndarray] = {} |
| for model_name, scale in zip(models, scales): |
| family = family_by_model.get(model_name, "other") |
| base = family_bases.get(family, global_base) |
| noise = rng.normal(size=(num_samples, dim)).astype(np.float32) |
| embeddings[model_name] = (base + scale * noise).astype(np.float32) |
|
|
| return embeddings |
|
|