KelSolaar's picture
Initial commit.
3c7db92
"""
Analyze error distribution to identify problematic regions in Munsell space.
This script:
1. Runs the best model on all REAL Munsell colors
2. Computes Delta-E for each sample
3. Identifies samples with high error (Delta-E > threshold)
4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues
5. Outputs statistics and visualizations
"""
import logging
from collections import defaultdict
import numpy as np
import onnxruntime as ort
from colour import XYZ_to_Lab, xyY_to_XYZ
from colour.difference import delta_E_CIE2000
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
from colour.notation.munsell import (
CCS_ILLUMINANT_MUNSELL,
munsell_colour_to_munsell_specification,
munsell_specification_to_xyY,
)
from learning_munsell import PROJECT_ROOT
logging.basicConfig(level=logging.INFO, format="%(message)s")
LOGGER = logging.getLogger(__name__)
HUE_NAMES = {
1: "R",
2: "YR",
3: "Y",
4: "GY",
5: "G",
6: "BG",
7: "B",
8: "PB",
9: "P",
10: "RP",
0: "RP",
}
def load_model_and_params(model_name: str) -> tuple:
"""Load ONNX model and normalization parameters."""
model_dir = PROJECT_ROOT / "models" / "from_xyY"
model_path = model_dir / f"{model_name}.onnx"
params_path = model_dir / f"{model_name}_normalization_parameters.npz"
if not model_path.exists():
msg = f"Model not found: {model_path}"
raise FileNotFoundError(msg)
if not params_path.exists():
msg = f"Params not found: {params_path}"
raise FileNotFoundError(msg)
session = ort.InferenceSession(str(model_path))
params = np.load(params_path, allow_pickle=True)
input_parameters = params["input_parameters"].item()
output_parameters = params["output_parameters"].item()
return session, input_parameters, output_parameters
def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray:
"""Normalize xyY input."""
normalized = np.copy(xyY).astype(np.float32)
# Scale Y from 0-100 to 0-1 range before normalization
normalized[..., 2] = xyY[..., 2] / 100.0
normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / (
params["x_range"][1] - params["x_range"][0]
)
normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / (
params["y_range"][1] - params["y_range"][0]
)
normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / (
params["Y_range"][1] - params["Y_range"][0]
)
return normalized
def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray:
"""Denormalize Munsell output."""
denorm = np.copy(pred)
denorm[..., 0] = (
pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
+ params["hue_range"][0]
)
denorm[..., 1] = (
pred[..., 1] * (params["value_range"][1] - params["value_range"][0])
+ params["value_range"][0]
)
denorm[..., 2] = (
pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
+ params["chroma_range"][0]
)
denorm[..., 3] = (
pred[..., 3] * (params["code_range"][1] - params["code_range"][0])
+ params["code_range"][0]
)
return denorm
def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float:
"""Compute Delta-E between predicted spec (via xyY) and ground truth xyY."""
try:
pred_xyY = munsell_specification_to_xyY(pred_spec)
pred_XYZ = xyY_to_XYZ(pred_xyY)
pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL)
# Ground truth Y is in 0-100 range, need to scale to 0-1
gt_xyY_scaled = gt_xyY.copy()
gt_xyY_scaled[2] = gt_xyY[2] / 100.0
gt_XYZ = xyY_to_XYZ(gt_xyY_scaled)
gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL)
return delta_E_CIE2000(gt_Lab, pred_Lab)
except Exception: # noqa: BLE001
return np.nan
def analyze_errors(
model_name: str = "multi_head_large", threshold: float = 3.0
) -> list:
"""Analyze error distribution for a model."""
LOGGER.info("=" * 80)
LOGGER.info("Error Analysis for %s", model_name)
LOGGER.info("=" * 80)
# Load model
session, input_parameters, output_parameters = load_model_and_params(model_name)
input_name = session.get_inputs()[0].name
# Collect data
results = []
for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL:
hue_code_str, value, chroma = munsell_spec_tuple
munsell_str = f"{hue_code_str} {value}/{chroma}"
try:
gt_spec = munsell_colour_to_munsell_specification(munsell_str)
gt_xyY = np.array(xyY_gt)
# Predict
xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_parameters)
pred_norm = session.run(None, {input_name: xyY_norm})[0]
pred_spec = denormalize_output(pred_norm, output_parameters)[0]
# Clamp to valid ranges
pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0)
pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0)
pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0)
pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0)
pred_spec[3] = np.round(pred_spec[3])
# Compute Delta-E
delta_e = compute_delta_e(pred_spec, gt_xyY)
if not np.isnan(delta_e):
results.append(
{
"munsell_str": munsell_str,
"gt_spec": gt_spec,
"pred_spec": pred_spec,
"delta_e": delta_e,
"hue": gt_spec[0],
"value": gt_spec[1],
"chroma": gt_spec[2],
"code": int(gt_spec[3]),
"gt_xyY": gt_xyY,
}
)
except Exception as e: # noqa: BLE001
LOGGER.warning("Failed for %s: %s", munsell_str, e)
LOGGER.info("\nTotal samples evaluated: %d", len(results))
# Overall statistics
delta_es = [r["delta_e"] for r in results]
LOGGER.info("\nOverall Delta-E Statistics:")
LOGGER.info(" Mean: %.4f", np.mean(delta_es))
LOGGER.info(" Median: %.4f", np.median(delta_es))
LOGGER.info(" Std: %.4f", np.std(delta_es))
LOGGER.info(" Min: %.4f", np.min(delta_es))
LOGGER.info(" Max: %.4f", np.max(delta_es))
# Distribution
LOGGER.info("\nDelta-E Distribution:")
for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]:
count = sum(1 for d in delta_es if d <= thresh)
pct = 100 * count / len(delta_es)
LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct)
# High error samples
high_error = [r for r in results if r["delta_e"] > threshold]
LOGGER.info(
"\nSamples with Delta-E > %.1f: %d (%.1f%%)",
threshold,
len(high_error),
100 * len(high_error) / len(results),
)
# Analyze by hue family
LOGGER.info("\n%s", "=" * 40)
LOGGER.info("Analysis by Hue Family")
LOGGER.info("=" * 40)
by_hue = defaultdict(list)
for r in results:
hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}")
by_hue[hue_name].append(r["delta_e"])
LOGGER.info(
"\n%-4s %5s %6s %6s %6s %s",
"Hue",
"Count",
"Mean",
"Median",
"Max",
">3.0",
)
for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]:
if hue_name in by_hue:
des = by_hue[hue_name]
high = sum(1 for d in des if d > 3.0)
LOGGER.info(
"%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
hue_name,
len(des),
np.mean(des),
np.median(des),
np.max(des),
high,
100 * high / len(des),
)
# Analyze by value range
LOGGER.info("\n%s", "=" * 40)
LOGGER.info("Analysis by Value Range")
LOGGER.info("=" * 40)
value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)]
LOGGER.info(
"\n%-8s %5s %6s %6s %6s %s",
"Value",
"Count",
"Mean",
"Median",
"Max",
">3.0",
)
for v_min, v_max in value_ranges:
des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max]
if des:
high = sum(1 for d in des if d > 3.0)
LOGGER.info(
"[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
v_min,
v_max,
len(des),
np.mean(des),
np.median(des),
np.max(des),
high,
100 * high / len(des) if des else 0,
)
# Analyze by chroma range
LOGGER.info("\n%s", "=" * 40)
LOGGER.info("Analysis by Chroma Range")
LOGGER.info("=" * 40)
chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)]
LOGGER.info(
"\n%-8s %5s %6s %6s %6s %s",
"Chroma",
"Count",
"Mean",
"Median",
"Max",
">3.0",
)
for c_min, c_max in chroma_ranges:
des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max]
if des:
high = sum(1 for d in des if d > 3.0)
LOGGER.info(
"[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
c_min,
c_max,
len(des),
np.mean(des),
np.median(des),
np.max(des),
high,
100 * high / len(des) if des else 0,
)
# Top 20 worst samples
LOGGER.info("\n%s", "=" * 40)
LOGGER.info("Top 20 Worst Samples")
LOGGER.info("=" * 40)
worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20]
LOGGER.info(
"\n%-15s %6s %-20s %-20s", "Munsell", "DeltaE", "GT Spec", "Pred Spec"
)
for r in worst:
gs = r["gt_spec"]
ps = r["pred_spec"]
gt = f"[{gs[0]:.1f}, {gs[1]:.1f}, {gs[2]:.1f}, {int(gs[3])}]"
pred = f"[{ps[0]:.1f}, {ps[1]:.1f}, {ps[2]:.1f}, {int(ps[3])}]"
LOGGER.info(
"%-15s %6.2f %-20s %-20s", r["munsell_str"], r["delta_e"], gt, pred
)
# Analyze component errors for high-error samples
LOGGER.info("\n%s", "=" * 40)
LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold)
LOGGER.info("=" * 40)
if high_error:
hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error]
value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error]
chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error]
code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error]
LOGGER.info("\n%-10s %6s %6s %6s", "Component", "Mean", "Median", "Max")
LOGGER.info(
"%-10s %6.2f %6.2f %6.2f",
"Hue",
np.mean(hue_errors),
np.median(hue_errors),
np.max(hue_errors),
)
LOGGER.info(
"%-10s %6.2f %6.2f %6.2f",
"Value",
np.mean(value_errors),
np.median(value_errors),
np.max(value_errors),
)
LOGGER.info(
"%-10s %6.2f %6.2f %6.2f",
"Chroma",
np.mean(chroma_errors),
np.median(chroma_errors),
np.max(chroma_errors),
)
LOGGER.info(
"%-10s %6.2f %6.2f %6.2f",
"Code",
np.mean(code_errors),
np.median(code_errors),
np.max(code_errors),
)
return results
def main() -> None:
"""Run error analysis."""
# Try the best models
models = [
"multi_head_large",
]
for model_name in models:
try:
analyze_errors(model_name, threshold=3.0)
except FileNotFoundError as e:
LOGGER.warning("Skipping %s: %s", model_name, e)
LOGGER.info("\n")
if __name__ == "__main__":
main()