| # Learning Munsell | |
| Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings. | |
| ## Overview | |
| This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications: | |
| - **xyY to Munsell (from_xyY)**: 25+ architectures, best Delta-E 0.52 | |
| - **Munsell to xyY (to_xyY)**: 9 architectures, best Delta-E 0.48 | |
| ### Delta-E Interpretation | |
| - **< 1.0**: Not perceptible by human eye | |
| - **1-2**: Perceptible through close observation | |
| - **2-10**: Perceptible at a glance | |
| - **> 10**: Colors are perceived as completely different | |
| Our best models achieve **Delta-E 0.48-0.52**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**. | |
| ## xyY to Munsell (from_xyY) | |
| ### Performance Benchmarks | |
| Comprehensive comparison using all 2,734 REAL Munsell colors: | |
| | Model | Delta-E | Speed (ms) | | |
| |----------------------------------------------------------|-------------|------------| | |
| | Colour Library (Baseline) | 0.00 | 111.90 | | |
| | **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 | | |
| | Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 | | |
| | Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 | | |
| | Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 | | |
| | MLP + Error Predictor | 0.53 | 0.030 | | |
| | Multi-ResNet (Large Dataset) | 0.54 | 0.044 | | |
| | Multi-Head + Multi-Error Predictor | 0.54 | 0.042 | | |
| | Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 | | |
| | Deep + Wide | 0.60 | 0.074 | | |
| | Multi-Head (Large Dataset) | 0.66 | 0.013 | | |
| | Mixture of Experts | 0.80 | 0.020 | | |
| | Transformer (Large Dataset) | 0.82 | 0.123 | | |
| | Multi-MLP | 0.86 | 0.027 | | |
| | MLP + Self-Attention | 0.88 | 0.173 | | |
| | MLP (Base Only) | 1.09 | **0.007** | | |
| | Unified MLP | 1.12 | 0.072 | | |
| Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate). | |
| **Best Models**: | |
| - **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) - Delta-E 0.52 | |
| - **Fastest**: MLP Base Only (0.007 ms/sample) - 15,492x faster than Colour library | |
| - **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large - 1,951x faster with Delta-E 0.52 | |
| ### Model Architectures | |
| 25+ architectures were systematically evaluated: | |
| **Single-Stage Models** | |
| 1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs | |
| 2. **Unified MLP** - Single large MLP with shared features | |
| 3. **Multi-Head** - Shared encoder with 4 independent decoder heads | |
| 4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples | |
| 5. **Multi-MLP** - 4 completely independent MLP branches (one per output) | |
| 6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples | |
| 7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting | |
| 8. **Deep + Wide** - Combined deep and wide network paths | |
| 9. **Mixture of Experts** - Gating network selecting specialized expert networks | |
| 10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data | |
| 11. **FT-Transformer** - Feature Tokenizer Transformer (standard size) | |
| **Two-Stage Models** | |
| 12. **MLP + Error Predictor** - Base MLP with unified error correction | |
| 13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors | |
| 14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant | |
| 15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors | |
| 16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant | |
| 17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches (BEST) | |
| The **Multi-ResNet + Multi-Error Predictor (Large Dataset)** architecture achieved the best results with Delta-E 0.52. | |
| ### Training Methodology | |
| **Data Generation** | |
| 1. **Dense xyY Grid** (~500K samples) | |
| - Regular grid in valid xyY space (MacAdam limits for Illuminant C) | |
| - Captures general input distribution | |
| 2. **Boundary Refinement** (~700K samples) | |
| - Adaptive dense sampling near Munsell gamut boundaries | |
| - Uses `maximum_chroma_from_renotation` to detect edges | |
| - Focuses on regions where iterative algorithm is most complex | |
| - Includes Y/GY/G hue regions with high value/chroma (challenging areas) | |
| 3. **Forward Augmentation** (~200K samples) | |
| - Dense Munsell space sampling via `munsell_specification_to_xyY` | |
| - Ensures coverage of known valid colors | |
| Total: ~1.4M samples for large dataset training. | |
| **Loss Functions** | |
| Two loss function approaches were tested: | |
| *Precision-Focused Loss* (Default): | |
| ``` | |
| total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss | |
| ``` | |
| - MSE: Standard mean squared error | |
| - MAE: Mean absolute error | |
| - Log penalty: Heavily penalizes small errors (pushes toward high precision) | |
| - Huber loss: Small delta (0.01) for precision on small errors | |
| *Pure MSE Loss* (Optimized config): | |
| ``` | |
| total_loss = MSE | |
| ``` | |
| Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy. | |
| ### Design Rationale | |
| **Two-Stage Architecture** | |
| The error predictor stage corrects systematic biases in the base model: | |
| 1. Base model learns the general xyY to Munsell mapping | |
| 2. Error predictor learns residual corrections specific to each component | |
| 3. Combined prediction: `final = base_prediction + error_correction` | |
| This decomposition allows each stage to specialize and reduces the complexity each network must learn. | |
| **Independent Branch Design** | |
| Munsell components have different characteristics: | |
| - **Hue**: Circular (0-10, wrapping), most complex | |
| - **Value**: Linear (0-10), easiest to predict | |
| - **Chroma**: Highly variable range depending on hue/value | |
| - **Code**: Discrete hue sector (0-9) | |
| Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization. | |
| **Architecture Details** | |
| *MLP (Base Only)* | |
| Simple feedforward network predicting all 4 outputs simultaneously: | |
| Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code) | |
| - Smallest model (~8KB ONNX) | |
| - Fastest inference (0.007 ms) | |
| - Baseline for comparison | |
| *Unified MLP* | |
| Single large MLP with shared internal features: | |
| Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4) | |
| - Shared representations across all outputs | |
| - Moderate size, good speed | |
| *Multi-Head MLP* | |
| Shared encoder with specialized decoder heads: | |
| Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1) | |
| ├──► Value Head (512→256→128→1) | |
| ├──► Chroma Head (512→384→256→128→1) | |
| └──► Code Head (512→256→128→1) | |
| - Shared encoder learns common color space features | |
| - 4 specialized decoder heads branch from shared representation | |
| - Parameter efficient (encoder weights shared) | |
| - Fast inference (encoder computed once) | |
| *Multi-MLP* | |
| Fully independent branches with no weight sharing: | |
| Input (3) ──► Hue Branch (3→128→256→512→256→128→1) | |
| Input (3) ──► Value Branch (3→128→256→512→256→128→1) | |
| Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider] | |
| Input (3) ──► Code Branch (3→128→256→512→256→128→1) | |
| - 4 completely independent MLPs | |
| - Each branch learns its own features from scratch | |
| - Chroma branch is wider (2x) to handle its complexity | |
| - Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors) | |
| *Multi-ResNet* | |
| Deep branches with residual-style connections: | |
| Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers] | |
| Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers] | |
| Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider] | |
| Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers] | |
| - Deeper architecture than Multi-MLP | |
| - BatchNorm + SiLU activation | |
| - Best accuracy when combined with error predictor (Delta-E 0.52) | |
| - Largest model (~14MB base, ~28MB with error predictor) | |
| *Deep + Wide* | |
| Combined deep and wide network paths: | |
| Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4) | |
| └──► Wide Path (direct connection) ─┘ | |
| - Deep path captures complex patterns | |
| - Wide path preserves direct input information | |
| - Good for mixed linear/nonlinear relationships | |
| *MLP + Self-Attention* | |
| MLP with attention mechanism for feature weighting: | |
| Input (3) ──► MLP ──► Self-Attention ──► Output (4) | |
| - Attention weights learn feature importance | |
| - Slower due to attention computation (0.173 ms) | |
| - Did not improve over simpler MLPs | |
| *Mixture of Experts* | |
| Gating network selecting specialized expert networks: | |
| Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4) | |
| - Multiple expert networks specialize in different input regions | |
| - Gating network learns which expert to use | |
| - More complex but did not outperform Multi-MLP | |
| *FT-Transformer* | |
| Feature Tokenizer Transformer for tabular data: | |
| Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4) | |
| - Each input feature tokenized separately | |
| - Self-attention across feature tokens | |
| - Good for tabular data with feature interactions | |
| - Slower inference due to attention computation | |
| *Error Predictor (Two-Stage)* | |
| Second-stage network that corrects base model errors: | |
| Stage 1: Input (3) ──► Base Model ──► Base Prediction (4) | |
| Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4) | |
| Final: Base Prediction + Error Correction = Final Output | |
| - Learns residual corrections for each component | |
| - Can have unified (1 network) or multi (4 networks) error predictors | |
| - Consistently improves accuracy across all base architectures | |
| - Best results: Multi-ResNet + Multi-Error Predictor (Delta-E 0.52) | |
| **Loss-Metric Mismatch** | |
| An important finding: **optimizing MSE does not optimize Delta-E**. | |
| The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because: | |
| - MSE treats all component errors equally | |
| - Delta-E (CIE2000) weights errors based on human perception | |
| - The precision-focused loss with custom weights better approximates perceptual importance | |
| **Weighted Boundary Loss (Experimental)** | |
| Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by: | |
| 1. Applying 3x loss weight to samples in challenging regions: | |
| - Hue: 0.18-0.35 (normalized range covering Y/YG/G) | |
| - Value > 0.7 (high brightness) | |
| - Chroma > 0.5 (high saturation) | |
| 2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits | |
| **Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model (Multi-MLP W+B + Multi-Error Predictor W+B Large, Delta-E 0.524) and the standard large dataset model (Multi-MLP + Multi-Error Predictor Large, Delta-E 0.525) achieve nearly identical results, making explicit loss weighting optional. The best overall model is Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52. | |
| ### Experimental Findings | |
| The following experiments were conducted but did not improve results: | |
| **Delta-E Training** | |
| Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator. | |
| *Hypothesis*: Perceptual Delta-E loss might outperform MSE-trained models. | |
| *Implementation*: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients. | |
| *Results*: While Delta-E is comparable, **hue accuracy is ~10x worse**: | |
| | Metric (Normalized MAE) | Delta-E Model | MSE Model | | |
| |--------------------------|---------------|-----------| | |
| | Hue MAE | 0.30 | 0.03 | | |
| | Value MAE | 0.002 | 0.004 | | |
| | Chroma MAE | 0.007 | 0.008 | | |
| | Code MAE | 0.07 | 0.01 | | |
| | **Delta-E (perceptual)** | **0.52** | **0.50** | | |
| *Key Takeaway*: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification. | |
| **Classical Interpolation** | |
| Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors. | |
| *Results (Validation MAE)*: | |
| | Component | RBF | KD-Tree | Delaunay | ML (Best) | | |
| |-----------|------|---------|----------|-----------| | |
| | Hue | 1.40 | 1.40 | 1.29 | **0.03** | | |
| | Value | 0.01 | 0.10 | 0.02 | 0.05 | | |
| | Chroma | 0.22 | 0.99 | 0.35 | **0.11** | | |
| | Code | 0.33 | 0.28 | 0.28 | **0.00** | | |
| *Key Insight*: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy. | |
| **Circular Hue Loss** | |
| Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps). | |
| *Results*: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24). | |
| *Key Takeaway*: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization. | |
| **REAL-Only Refinement** | |
| Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995). | |
| *Results*: Essentially identical performance (Delta-E 1.5233 vs 1.5191). | |
| *Key Takeaway*: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate. | |
| **Gamma Normalization** | |
| Gamma correction to the Y (luminance) channel during normalization. | |
| *Results*: No consistent improvement across gamma values 1.0-3.0: | |
| | Gamma | Median ΔE (± std) | | |
| |----------------|-------------------| | |
| | 1.0 (baseline) | 0.730 ± 0.054 | | |
| | 2.5 (best) | 0.683 ± 0.132 | | |
|  | |
| *Key Takeaway*: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise. | |
| ## Munsell to xyY (to_xyY) | |
| ### Performance Benchmarks | |
| Comprehensive comparison using all 2,734 REAL Munsell colors: | |
| | Model | Delta-E | Speed (ms) | | |
| |-----------------------------------------------|-------------|------------| | |
| | Colour Library (Baseline) | 0.00 | 1.27 | | |
| | **Multi-MLP (Optimized)** | **0.48** | 0.008 | | |
| | Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 | | |
| | Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 | | |
| | Multi-MLP | 0.66 | 0.016 | | |
| | Multi-MLP + Error Predictor | 0.67 | 0.018 | | |
| | Multi-Head (Optimized) | 0.71 | 0.015 | | |
| | Multi-Head | 0.78 | 0.008 | | |
| | Multi-Head + Multi-Error Predictor | 1.11 | 0.028 | | |
| | Simple MLP | 1.42 | **0.0008** | | |
| **Best Models**: | |
| - **Best Accuracy**: Multi-MLP (Optimized) - Delta-E 0.48 | |
| - **Fastest**: Simple MLP (0.0008 ms/sample) - 1,654x faster than Colour library | |
| - **Best Balance**: Multi-MLP (Optimized) - 154x faster with Delta-E 0.48 | |
| ### Model Architectures | |
| 9 architectures were evaluated for the Munsell to xyY direction: | |
| **Single-Stage Models** | |
| 1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs | |
| 2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y) | |
| 3. **Multi-Head (Optimized)** - Hyperparameter-optimized variant | |
| 4. **Multi-MLP** - 3 completely independent MLP branches | |
| 5. **Multi-MLP (Optimized)** - Hyperparameter-optimized variant (BEST) | |
| **Two-Stage Models** | |
| 6. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction | |
| 7. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors | |
| 8. **Multi-MLP (Opt) + Multi-Error Predictor (Opt)** - Optimized two-stage | |
| 9. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction | |
| The **Multi-MLP (Optimized)** architecture achieved the best results with Delta-E 0.48. | |
| ### Differentiable Approximator | |
| A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss: | |
| - **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU | |
| - **Accuracy**: MAE ~0.0006 for x, y, and Y components | |
| - **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz) | |
| This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables. | |
| ## Shared Infrastructure | |
| ### Hyperparameter Optimization | |
| Optuna was used for systematic hyperparameter search over: | |
| - Learning rate (1e-4 to 1e-3) | |
| - Batch size (256, 512, 1024) | |
| - Dropout rate (0.0 to 0.2) | |
| - Chroma branch width multiplier (1.0 to 2.0) | |
| - Loss function weights (MSE, Huber) | |
| Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization. | |
| ### Training Infrastructure | |
| - **Optimizer**: AdamW with weight decay | |
| - **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5) | |
| - **Early stopping**: Patience=20 epochs | |
| - **Checkpointing**: Best model saved based on validation loss | |
| - **Logging**: MLflow for experiment tracking | |
| ### JAX Delta-E Implementation | |
| Located in `learning_munsell/losses/jax_delta_e.py`: | |
| - Differentiable xyY -> XYZ -> Lab color space conversions | |
| - Full CIE 2000 Delta-E implementation with gradient support | |
| - JIT-compiled functions for performance | |
| Usage: | |
| ```python | |
| from learning_munsell.losses import delta_E_loss, delta_E_CIE2000 | |
| # Compute perceptual loss between predicted and target xyY | |
| loss = delta_E_loss(pred_xyY, target_xyY) | |
| ``` | |
| ## Limitations | |
| ### BatchNorm Instability on MPS | |
| Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend: | |
| 1. **Validation loss spikes** during training | |
| 2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1) | |
| 3. **Non-reproducible behavior** | |
| **Affected Models**: Large dataset error predictors using BatchNorm. | |
| **Workarounds**: | |
| 1. Use CPU for training | |
| 2. Replace BatchNorm with LayerNorm | |
| 3. Use smaller models (300K samples vs 2M) | |
| 4. Skip error predictor stage for affected models | |
| The recommended production model (`multi_resnet_error_predictor_large.onnx`) was trained on the large dataset and does not exhibit this instability. | |
| **References**: | |
| - [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602) | |
| - [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230) | |
| - [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/) | |