| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional, Dict |
| |
|
| | import numpy as np |
| |
|
| | from src.embeddings.similarity import cosine_similarity |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class MSCIResult: |
| | st_i: float |
| | st_a: float |
| | si_a: Optional[float] |
| | msci: float |
| | weights: Dict[str, float] |
| |
|
| |
|
| | def compute_msci_v0( |
| | emb_text: np.ndarray, |
| | emb_image: np.ndarray, |
| | emb_audio: np.ndarray, |
| | include_image_audio: bool = True, |
| | w_ti: float = 0.45, |
| | w_ta: float = 0.45, |
| | w_ia: float = 0.10, |
| | ) -> MSCIResult: |
| | st_i = cosine_similarity(emb_text, emb_image) |
| | st_a = cosine_similarity(emb_text, emb_audio) |
| |
|
| | si_a = cosine_similarity(emb_image, emb_audio) if include_image_audio else None |
| |
|
| | if include_image_audio: |
| | total = w_ti + w_ta + w_ia |
| | msci = (w_ti * st_i + w_ta * st_a + w_ia * (si_a or 0.0)) / total |
| | weights = {"w_ti": w_ti, "w_ta": w_ta, "w_ia": w_ia} |
| | else: |
| | total = w_ti + w_ta |
| | msci = (w_ti * st_i + w_ta * st_a) / total |
| | weights = {"w_ti": w_ti, "w_ta": w_ta} |
| |
|
| | return MSCIResult( |
| | st_i=float(round(st_i, 4)), |
| | st_a=float(round(st_a, 4)), |
| | si_a=float(round(si_a, 4)) if si_a is not None else None, |
| | msci=float(round(msci, 4)), |
| | weights=weights, |
| | ) |
| |
|