""" Analyze error distribution to identify problematic regions in Munsell space. This script: 1. Runs the best model on all REAL Munsell colors 2. Computes Delta-E for each sample 3. Identifies samples with high error (Delta-E > threshold) 4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues 5. Outputs statistics and visualizations """ import logging from collections import defaultdict import numpy as np import onnxruntime as ort from colour import XYZ_to_Lab, xyY_to_XYZ from colour.difference import delta_E_CIE2000 from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL from colour.notation.munsell import ( CCS_ILLUMINANT_MUNSELL, munsell_colour_to_munsell_specification, munsell_specification_to_xyY, ) from learning_munsell import PROJECT_ROOT logging.basicConfig(level=logging.INFO, format="%(message)s") LOGGER = logging.getLogger(__name__) HUE_NAMES = { 1: "R", 2: "YR", 3: "Y", 4: "GY", 5: "G", 6: "BG", 7: "B", 8: "PB", 9: "P", 10: "RP", 0: "RP", } def load_model_and_params(model_name: str) -> tuple: """Load ONNX model and normalization parameters.""" model_dir = PROJECT_ROOT / "models" / "from_xyY" model_path = model_dir / f"{model_name}.onnx" params_path = model_dir / f"{model_name}_normalization_parameters.npz" if not model_path.exists(): msg = f"Model not found: {model_path}" raise FileNotFoundError(msg) if not params_path.exists(): msg = f"Params not found: {params_path}" raise FileNotFoundError(msg) session = ort.InferenceSession(str(model_path)) params = np.load(params_path, allow_pickle=True) input_parameters = params["input_parameters"].item() output_parameters = params["output_parameters"].item() return session, input_parameters, output_parameters def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray: """Normalize xyY input.""" normalized = np.copy(xyY).astype(np.float32) # Scale Y from 0-100 to 0-1 range before normalization normalized[..., 2] = xyY[..., 2] / 100.0 normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / ( params["x_range"][1] - params["x_range"][0] ) normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / ( params["y_range"][1] - params["y_range"][0] ) normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / ( params["Y_range"][1] - params["Y_range"][0] ) return normalized def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray: """Denormalize Munsell output.""" denorm = np.copy(pred) denorm[..., 0] = ( pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) + params["hue_range"][0] ) denorm[..., 1] = ( pred[..., 1] * (params["value_range"][1] - params["value_range"][0]) + params["value_range"][0] ) denorm[..., 2] = ( pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + params["chroma_range"][0] ) denorm[..., 3] = ( pred[..., 3] * (params["code_range"][1] - params["code_range"][0]) + params["code_range"][0] ) return denorm def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float: """Compute Delta-E between predicted spec (via xyY) and ground truth xyY.""" try: pred_xyY = munsell_specification_to_xyY(pred_spec) pred_XYZ = xyY_to_XYZ(pred_xyY) pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL) # Ground truth Y is in 0-100 range, need to scale to 0-1 gt_xyY_scaled = gt_xyY.copy() gt_xyY_scaled[2] = gt_xyY[2] / 100.0 gt_XYZ = xyY_to_XYZ(gt_xyY_scaled) gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL) return delta_E_CIE2000(gt_Lab, pred_Lab) except Exception: # noqa: BLE001 return np.nan def analyze_errors( model_name: str = "multi_head_large", threshold: float = 3.0 ) -> list: """Analyze error distribution for a model.""" LOGGER.info("=" * 80) LOGGER.info("Error Analysis for %s", model_name) LOGGER.info("=" * 80) # Load model session, input_parameters, output_parameters = load_model_and_params(model_name) input_name = session.get_inputs()[0].name # Collect data results = [] for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL: hue_code_str, value, chroma = munsell_spec_tuple munsell_str = f"{hue_code_str} {value}/{chroma}" try: gt_spec = munsell_colour_to_munsell_specification(munsell_str) gt_xyY = np.array(xyY_gt) # Predict xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_parameters) pred_norm = session.run(None, {input_name: xyY_norm})[0] pred_spec = denormalize_output(pred_norm, output_parameters)[0] # Clamp to valid ranges pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0) pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0) pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0) pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0) pred_spec[3] = np.round(pred_spec[3]) # Compute Delta-E delta_e = compute_delta_e(pred_spec, gt_xyY) if not np.isnan(delta_e): results.append( { "munsell_str": munsell_str, "gt_spec": gt_spec, "pred_spec": pred_spec, "delta_e": delta_e, "hue": gt_spec[0], "value": gt_spec[1], "chroma": gt_spec[2], "code": int(gt_spec[3]), "gt_xyY": gt_xyY, } ) except Exception as e: # noqa: BLE001 LOGGER.warning("Failed for %s: %s", munsell_str, e) LOGGER.info("\nTotal samples evaluated: %d", len(results)) # Overall statistics delta_es = [r["delta_e"] for r in results] LOGGER.info("\nOverall Delta-E Statistics:") LOGGER.info(" Mean: %.4f", np.mean(delta_es)) LOGGER.info(" Median: %.4f", np.median(delta_es)) LOGGER.info(" Std: %.4f", np.std(delta_es)) LOGGER.info(" Min: %.4f", np.min(delta_es)) LOGGER.info(" Max: %.4f", np.max(delta_es)) # Distribution LOGGER.info("\nDelta-E Distribution:") for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]: count = sum(1 for d in delta_es if d <= thresh) pct = 100 * count / len(delta_es) LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct) # High error samples high_error = [r for r in results if r["delta_e"] > threshold] LOGGER.info( "\nSamples with Delta-E > %.1f: %d (%.1f%%)", threshold, len(high_error), 100 * len(high_error) / len(results), ) # Analyze by hue family LOGGER.info("\n%s", "=" * 40) LOGGER.info("Analysis by Hue Family") LOGGER.info("=" * 40) by_hue = defaultdict(list) for r in results: hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}") by_hue[hue_name].append(r["delta_e"]) LOGGER.info( "\n%-4s %5s %6s %6s %6s %s", "Hue", "Count", "Mean", "Median", "Max", ">3.0", ) for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]: if hue_name in by_hue: des = by_hue[hue_name] high = sum(1 for d in des if d > 3.0) LOGGER.info( "%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)", hue_name, len(des), np.mean(des), np.median(des), np.max(des), high, 100 * high / len(des), ) # Analyze by value range LOGGER.info("\n%s", "=" * 40) LOGGER.info("Analysis by Value Range") LOGGER.info("=" * 40) value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)] LOGGER.info( "\n%-8s %5s %6s %6s %6s %s", "Value", "Count", "Mean", "Median", "Max", ">3.0", ) for v_min, v_max in value_ranges: des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max] if des: high = sum(1 for d in des if d > 3.0) LOGGER.info( "[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)", v_min, v_max, len(des), np.mean(des), np.median(des), np.max(des), high, 100 * high / len(des) if des else 0, ) # Analyze by chroma range LOGGER.info("\n%s", "=" * 40) LOGGER.info("Analysis by Chroma Range") LOGGER.info("=" * 40) chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)] LOGGER.info( "\n%-8s %5s %6s %6s %6s %s", "Chroma", "Count", "Mean", "Median", "Max", ">3.0", ) for c_min, c_max in chroma_ranges: des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max] if des: high = sum(1 for d in des if d > 3.0) LOGGER.info( "[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)", c_min, c_max, len(des), np.mean(des), np.median(des), np.max(des), high, 100 * high / len(des) if des else 0, ) # Top 20 worst samples LOGGER.info("\n%s", "=" * 40) LOGGER.info("Top 20 Worst Samples") LOGGER.info("=" * 40) worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20] LOGGER.info( "\n%-15s %6s %-20s %-20s", "Munsell", "DeltaE", "GT Spec", "Pred Spec" ) for r in worst: gs = r["gt_spec"] ps = r["pred_spec"] gt = f"[{gs[0]:.1f}, {gs[1]:.1f}, {gs[2]:.1f}, {int(gs[3])}]" pred = f"[{ps[0]:.1f}, {ps[1]:.1f}, {ps[2]:.1f}, {int(ps[3])}]" LOGGER.info( "%-15s %6.2f %-20s %-20s", r["munsell_str"], r["delta_e"], gt, pred ) # Analyze component errors for high-error samples LOGGER.info("\n%s", "=" * 40) LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold) LOGGER.info("=" * 40) if high_error: hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error] value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error] chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error] code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error] LOGGER.info("\n%-10s %6s %6s %6s", "Component", "Mean", "Median", "Max") LOGGER.info( "%-10s %6.2f %6.2f %6.2f", "Hue", np.mean(hue_errors), np.median(hue_errors), np.max(hue_errors), ) LOGGER.info( "%-10s %6.2f %6.2f %6.2f", "Value", np.mean(value_errors), np.median(value_errors), np.max(value_errors), ) LOGGER.info( "%-10s %6.2f %6.2f %6.2f", "Chroma", np.mean(chroma_errors), np.median(chroma_errors), np.max(chroma_errors), ) LOGGER.info( "%-10s %6.2f %6.2f %6.2f", "Code", np.mean(code_errors), np.median(code_errors), np.max(code_errors), ) return results def main() -> None: """Run error analysis.""" # Try the best models models = [ "multi_head_large", ] for model_name in models: try: analyze_errors(model_name, threshold=3.0) except FileNotFoundError as e: LOGGER.warning("Skipping %s: %s", model_name, e) LOGGER.info("\n") if __name__ == "__main__": main()