File size: 5,179 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations

import logging
from typing import Any, Dict, Optional

from src.coherence.drift_detector import detect_drift
from src.coherence.scorer import CoherenceScorer
from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.embeddings.similarity import cosine_similarity

logger = logging.getLogger(__name__)


class CoherenceEngine:
    """
    Evaluates multimodal coherence using correct embedding spaces.

    Text-Image similarity: CLIP shared space
    Text-Audio similarity: CLAP shared space
    Image-Audio similarity: cross-space (CLIP vs CLAP) — not directly comparable.
        We omit si_a from scoring by default since it would compare embeddings
        from different model spaces without a trained bridge.

    If a trained CrossSpaceBridge is loaded via load_bridge(), si_a is computed
    in the learned bridge space and the full MSCI formula activates:
        MSCI = 0.45 * st_i + 0.45 * st_a + 0.10 * si_a
    """

    def __init__(self, target_dim: int = 512):
        self.embedder = AlignedEmbedder(target_dim=target_dim)
        self.scorer = CoherenceScorer()
        self._bridge = None  # Optional CrossSpaceBridge for si_a

    def load_bridge(self, path: str) -> None:
        """
        Load a trained CrossSpaceBridge to enable image-audio similarity.

        Once loaded, si_a will be computed via the bridge's shared space
        instead of being set to None.

        Args:
            path: Path to saved bridge weights (.pt file)
        """
        from src.embeddings.cross_space_bridge import CrossSpaceBridge
        from pathlib import Path

        bridge_path = Path(path)
        if not bridge_path.exists():
            logger.warning("Bridge file not found: %s — si_a remains disabled", path)
            return

        self._bridge = CrossSpaceBridge.load(bridge_path)
        logger.info("Cross-space bridge loaded — si_a enabled")

    def evaluate(
        self,
        text: str,
        image_path: Optional[str] = None,
        audio_path: Optional[str] = None,
    ) -> Dict[str, Any]:
        # CLIP text embedding (for text-image comparison)
        emb_text_clip = self.embedder.embed_text(text)
        # CLAP text embedding (for text-audio comparison)
        emb_text_clap = self.embedder.embed_text_for_audio(text) if audio_path else None

        emb_image = None
        emb_audio = None

        if image_path:
            emb_image = self.embedder.embed_image(image_path)
        if audio_path:
            emb_audio = self.embedder.embed_audio(audio_path)

        scores: Dict[str, float | None] = {}

        # Text-Image: CLIP shared space (meaningful)
        if emb_text_clip is not None and emb_image is not None:
            scores["st_i"] = float(round(cosine_similarity(emb_text_clip, emb_image), 4))
            logger.debug("st_i = %.4f", scores["st_i"])

        # Text-Audio: CLAP shared space (meaningful)
        if emb_text_clap is not None and emb_audio is not None:
            scores["st_a"] = float(round(cosine_similarity(emb_text_clap, emb_audio), 4))
            logger.debug("st_a = %.4f", scores["st_a"])

        # Image-Audio: cross-space (CLIP image vs CLAP audio)
        # Without a bridge, these live in different spaces — similarity is meaningless.
        # With a trained bridge, project both into a shared space for si_a.
        if self._bridge is not None and emb_image is not None and emb_audio is not None:
            scores["si_a"] = float(round(
                self._bridge.compute_similarity(emb_image, emb_audio), 4
            ))
            logger.debug("si_a = %.4f (via bridge)", scores["si_a"])
        else:
            scores["si_a"] = None

        # Compute MSCI from available scores
        available = {k: v for k, v in scores.items() if v is not None}
        if len(available) >= 2:
            weights = {"st_i": 0.45, "st_a": 0.45, "si_a": 0.10}
            total = sum(weights[k] for k in available if k in weights)
            msci = sum(available[k] * weights[k] for k in available if k in weights) / max(total, 1e-6)
            scores["msci"] = float(round(msci, 4))
        elif len(available) == 1:
            scores["msci"] = float(round(list(available.values())[0], 4))
        else:
            scores["msci"] = None

        logger.info("MSCI = %s (from %d pairwise scores)", scores["msci"], len(available))

        drift = detect_drift(
            scores.get("msci"),
            scores.get("st_i"),
            scores.get("st_a"),
            scores.get("si_a"),
        )
        coherence = self.scorer.score(scores=scores, global_drift=drift["global_drift"])

        return {
            "scores": scores,
            "drift": drift,
            "coherence": coherence,
            "classification": coherence["classification"],
            "final_score": coherence["final_score"],
        }


def evaluate_coherence(
    text: str,
    image_path: Optional[str] = None,
    audio_path: Optional[str] = None,
) -> Dict[str, Any]:
    engine = CoherenceEngine()
    return engine.evaluate(
        text=text,
        image_path=image_path,
        audio_path=audio_path,
    )