File size: 1,308 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Dict

import numpy as np

from src.embeddings.similarity import cosine_similarity


@dataclass(frozen=True)
class MSCIResult:
    st_i: float
    st_a: float
    si_a: Optional[float]
    msci: float
    weights: Dict[str, float]


def compute_msci_v0(
    emb_text: np.ndarray,
    emb_image: np.ndarray,
    emb_audio: np.ndarray,
    include_image_audio: bool = True,
    w_ti: float = 0.45,
    w_ta: float = 0.45,
    w_ia: float = 0.10,
) -> MSCIResult:
    st_i = cosine_similarity(emb_text, emb_image)
    st_a = cosine_similarity(emb_text, emb_audio)

    si_a = cosine_similarity(emb_image, emb_audio) if include_image_audio else None

    if include_image_audio:
        total = w_ti + w_ta + w_ia
        msci = (w_ti * st_i + w_ta * st_a + w_ia * (si_a or 0.0)) / total
        weights = {"w_ti": w_ti, "w_ta": w_ta, "w_ia": w_ia}
    else:
        total = w_ti + w_ta
        msci = (w_ti * st_i + w_ta * st_a) / total
        weights = {"w_ti": w_ti, "w_ta": w_ta}

    return MSCIResult(
        st_i=float(round(st_i, 4)),
        st_a=float(round(st_a, 4)),
        si_a=float(round(si_a, 4)) if si_a is not None else None,
        msci=float(round(msci, 4)),
        weights=weights,
    )