File size: 1,888 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from dataclasses import dataclass
from typing import Dict

from src.coherence.thresholds import AdaptiveThresholds


@dataclass
class CoherenceResult:
    label: str
    reason: str
    weakest_metric: str | None
    penalties: Dict[str, float]


class CoherenceClassifier:
    def __init__(self, thresholds: AdaptiveThresholds):
        self.t = thresholds

    def classify(self, scores: Dict[str, float], global_drift: bool) -> CoherenceResult:
        if not scores:
            return CoherenceResult(
                label="UNKNOWN",
                reason="Insufficient metrics for classification",
                weakest_metric=None,
                penalties={},
            )

        statuses = {m: self.t.classify_value(m, v) for m, v in scores.items()}

        fails = [m for m, s in statuses.items() if s == "FAIL"]
        weaks = [m for m, s in statuses.items() if s == "WEAK"]

        penalties: Dict[str, float] = {}
        weakest = min(scores, key=lambda m: scores[m]) if scores else None

        if statuses.get("msci") == "FAIL" or len(fails) >= 2:
            penalties["global_drift"] = 0.18
            label = "GLOBAL_FAILURE"
            reason = "Semantic alignment failed across modalities"
        elif len(fails) == 1:
            penalties["weak_modality"] = 0.12
            label = "MODALITY_FAILURE"
            reason = f"Failure in modality: {fails[0]}"
            weakest = fails[0]
        elif len(weaks) >= 1:
            penalties["weak_modality"] = 0.06
            label = "LOCAL_MODALITY_WEAKNESS"
            reason = f"Weak coherence in modality: {weaks[0]}"
        else:
            label = "HIGH_COHERENCE"
            reason = "Strong cross-modal semantic agreement"

        return CoherenceResult(
            label=label,
            reason=reason,
            weakest_metric=weakest,
            penalties=penalties,
        )