siddsuresh97's picture
Initial commit: ICLR 2026 Representational Alignment Challenge
d6c8a4f
from __future__ import annotations
from functools import lru_cache
from typing import Iterable, Mapping
import numpy as np
Stimulus = Mapping[str, str]
_DUMMY_MODELS = [
"vit_base_patch16_224",
"vit_large_patch16_224",
"resnet50",
"resnet101",
"convnext_base",
"convnext_large",
"deit_base_patch16_224",
"clip_vit_b32",
"swin_tiny_patch4_window7_224",
]
_DUMMY_STIMULI = [
{"dataset_name": "cifar100", "image_identifier": "test/airplane/image_0001.png"},
{"dataset_name": "cifar100", "image_identifier": "test/bear/image_0007.png"},
{"dataset_name": "cifar100", "image_identifier": "test/bottle/image_0012.png"},
{"dataset_name": "cifar100", "image_identifier": "test/bus/image_0021.png"},
{"dataset_name": "cifar100", "image_identifier": "test/girl/image_0033.png"},
{"dataset_name": "cifar100", "image_identifier": "test/keyboard/image_0044.png"},
{"dataset_name": "cifar100", "image_identifier": "test/rocket/image_0051.png"},
{"dataset_name": "cifar100", "image_identifier": "test/whale/image_0068.png"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n01440764/ILSVRC2012_val_00000964.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n02123159/ILSVRC2012_val_00001459.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n03255030/ILSVRC2012_val_00001903.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n03445777/ILSVRC2012_val_00003572.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n03729826/ILSVRC2012_val_00005336.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n03902125/ILSVRC2012_val_00006614.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n04254777/ILSVRC2012_val_00007190.JPEG"},
{"dataset_name": "imagenet1k", "image_identifier": "val/n04557648/ILSVRC2012_val_00009024.JPEG"},
]
def list_dummy_models() -> list[str]:
return list(_DUMMY_MODELS)
def list_dummy_stimuli() -> list[dict[str, str]]:
return [dict(item) for item in _DUMMY_STIMULI]
def stimulus_key(stimulus: Stimulus) -> str:
dataset_name = stimulus.get("dataset_name", "").strip()
image_identifier = stimulus.get("image_identifier", "").strip()
if not dataset_name or not image_identifier:
raise ValueError("Stimulus must include dataset_name and image_identifier.")
return f"{dataset_name}::{image_identifier}"
def resolve_stimulus_indices(
selected_stimuli: Iterable[Stimulus | str],
available_stimuli: Iterable[Stimulus],
) -> list[int]:
stimulus_index = {stimulus_key(stimulus): idx for idx, stimulus in enumerate(available_stimuli)}
keys: list[str] = []
for item in selected_stimuli:
if isinstance(item, str):
key = item
else:
key = stimulus_key(item)
keys.append(key)
if not keys:
raise ValueError("Select at least one stimulus.")
if len(keys) != len(set(keys)):
raise ValueError("Stimulus selections must be unique.")
missing = [key for key in keys if key not in stimulus_index]
if missing:
missing_str = ", ".join(missing)
raise ValueError(f"Unknown stimuli requested: {missing_str}")
return [stimulus_index[key] for key in keys]
@lru_cache(maxsize=1)
def get_dummy_model_embeddings() -> dict[str, np.ndarray]:
rng = np.random.default_rng(2026)
models = list_dummy_models()
stimuli = list_dummy_stimuli()
num_samples = len(stimuli)
dim = 64
family_by_model = {
"vit_base_patch16_224": "vit",
"vit_large_patch16_224": "vit",
"resnet50": "resnet",
"resnet101": "resnet",
"convnext_base": "convnext",
"convnext_large": "convnext",
"deit_base_patch16_224": "vit",
"clip_vit_b32": "vit",
"swin_tiny_patch4_window7_224": "swin",
}
global_base = rng.normal(size=(num_samples, dim)).astype(np.float32)
family_bases = {}
for family in sorted(set(family_by_model.values())):
family_noise = rng.normal(size=(num_samples, dim)).astype(np.float32)
family_stimulus = rng.normal(size=(num_samples, dim)).astype(np.float32)
# Separate family-level structure so cross-family CKA drops more.
family_bases[family] = 0.15 * global_base + 0.55 * family_noise + 0.30 * family_stimulus
scales = np.linspace(0.02, 0.08, len(models))
embeddings: dict[str, np.ndarray] = {}
for model_name, scale in zip(models, scales):
family = family_by_model.get(model_name, "other")
base = family_bases.get(family, global_base)
noise = rng.normal(size=(num_samples, dim)).astype(np.float32)
embeddings[model_name] = (base + scale * noise).astype(np.float32)
return embeddings