# Learning Munsell Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings. ## Overview This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications: - **xyY to Munsell (from_xyY)**: 25+ architectures, best Delta-E 0.52 - **Munsell to xyY (to_xyY)**: 9 architectures, best Delta-E 0.48 ### Delta-E Interpretation - **< 1.0**: Not perceptible by human eye - **1-2**: Perceptible through close observation - **2-10**: Perceptible at a glance - **> 10**: Colors are perceived as completely different Our best models achieve **Delta-E 0.48-0.52**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**. ## xyY to Munsell (from_xyY) ### Performance Benchmarks Comprehensive comparison using all 2,734 REAL Munsell colors: | Model | Delta-E | Speed (ms) | |----------------------------------------------------------|-------------|------------| | Colour Library (Baseline) | 0.00 | 111.90 | | **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 | | Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 | | Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 | | Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 | | MLP + Error Predictor | 0.53 | 0.030 | | Multi-ResNet (Large Dataset) | 0.54 | 0.044 | | Multi-Head + Multi-Error Predictor | 0.54 | 0.042 | | Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 | | Deep + Wide | 0.60 | 0.074 | | Multi-Head (Large Dataset) | 0.66 | 0.013 | | Mixture of Experts | 0.80 | 0.020 | | Transformer (Large Dataset) | 0.82 | 0.123 | | Multi-MLP | 0.86 | 0.027 | | MLP + Self-Attention | 0.88 | 0.173 | | MLP (Base Only) | 1.09 | **0.007** | | Unified MLP | 1.12 | 0.072 | Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate). **Best Models**: - **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) - Delta-E 0.52 - **Fastest**: MLP Base Only (0.007 ms/sample) - 15,492x faster than Colour library - **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large - 1,951x faster with Delta-E 0.52 ### Model Architectures 25+ architectures were systematically evaluated: **Single-Stage Models** 1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs 2. **Unified MLP** - Single large MLP with shared features 3. **Multi-Head** - Shared encoder with 4 independent decoder heads 4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples 5. **Multi-MLP** - 4 completely independent MLP branches (one per output) 6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples 7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting 8. **Deep + Wide** - Combined deep and wide network paths 9. **Mixture of Experts** - Gating network selecting specialized expert networks 10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data 11. **FT-Transformer** - Feature Tokenizer Transformer (standard size) **Two-Stage Models** 12. **MLP + Error Predictor** - Base MLP with unified error correction 13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors 14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant 15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors 16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant 17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches (BEST) The **Multi-ResNet + Multi-Error Predictor (Large Dataset)** architecture achieved the best results with Delta-E 0.52. ### Training Methodology **Data Generation** 1. **Dense xyY Grid** (~500K samples) - Regular grid in valid xyY space (MacAdam limits for Illuminant C) - Captures general input distribution 2. **Boundary Refinement** (~700K samples) - Adaptive dense sampling near Munsell gamut boundaries - Uses `maximum_chroma_from_renotation` to detect edges - Focuses on regions where iterative algorithm is most complex - Includes Y/GY/G hue regions with high value/chroma (challenging areas) 3. **Forward Augmentation** (~200K samples) - Dense Munsell space sampling via `munsell_specification_to_xyY` - Ensures coverage of known valid colors Total: ~1.4M samples for large dataset training. **Loss Functions** Two loss function approaches were tested: *Precision-Focused Loss* (Default): ``` total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss ``` - MSE: Standard mean squared error - MAE: Mean absolute error - Log penalty: Heavily penalizes small errors (pushes toward high precision) - Huber loss: Small delta (0.01) for precision on small errors *Pure MSE Loss* (Optimized config): ``` total_loss = MSE ``` Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy. ### Design Rationale **Two-Stage Architecture** The error predictor stage corrects systematic biases in the base model: 1. Base model learns the general xyY to Munsell mapping 2. Error predictor learns residual corrections specific to each component 3. Combined prediction: `final = base_prediction + error_correction` This decomposition allows each stage to specialize and reduces the complexity each network must learn. **Independent Branch Design** Munsell components have different characteristics: - **Hue**: Circular (0-10, wrapping), most complex - **Value**: Linear (0-10), easiest to predict - **Chroma**: Highly variable range depending on hue/value - **Code**: Discrete hue sector (0-9) Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization. **Architecture Details** *MLP (Base Only)* Simple feedforward network predicting all 4 outputs simultaneously: Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code) - Smallest model (~8KB ONNX) - Fastest inference (0.007 ms) - Baseline for comparison *Unified MLP* Single large MLP with shared internal features: Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4) - Shared representations across all outputs - Moderate size, good speed *Multi-Head MLP* Shared encoder with specialized decoder heads: Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1) ├──► Value Head (512→256→128→1) ├──► Chroma Head (512→384→256→128→1) └──► Code Head (512→256→128→1) - Shared encoder learns common color space features - 4 specialized decoder heads branch from shared representation - Parameter efficient (encoder weights shared) - Fast inference (encoder computed once) *Multi-MLP* Fully independent branches with no weight sharing: Input (3) ──► Hue Branch (3→128→256→512→256→128→1) Input (3) ──► Value Branch (3→128→256→512→256→128→1) Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider] Input (3) ──► Code Branch (3→128→256→512→256→128→1) - 4 completely independent MLPs - Each branch learns its own features from scratch - Chroma branch is wider (2x) to handle its complexity - Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors) *Multi-ResNet* Deep branches with residual-style connections: Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers] Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers] Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider] Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers] - Deeper architecture than Multi-MLP - BatchNorm + SiLU activation - Best accuracy when combined with error predictor (Delta-E 0.52) - Largest model (~14MB base, ~28MB with error predictor) *Deep + Wide* Combined deep and wide network paths: Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4) └──► Wide Path (direct connection) ─┘ - Deep path captures complex patterns - Wide path preserves direct input information - Good for mixed linear/nonlinear relationships *MLP + Self-Attention* MLP with attention mechanism for feature weighting: Input (3) ──► MLP ──► Self-Attention ──► Output (4) - Attention weights learn feature importance - Slower due to attention computation (0.173 ms) - Did not improve over simpler MLPs *Mixture of Experts* Gating network selecting specialized expert networks: Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4) - Multiple expert networks specialize in different input regions - Gating network learns which expert to use - More complex but did not outperform Multi-MLP *FT-Transformer* Feature Tokenizer Transformer for tabular data: Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4) - Each input feature tokenized separately - Self-attention across feature tokens - Good for tabular data with feature interactions - Slower inference due to attention computation *Error Predictor (Two-Stage)* Second-stage network that corrects base model errors: Stage 1: Input (3) ──► Base Model ──► Base Prediction (4) Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4) Final: Base Prediction + Error Correction = Final Output - Learns residual corrections for each component - Can have unified (1 network) or multi (4 networks) error predictors - Consistently improves accuracy across all base architectures - Best results: Multi-ResNet + Multi-Error Predictor (Delta-E 0.52) **Loss-Metric Mismatch** An important finding: **optimizing MSE does not optimize Delta-E**. The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because: - MSE treats all component errors equally - Delta-E (CIE2000) weights errors based on human perception - The precision-focused loss with custom weights better approximates perceptual importance **Weighted Boundary Loss (Experimental)** Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by: 1. Applying 3x loss weight to samples in challenging regions: - Hue: 0.18-0.35 (normalized range covering Y/YG/G) - Value > 0.7 (high brightness) - Chroma > 0.5 (high saturation) 2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits **Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model (Multi-MLP W+B + Multi-Error Predictor W+B Large, Delta-E 0.524) and the standard large dataset model (Multi-MLP + Multi-Error Predictor Large, Delta-E 0.525) achieve nearly identical results, making explicit loss weighting optional. The best overall model is Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52. ### Experimental Findings The following experiments were conducted but did not improve results: **Delta-E Training** Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator. *Hypothesis*: Perceptual Delta-E loss might outperform MSE-trained models. *Implementation*: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients. *Results*: While Delta-E is comparable, **hue accuracy is ~10x worse**: | Metric (Normalized MAE) | Delta-E Model | MSE Model | |--------------------------|---------------|-----------| | Hue MAE | 0.30 | 0.03 | | Value MAE | 0.002 | 0.004 | | Chroma MAE | 0.007 | 0.008 | | Code MAE | 0.07 | 0.01 | | **Delta-E (perceptual)** | **0.52** | **0.50** | *Key Takeaway*: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification. **Classical Interpolation** Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors. *Results (Validation MAE)*: | Component | RBF | KD-Tree | Delaunay | ML (Best) | |-----------|------|---------|----------|-----------| | Hue | 1.40 | 1.40 | 1.29 | **0.03** | | Value | 0.01 | 0.10 | 0.02 | 0.05 | | Chroma | 0.22 | 0.99 | 0.35 | **0.11** | | Code | 0.33 | 0.28 | 0.28 | **0.00** | *Key Insight*: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy. **Circular Hue Loss** Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps). *Results*: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24). *Key Takeaway*: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization. **REAL-Only Refinement** Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995). *Results*: Essentially identical performance (Delta-E 1.5233 vs 1.5191). *Key Takeaway*: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate. **Gamma Normalization** Gamma correction to the Y (luminance) channel during normalization. *Results*: No consistent improvement across gamma values 1.0-3.0: | Gamma | Median ΔE (± std) | |----------------|-------------------| | 1.0 (baseline) | 0.730 ± 0.054 | | 2.5 (best) | 0.683 ± 0.132 | ![Gamma sweep results](_static/gamma_sweep_plot.png) *Key Takeaway*: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise. ## Munsell to xyY (to_xyY) ### Performance Benchmarks Comprehensive comparison using all 2,734 REAL Munsell colors: | Model | Delta-E | Speed (ms) | |-----------------------------------------------|-------------|------------| | Colour Library (Baseline) | 0.00 | 1.27 | | **Multi-MLP (Optimized)** | **0.48** | 0.008 | | Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 | | Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 | | Multi-MLP | 0.66 | 0.016 | | Multi-MLP + Error Predictor | 0.67 | 0.018 | | Multi-Head (Optimized) | 0.71 | 0.015 | | Multi-Head | 0.78 | 0.008 | | Multi-Head + Multi-Error Predictor | 1.11 | 0.028 | | Simple MLP | 1.42 | **0.0008** | **Best Models**: - **Best Accuracy**: Multi-MLP (Optimized) - Delta-E 0.48 - **Fastest**: Simple MLP (0.0008 ms/sample) - 1,654x faster than Colour library - **Best Balance**: Multi-MLP (Optimized) - 154x faster with Delta-E 0.48 ### Model Architectures 9 architectures were evaluated for the Munsell to xyY direction: **Single-Stage Models** 1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs 2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y) 3. **Multi-Head (Optimized)** - Hyperparameter-optimized variant 4. **Multi-MLP** - 3 completely independent MLP branches 5. **Multi-MLP (Optimized)** - Hyperparameter-optimized variant (BEST) **Two-Stage Models** 6. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction 7. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors 8. **Multi-MLP (Opt) + Multi-Error Predictor (Opt)** - Optimized two-stage 9. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction The **Multi-MLP (Optimized)** architecture achieved the best results with Delta-E 0.48. ### Differentiable Approximator A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss: - **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU - **Accuracy**: MAE ~0.0006 for x, y, and Y components - **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz) This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables. ## Shared Infrastructure ### Hyperparameter Optimization Optuna was used for systematic hyperparameter search over: - Learning rate (1e-4 to 1e-3) - Batch size (256, 512, 1024) - Dropout rate (0.0 to 0.2) - Chroma branch width multiplier (1.0 to 2.0) - Loss function weights (MSE, Huber) Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization. ### Training Infrastructure - **Optimizer**: AdamW with weight decay - **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5) - **Early stopping**: Patience=20 epochs - **Checkpointing**: Best model saved based on validation loss - **Logging**: MLflow for experiment tracking ### JAX Delta-E Implementation Located in `learning_munsell/losses/jax_delta_e.py`: - Differentiable xyY -> XYZ -> Lab color space conversions - Full CIE 2000 Delta-E implementation with gradient support - JIT-compiled functions for performance Usage: ```python from learning_munsell.losses import delta_E_loss, delta_E_CIE2000 # Compute perceptual loss between predicted and target xyY loss = delta_E_loss(pred_xyY, target_xyY) ``` ## Limitations ### BatchNorm Instability on MPS Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend: 1. **Validation loss spikes** during training 2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1) 3. **Non-reproducible behavior** **Affected Models**: Large dataset error predictors using BatchNorm. **Workarounds**: 1. Use CPU for training 2. Replace BatchNorm with LayerNorm 3. Use smaller models (300K samples vs 2M) 4. Skip error predictor stage for affected models The recommended production model (`multi_resnet_error_predictor_large.onnx`) was trained on the large dataset and does not exhibit this instability. **References**: - [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602) - [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230) - [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/)