pratik-250620's picture
Upload folder using huggingface_hub
6835659 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Dict
import numpy as np
from src.embeddings.similarity import cosine_similarity
@dataclass(frozen=True)
class MSCIResult:
st_i: float
st_a: float
si_a: Optional[float]
msci: float
weights: Dict[str, float]
def compute_msci_v0(
emb_text: np.ndarray,
emb_image: np.ndarray,
emb_audio: np.ndarray,
include_image_audio: bool = True,
w_ti: float = 0.45,
w_ta: float = 0.45,
w_ia: float = 0.10,
) -> MSCIResult:
st_i = cosine_similarity(emb_text, emb_image)
st_a = cosine_similarity(emb_text, emb_audio)
si_a = cosine_similarity(emb_image, emb_audio) if include_image_audio else None
if include_image_audio:
total = w_ti + w_ta + w_ia
msci = (w_ti * st_i + w_ta * st_a + w_ia * (si_a or 0.0)) / total
weights = {"w_ti": w_ti, "w_ta": w_ta, "w_ia": w_ia}
else:
total = w_ti + w_ta
msci = (w_ti * st_i + w_ta * st_a) / total
weights = {"w_ti": w_ti, "w_ta": w_ta}
return MSCIResult(
st_i=float(round(st_i, 4)),
st_a=float(round(st_a, 4)),
si_a=float(round(si_a, 4)) if si_a is not None else None,
msci=float(round(msci, 4)),
weights=weights,
)