File size: 4,840 Bytes
d6c8a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf9b32
d6c8a4f
 
 
ebf9b32
d6c8a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf9b32
d6c8a4f
 
 
ebf9b32
d6c8a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations

from itertools import combinations
from typing import Iterable

import numpy as np

from src.cka.compute import linear_cka
from src.hackathon.data import get_dummy_model_embeddings, list_dummy_stimuli, resolve_stimulus_indices
from src.hackathon.modal_client import (
    is_modal_enabled,
    score_blue_with_pairwise as modal_score_blue_with_pairwise,
    score_red_with_pairwise as modal_score_red_with_pairwise,
)


def _validate_models(model_names: Iterable[str], embeddings_by_model: dict[str, np.ndarray]) -> list[str]:
    names = [name.strip() for name in model_names if name.strip()]
    if len(names) < 2:
        raise ValueError("Select at least two models.")
    if len(names) != len(set(names)):
        raise ValueError("Model selections must be unique.")

    missing = [name for name in names if name not in embeddings_by_model]
    if missing:
        missing_str = ", ".join(missing)
        raise ValueError(f"Unknown models requested: {missing_str}")

    return names


def _format_score(score: float) -> float:
    return round(float(score), 4)


def _pairwise_scores(
    model_names: Iterable[str],
    embeddings_by_model: dict[str, np.ndarray],
) -> tuple[float, list[dict[str, float | str]]]:
    scores = []
    pairwise: list[dict[str, float | str]] = []
    for model_a, model_b in combinations(model_names, 2):
        score = linear_cka(embeddings_by_model[model_a], embeddings_by_model[model_b])
        scores.append(score)
        pairwise.append(
            {
                "Model A": model_a,
                "Model B": model_b,
                "CKA": _format_score(score),
            }
        )

    if not scores:
        return 0.0, []
    return float(np.mean(scores)), pairwise


def score_blue_with_pairwise(
    model_names: Iterable[str],
    *,
    embeddings_by_model: dict[str, np.ndarray] | None = None,
    submission_id: str | None = None,
    submitter: str | None = None,
    hf_link: str | None = None,
) -> tuple[float, list[dict[str, float | str]]]:
    if embeddings_by_model is None and is_modal_enabled():
        avg_cka, pairwise = modal_score_blue_with_pairwise(
            model_names, submission_id=submission_id, submitter=submitter, hf_link=hf_link,
        )
        formatted = [
            {"Model A": item["model_a"], "Model B": item["model_b"], "CKA": _format_score(item["cka"])}
            for item in pairwise
        ]
        return float(avg_cka), formatted

    if embeddings_by_model is None:
        embeddings_by_model = get_dummy_model_embeddings()

    model_names = _validate_models(model_names, embeddings_by_model)
    avg_cka, pairwise = _pairwise_scores(model_names, embeddings_by_model)
    return float(avg_cka), pairwise


def score_blue(
    model_names: Iterable[str],
    *,
    embeddings_by_model: dict[str, np.ndarray] | None = None,
) -> float:
    avg_cka, _ = score_blue_with_pairwise(model_names, embeddings_by_model=embeddings_by_model)
    return float(avg_cka)


def score_red_with_pairwise(
    selected_stimuli: Iterable[dict[str, str] | str],
    *,
    embeddings_by_model: dict[str, np.ndarray] | None = None,
    stimuli_catalog: Iterable[dict[str, str]] | None = None,
    submission_id: str | None = None,
    submitter: str | None = None,
    hf_link: str | None = None,
) -> tuple[float, list[dict[str, float | str]]]:
    if embeddings_by_model is None and is_modal_enabled():
        score, pairwise = modal_score_red_with_pairwise(
            selected_stimuli, submission_id=submission_id, submitter=submitter, hf_link=hf_link,
        )
        formatted = [
            {"Model A": item["model_a"], "Model B": item["model_b"], "CKA": _format_score(item["cka"])}
            for item in pairwise
        ]
        return float(score), formatted

    if embeddings_by_model is None:
        embeddings_by_model = get_dummy_model_embeddings()
    if stimuli_catalog is None:
        stimuli_catalog = list_dummy_stimuli()

    model_names = _validate_models(embeddings_by_model.keys(), embeddings_by_model)
    stimulus_indices = resolve_stimulus_indices(selected_stimuli, stimuli_catalog)
    if len(stimulus_indices) < 2:
        raise ValueError("Select at least two stimuli.")

    filtered = {name: embeddings_by_model[name][stimulus_indices] for name in model_names}
    avg_cka, pairwise = _pairwise_scores(model_names, filtered)
    return float(1.0 - avg_cka), pairwise


def score_red(
    selected_stimuli: Iterable[dict[str, str] | str],
    *,
    embeddings_by_model: dict[str, np.ndarray] | None = None,
    stimuli_catalog: Iterable[dict[str, str]] | None = None,
) -> float:
    score, _ = score_red_with_pairwise(
        selected_stimuli,
        embeddings_by_model=embeddings_by_model,
        stimuli_catalog=stimuli_catalog,
    )
    return float(score)