File size: 4,520 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional


def _safe_get(d: Dict[str, Any], *keys: str, default=None):
    cur: Any = d
    for k in keys:
        if not isinstance(cur, dict) or k not in cur:
            return default
        cur = cur[k]
    return cur


def load_coherence_stats(path: str = "artifacts/coherence_stats.json") -> Dict[str, Any]:
    p = Path(path)
    if not p.exists():
        return {}
    try:
        return json.loads(p.read_text(encoding="utf-8"))
    except Exception:
        return {}


def normalize_metric(stats: Dict[str, Any], name: str, value: float) -> float:
    """
    Normalize metric to [0,1] using robust percentiles if available.
    Expected stats formats we support:
      stats[name]["p05"], stats[name]["p95"]
      OR stats["metrics"][name]["p05"], ...
    Falls back to (value) clipped to [0,1] if no stats found.
    """
    p05 = _safe_get(stats, name, "p05")
    p95 = _safe_get(stats, name, "p95")
    if p05 is None or p95 is None:
        p05 = _safe_get(stats, "metrics", name, "p05")
        p95 = _safe_get(stats, "metrics", name, "p95")

    if p05 is None or p95 is None or p95 == p05:
        return max(0.0, min(1.0, float(value)))

    v = (float(value) - float(p05)) / (float(p95) - float(p05))
    return max(0.0, min(1.0, v))


@dataclass
class CoherenceScoringConfig:
    w_msci: float = 0.35
    w_st_i: float = 0.20
    w_st_a: float = 0.20
    w_si_a: float = 0.25

    global_drift_penalty: float = 0.18
    visual_drift_penalty: float = 0.10
    audio_drift_penalty: float = 0.10

    weakness_floor: float = 0.35
    weakness_max_extra: float = 0.12


def compute_base_score(
    scores: Dict[str, float],
    stats: Dict[str, Any],
    cfg: CoherenceScoringConfig,
) -> Dict[str, Any]:
    msci = normalize_metric(stats, "msci", scores.get("msci", 0.0))
    st_i = normalize_metric(stats, "st_i", scores.get("st_i", 0.0))
    st_a = normalize_metric(stats, "st_a", scores.get("st_a", 0.0))
    si_a = normalize_metric(stats, "si_a", scores.get("si_a", 0.0))

    weights = [cfg.w_msci, cfg.w_st_i, cfg.w_st_a, cfg.w_si_a]
    wsum = sum(weights) if sum(weights) > 0 else 1.0
    w_msci, w_st_i, w_st_a, w_si_a = [w / wsum for w in weights]

    base = w_msci * msci + w_st_i * st_i + w_st_a * st_a + w_si_a * si_a

    return {
        "base_score": float(max(0.0, min(1.0, base))),
        "normalized": {"msci": msci, "st_i": st_i, "st_a": st_a, "si_a": si_a},
        "weights": {"msci": w_msci, "st_i": w_st_i, "st_a": w_st_a, "si_a": w_si_a},
    }


def compute_drift_penalties(
    normalized: Dict[str, float],
    drift: Dict[str, bool],
    cfg: CoherenceScoringConfig,
) -> Dict[str, Any]:
    penalties: Dict[str, float] = {}

    if drift.get("global_drift", False):
        penalties["global_drift"] = cfg.global_drift_penalty
    if drift.get("visual_drift", False):
        penalties["visual_drift"] = cfg.visual_drift_penalty
    if drift.get("audio_drift", False):
        penalties["audio_drift"] = cfg.audio_drift_penalty

    weakest = min(normalized.values()) if normalized else 1.0
    if weakest < cfg.weakness_floor:
        ratio = (cfg.weakness_floor - weakest) / max(1e-6, cfg.weakness_floor)
        penalties["weakest_modality"] = float(
            min(cfg.weakness_max_extra, cfg.weakness_max_extra * ratio)
        )

    total = float(sum(penalties.values()))
    return {"penalties": penalties, "total_penalty": total, "weakest": float(weakest)}


def compute_final_coherence(
    scores: Dict[str, float],
    drift: Dict[str, bool],
    stats_path: str = "artifacts/coherence_stats.json",
    cfg: Optional[CoherenceScoringConfig] = None,
) -> Dict[str, Any]:
    cfg = cfg or CoherenceScoringConfig()
    stats = load_coherence_stats(stats_path)

    base_pack = compute_base_score(scores, stats, cfg)
    drift_pack = compute_drift_penalties(base_pack["normalized"], drift, cfg)

    final = base_pack["base_score"] - drift_pack["total_penalty"]
    final = float(max(0.0, min(1.0, final)))

    return {
        "base_score": base_pack["base_score"],
        "final_score": final,
        "normalized": base_pack["normalized"],
        "weights": base_pack["weights"],
        "penalties": drift_pack["penalties"],
        "total_penalty": drift_pack["total_penalty"],
        "weakest_modality": drift_pack["weakest"],
        "used_stats_file": stats_path,
        "stats_loaded": bool(stats),
    }