from __future__ import annotations from itertools import combinations from typing import Iterable import numpy as np from src.cka.compute import linear_cka from src.hackathon.data import get_dummy_model_embeddings, list_dummy_stimuli, resolve_stimulus_indices from src.hackathon.modal_client import ( is_modal_enabled, score_blue_with_pairwise as modal_score_blue_with_pairwise, score_red_with_pairwise as modal_score_red_with_pairwise, ) def _validate_models(model_names: Iterable[str], embeddings_by_model: dict[str, np.ndarray]) -> list[str]: names = [name.strip() for name in model_names if name.strip()] if len(names) < 2: raise ValueError("Select at least two models.") if len(names) != len(set(names)): raise ValueError("Model selections must be unique.") missing = [name for name in names if name not in embeddings_by_model] if missing: missing_str = ", ".join(missing) raise ValueError(f"Unknown models requested: {missing_str}") return names def _format_score(score: float) -> float: return round(float(score), 4) def _pairwise_scores( model_names: Iterable[str], embeddings_by_model: dict[str, np.ndarray], ) -> tuple[float, list[dict[str, float | str]]]: scores = [] pairwise: list[dict[str, float | str]] = [] for model_a, model_b in combinations(model_names, 2): score = linear_cka(embeddings_by_model[model_a], embeddings_by_model[model_b]) scores.append(score) pairwise.append( { "Model A": model_a, "Model B": model_b, "CKA": _format_score(score), } ) if not scores: return 0.0, [] return float(np.mean(scores)), pairwise def score_blue_with_pairwise( model_names: Iterable[str], *, embeddings_by_model: dict[str, np.ndarray] | None = None, submission_id: str | None = None, submitter: str | None = None, hf_link: str | None = None, ) -> tuple[float, list[dict[str, float | str]]]: if embeddings_by_model is None and is_modal_enabled(): avg_cka, pairwise = modal_score_blue_with_pairwise( model_names, submission_id=submission_id, submitter=submitter, hf_link=hf_link, ) formatted = [ {"Model A": item["model_a"], "Model B": item["model_b"], "CKA": _format_score(item["cka"])} for item in pairwise ] return float(avg_cka), formatted if embeddings_by_model is None: embeddings_by_model = get_dummy_model_embeddings() model_names = _validate_models(model_names, embeddings_by_model) avg_cka, pairwise = _pairwise_scores(model_names, embeddings_by_model) return float(avg_cka), pairwise def score_blue( model_names: Iterable[str], *, embeddings_by_model: dict[str, np.ndarray] | None = None, ) -> float: avg_cka, _ = score_blue_with_pairwise(model_names, embeddings_by_model=embeddings_by_model) return float(avg_cka) def score_red_with_pairwise( selected_stimuli: Iterable[dict[str, str] | str], *, embeddings_by_model: dict[str, np.ndarray] | None = None, stimuli_catalog: Iterable[dict[str, str]] | None = None, submission_id: str | None = None, submitter: str | None = None, hf_link: str | None = None, ) -> tuple[float, list[dict[str, float | str]]]: if embeddings_by_model is None and is_modal_enabled(): score, pairwise = modal_score_red_with_pairwise( selected_stimuli, submission_id=submission_id, submitter=submitter, hf_link=hf_link, ) formatted = [ {"Model A": item["model_a"], "Model B": item["model_b"], "CKA": _format_score(item["cka"])} for item in pairwise ] return float(score), formatted if embeddings_by_model is None: embeddings_by_model = get_dummy_model_embeddings() if stimuli_catalog is None: stimuli_catalog = list_dummy_stimuli() model_names = _validate_models(embeddings_by_model.keys(), embeddings_by_model) stimulus_indices = resolve_stimulus_indices(selected_stimuli, stimuli_catalog) if len(stimulus_indices) < 2: raise ValueError("Select at least two stimuli.") filtered = {name: embeddings_by_model[name][stimulus_indices] for name in model_names} avg_cka, pairwise = _pairwise_scores(model_names, filtered) return float(1.0 - avg_cka), pairwise def score_red( selected_stimuli: Iterable[dict[str, str] | str], *, embeddings_by_model: dict[str, np.ndarray] | None = None, stimuli_catalog: Iterable[dict[str, str]] | None = None, ) -> float: score, _ = score_red_with_pairwise( selected_stimuli, embeddings_by_model=embeddings_by_model, stimuli_catalog=stimuli_catalog, ) return float(score)