from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional def _safe_get(d: Dict[str, Any], *keys: str, default=None): cur: Any = d for k in keys: if not isinstance(cur, dict) or k not in cur: return default cur = cur[k] return cur def load_coherence_stats(path: str = "artifacts/coherence_stats.json") -> Dict[str, Any]: p = Path(path) if not p.exists(): return {} try: return json.loads(p.read_text(encoding="utf-8")) except Exception: return {} def normalize_metric(stats: Dict[str, Any], name: str, value: float) -> float: """ Normalize metric to [0,1] using robust percentiles if available. Expected stats formats we support: stats[name]["p05"], stats[name]["p95"] OR stats["metrics"][name]["p05"], ... Falls back to (value) clipped to [0,1] if no stats found. """ p05 = _safe_get(stats, name, "p05") p95 = _safe_get(stats, name, "p95") if p05 is None or p95 is None: p05 = _safe_get(stats, "metrics", name, "p05") p95 = _safe_get(stats, "metrics", name, "p95") if p05 is None or p95 is None or p95 == p05: return max(0.0, min(1.0, float(value))) v = (float(value) - float(p05)) / (float(p95) - float(p05)) return max(0.0, min(1.0, v)) @dataclass class CoherenceScoringConfig: w_msci: float = 0.35 w_st_i: float = 0.20 w_st_a: float = 0.20 w_si_a: float = 0.25 global_drift_penalty: float = 0.18 visual_drift_penalty: float = 0.10 audio_drift_penalty: float = 0.10 weakness_floor: float = 0.35 weakness_max_extra: float = 0.12 def compute_base_score( scores: Dict[str, float], stats: Dict[str, Any], cfg: CoherenceScoringConfig, ) -> Dict[str, Any]: msci = normalize_metric(stats, "msci", scores.get("msci", 0.0)) st_i = normalize_metric(stats, "st_i", scores.get("st_i", 0.0)) st_a = normalize_metric(stats, "st_a", scores.get("st_a", 0.0)) si_a = normalize_metric(stats, "si_a", scores.get("si_a", 0.0)) weights = [cfg.w_msci, cfg.w_st_i, cfg.w_st_a, cfg.w_si_a] wsum = sum(weights) if sum(weights) > 0 else 1.0 w_msci, w_st_i, w_st_a, w_si_a = [w / wsum for w in weights] base = w_msci * msci + w_st_i * st_i + w_st_a * st_a + w_si_a * si_a return { "base_score": float(max(0.0, min(1.0, base))), "normalized": {"msci": msci, "st_i": st_i, "st_a": st_a, "si_a": si_a}, "weights": {"msci": w_msci, "st_i": w_st_i, "st_a": w_st_a, "si_a": w_si_a}, } def compute_drift_penalties( normalized: Dict[str, float], drift: Dict[str, bool], cfg: CoherenceScoringConfig, ) -> Dict[str, Any]: penalties: Dict[str, float] = {} if drift.get("global_drift", False): penalties["global_drift"] = cfg.global_drift_penalty if drift.get("visual_drift", False): penalties["visual_drift"] = cfg.visual_drift_penalty if drift.get("audio_drift", False): penalties["audio_drift"] = cfg.audio_drift_penalty weakest = min(normalized.values()) if normalized else 1.0 if weakest < cfg.weakness_floor: ratio = (cfg.weakness_floor - weakest) / max(1e-6, cfg.weakness_floor) penalties["weakest_modality"] = float( min(cfg.weakness_max_extra, cfg.weakness_max_extra * ratio) ) total = float(sum(penalties.values())) return {"penalties": penalties, "total_penalty": total, "weakest": float(weakest)} def compute_final_coherence( scores: Dict[str, float], drift: Dict[str, bool], stats_path: str = "artifacts/coherence_stats.json", cfg: Optional[CoherenceScoringConfig] = None, ) -> Dict[str, Any]: cfg = cfg or CoherenceScoringConfig() stats = load_coherence_stats(stats_path) base_pack = compute_base_score(scores, stats, cfg) drift_pack = compute_drift_penalties(base_pack["normalized"], drift, cfg) final = base_pack["base_score"] - drift_pack["total_penalty"] final = float(max(0.0, min(1.0, final))) return { "base_score": base_pack["base_score"], "final_score": final, "normalized": base_pack["normalized"], "weights": base_pack["weights"], "penalties": drift_pack["penalties"], "total_penalty": drift_pack["total_penalty"], "weakest_modality": drift_pack["weakest"], "used_stats_file": stats_path, "stats_loaded": bool(stats), }