| """ |
| Analyze error distribution to identify problematic regions in Munsell space. |
| |
| This script: |
| 1. Runs the best model on all REAL Munsell colors |
| 2. Computes Delta-E for each sample |
| 3. Identifies samples with high error (Delta-E > threshold) |
| 4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues |
| 5. Outputs statistics and visualizations |
| """ |
|
|
| import logging |
| from collections import defaultdict |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from colour import XYZ_to_Lab, xyY_to_XYZ |
| from colour.difference import delta_E_CIE2000 |
| from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL |
| from colour.notation.munsell import ( |
| CCS_ILLUMINANT_MUNSELL, |
| munsell_colour_to_munsell_specification, |
| munsell_specification_to_xyY, |
| ) |
|
|
| from learning_munsell import PROJECT_ROOT |
|
|
| logging.basicConfig(level=logging.INFO, format="%(message)s") |
| LOGGER = logging.getLogger(__name__) |
|
|
| HUE_NAMES = { |
| 1: "R", |
| 2: "YR", |
| 3: "Y", |
| 4: "GY", |
| 5: "G", |
| 6: "BG", |
| 7: "B", |
| 8: "PB", |
| 9: "P", |
| 10: "RP", |
| 0: "RP", |
| } |
|
|
|
|
| def load_model_and_params(model_name: str) -> tuple: |
| """Load ONNX model and normalization parameters.""" |
| model_dir = PROJECT_ROOT / "models" / "from_xyY" |
|
|
| model_path = model_dir / f"{model_name}.onnx" |
| params_path = model_dir / f"{model_name}_normalization_parameters.npz" |
|
|
| if not model_path.exists(): |
| msg = f"Model not found: {model_path}" |
| raise FileNotFoundError(msg) |
| if not params_path.exists(): |
| msg = f"Params not found: {params_path}" |
| raise FileNotFoundError(msg) |
|
|
| session = ort.InferenceSession(str(model_path)) |
| params = np.load(params_path, allow_pickle=True) |
| input_parameters = params["input_parameters"].item() |
| output_parameters = params["output_parameters"].item() |
|
|
| return session, input_parameters, output_parameters |
|
|
|
|
| def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray: |
| """Normalize xyY input.""" |
| normalized = np.copy(xyY).astype(np.float32) |
| |
| normalized[..., 2] = xyY[..., 2] / 100.0 |
| normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / ( |
| params["x_range"][1] - params["x_range"][0] |
| ) |
| normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / ( |
| params["y_range"][1] - params["y_range"][0] |
| ) |
| normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / ( |
| params["Y_range"][1] - params["Y_range"][0] |
| ) |
| return normalized |
|
|
|
|
| def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray: |
| """Denormalize Munsell output.""" |
| denorm = np.copy(pred) |
| denorm[..., 0] = ( |
| pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) |
| + params["hue_range"][0] |
| ) |
| denorm[..., 1] = ( |
| pred[..., 1] * (params["value_range"][1] - params["value_range"][0]) |
| + params["value_range"][0] |
| ) |
| denorm[..., 2] = ( |
| pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) |
| + params["chroma_range"][0] |
| ) |
| denorm[..., 3] = ( |
| pred[..., 3] * (params["code_range"][1] - params["code_range"][0]) |
| + params["code_range"][0] |
| ) |
| return denorm |
|
|
|
|
| def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float: |
| """Compute Delta-E between predicted spec (via xyY) and ground truth xyY.""" |
| try: |
| pred_xyY = munsell_specification_to_xyY(pred_spec) |
| pred_XYZ = xyY_to_XYZ(pred_xyY) |
| pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL) |
|
|
| |
| gt_xyY_scaled = gt_xyY.copy() |
| gt_xyY_scaled[2] = gt_xyY[2] / 100.0 |
| gt_XYZ = xyY_to_XYZ(gt_xyY_scaled) |
| gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL) |
|
|
| return delta_E_CIE2000(gt_Lab, pred_Lab) |
| except Exception: |
| return np.nan |
|
|
|
|
| def analyze_errors( |
| model_name: str = "multi_head_large", threshold: float = 3.0 |
| ) -> list: |
| """Analyze error distribution for a model.""" |
| LOGGER.info("=" * 80) |
| LOGGER.info("Error Analysis for %s", model_name) |
| LOGGER.info("=" * 80) |
|
|
| |
| session, input_parameters, output_parameters = load_model_and_params(model_name) |
| input_name = session.get_inputs()[0].name |
|
|
| |
| results = [] |
|
|
| for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL: |
| hue_code_str, value, chroma = munsell_spec_tuple |
| munsell_str = f"{hue_code_str} {value}/{chroma}" |
|
|
| try: |
| gt_spec = munsell_colour_to_munsell_specification(munsell_str) |
| gt_xyY = np.array(xyY_gt) |
|
|
| |
| xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_parameters) |
| pred_norm = session.run(None, {input_name: xyY_norm})[0] |
| pred_spec = denormalize_output(pred_norm, output_parameters)[0] |
|
|
| |
| pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0) |
| pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0) |
| pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0) |
| pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0) |
| pred_spec[3] = np.round(pred_spec[3]) |
|
|
| |
| delta_e = compute_delta_e(pred_spec, gt_xyY) |
|
|
| if not np.isnan(delta_e): |
| results.append( |
| { |
| "munsell_str": munsell_str, |
| "gt_spec": gt_spec, |
| "pred_spec": pred_spec, |
| "delta_e": delta_e, |
| "hue": gt_spec[0], |
| "value": gt_spec[1], |
| "chroma": gt_spec[2], |
| "code": int(gt_spec[3]), |
| "gt_xyY": gt_xyY, |
| } |
| ) |
| except Exception as e: |
| LOGGER.warning("Failed for %s: %s", munsell_str, e) |
|
|
| LOGGER.info("\nTotal samples evaluated: %d", len(results)) |
|
|
| |
| delta_es = [r["delta_e"] for r in results] |
| LOGGER.info("\nOverall Delta-E Statistics:") |
| LOGGER.info(" Mean: %.4f", np.mean(delta_es)) |
| LOGGER.info(" Median: %.4f", np.median(delta_es)) |
| LOGGER.info(" Std: %.4f", np.std(delta_es)) |
| LOGGER.info(" Min: %.4f", np.min(delta_es)) |
| LOGGER.info(" Max: %.4f", np.max(delta_es)) |
|
|
| |
| LOGGER.info("\nDelta-E Distribution:") |
| for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]: |
| count = sum(1 for d in delta_es if d <= thresh) |
| pct = 100 * count / len(delta_es) |
| LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct) |
|
|
| |
| high_error = [r for r in results if r["delta_e"] > threshold] |
| LOGGER.info( |
| "\nSamples with Delta-E > %.1f: %d (%.1f%%)", |
| threshold, |
| len(high_error), |
| 100 * len(high_error) / len(results), |
| ) |
|
|
| |
| LOGGER.info("\n%s", "=" * 40) |
| LOGGER.info("Analysis by Hue Family") |
| LOGGER.info("=" * 40) |
|
|
| by_hue = defaultdict(list) |
| for r in results: |
| hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}") |
| by_hue[hue_name].append(r["delta_e"]) |
|
|
| LOGGER.info( |
| "\n%-4s %5s %6s %6s %6s %s", |
| "Hue", |
| "Count", |
| "Mean", |
| "Median", |
| "Max", |
| ">3.0", |
| ) |
| for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]: |
| if hue_name in by_hue: |
| des = by_hue[hue_name] |
| high = sum(1 for d in des if d > 3.0) |
| LOGGER.info( |
| "%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)", |
| hue_name, |
| len(des), |
| np.mean(des), |
| np.median(des), |
| np.max(des), |
| high, |
| 100 * high / len(des), |
| ) |
|
|
| |
| LOGGER.info("\n%s", "=" * 40) |
| LOGGER.info("Analysis by Value Range") |
| LOGGER.info("=" * 40) |
|
|
| value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)] |
| LOGGER.info( |
| "\n%-8s %5s %6s %6s %6s %s", |
| "Value", |
| "Count", |
| "Mean", |
| "Median", |
| "Max", |
| ">3.0", |
| ) |
| for v_min, v_max in value_ranges: |
| des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max] |
| if des: |
| high = sum(1 for d in des if d > 3.0) |
| LOGGER.info( |
| "[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)", |
| v_min, |
| v_max, |
| len(des), |
| np.mean(des), |
| np.median(des), |
| np.max(des), |
| high, |
| 100 * high / len(des) if des else 0, |
| ) |
|
|
| |
| LOGGER.info("\n%s", "=" * 40) |
| LOGGER.info("Analysis by Chroma Range") |
| LOGGER.info("=" * 40) |
|
|
| chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)] |
| LOGGER.info( |
| "\n%-8s %5s %6s %6s %6s %s", |
| "Chroma", |
| "Count", |
| "Mean", |
| "Median", |
| "Max", |
| ">3.0", |
| ) |
| for c_min, c_max in chroma_ranges: |
| des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max] |
| if des: |
| high = sum(1 for d in des if d > 3.0) |
| LOGGER.info( |
| "[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)", |
| c_min, |
| c_max, |
| len(des), |
| np.mean(des), |
| np.median(des), |
| np.max(des), |
| high, |
| 100 * high / len(des) if des else 0, |
| ) |
|
|
| |
| LOGGER.info("\n%s", "=" * 40) |
| LOGGER.info("Top 20 Worst Samples") |
| LOGGER.info("=" * 40) |
|
|
| worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20] |
| LOGGER.info( |
| "\n%-15s %6s %-20s %-20s", "Munsell", "DeltaE", "GT Spec", "Pred Spec" |
| ) |
| for r in worst: |
| gs = r["gt_spec"] |
| ps = r["pred_spec"] |
| gt = f"[{gs[0]:.1f}, {gs[1]:.1f}, {gs[2]:.1f}, {int(gs[3])}]" |
| pred = f"[{ps[0]:.1f}, {ps[1]:.1f}, {ps[2]:.1f}, {int(ps[3])}]" |
| LOGGER.info( |
| "%-15s %6.2f %-20s %-20s", r["munsell_str"], r["delta_e"], gt, pred |
| ) |
|
|
| |
| LOGGER.info("\n%s", "=" * 40) |
| LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold) |
| LOGGER.info("=" * 40) |
|
|
| if high_error: |
| hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error] |
| value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error] |
| chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error] |
| code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error] |
|
|
| LOGGER.info("\n%-10s %6s %6s %6s", "Component", "Mean", "Median", "Max") |
| LOGGER.info( |
| "%-10s %6.2f %6.2f %6.2f", |
| "Hue", |
| np.mean(hue_errors), |
| np.median(hue_errors), |
| np.max(hue_errors), |
| ) |
| LOGGER.info( |
| "%-10s %6.2f %6.2f %6.2f", |
| "Value", |
| np.mean(value_errors), |
| np.median(value_errors), |
| np.max(value_errors), |
| ) |
| LOGGER.info( |
| "%-10s %6.2f %6.2f %6.2f", |
| "Chroma", |
| np.mean(chroma_errors), |
| np.median(chroma_errors), |
| np.max(chroma_errors), |
| ) |
| LOGGER.info( |
| "%-10s %6.2f %6.2f %6.2f", |
| "Code", |
| np.mean(code_errors), |
| np.median(code_errors), |
| np.max(code_errors), |
| ) |
|
|
| return results |
|
|
|
|
| def main() -> None: |
| """Run error analysis.""" |
| |
| models = [ |
| "multi_head_large", |
| ] |
|
|
| for model_name in models: |
| try: |
| analyze_errors(model_name, threshold=3.0) |
| except FileNotFoundError as e: |
| LOGGER.warning("Skipping %s: %s", model_name, e) |
| LOGGER.info("\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|