learning-munsell / docs /learning_munsell.md
KelSolaar's picture
Initial commit.
fa06c67
# Learning Munsell
Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings.
## Overview
This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications:
- **xyY to Munsell (from_xyY)**: 25+ architectures, best Delta-E 0.52
- **Munsell to xyY (to_xyY)**: 9 architectures, best Delta-E 0.48
### Delta-E Interpretation
- **< 1.0**: Not perceptible by human eye
- **1-2**: Perceptible through close observation
- **2-10**: Perceptible at a glance
- **> 10**: Colors are perceived as completely different
Our best models achieve **Delta-E 0.48-0.52**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**.
## xyY to Munsell (from_xyY)
### Performance Benchmarks
Comprehensive comparison using all 2,734 REAL Munsell colors:
| Model | Delta-E | Speed (ms) |
|----------------------------------------------------------|-------------|------------|
| Colour Library (Baseline) | 0.00 | 111.90 |
| **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 |
| Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 |
| Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 |
| Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 |
| MLP + Error Predictor | 0.53 | 0.030 |
| Multi-ResNet (Large Dataset) | 0.54 | 0.044 |
| Multi-Head + Multi-Error Predictor | 0.54 | 0.042 |
| Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 |
| Deep + Wide | 0.60 | 0.074 |
| Multi-Head (Large Dataset) | 0.66 | 0.013 |
| Mixture of Experts | 0.80 | 0.020 |
| Transformer (Large Dataset) | 0.82 | 0.123 |
| Multi-MLP | 0.86 | 0.027 |
| MLP + Self-Attention | 0.88 | 0.173 |
| MLP (Base Only) | 1.09 | **0.007** |
| Unified MLP | 1.12 | 0.072 |
Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate).
**Best Models**:
- **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) - Delta-E 0.52
- **Fastest**: MLP Base Only (0.007 ms/sample) - 15,492x faster than Colour library
- **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large - 1,951x faster with Delta-E 0.52
### Model Architectures
25+ architectures were systematically evaluated:
**Single-Stage Models**
1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs
2. **Unified MLP** - Single large MLP with shared features
3. **Multi-Head** - Shared encoder with 4 independent decoder heads
4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples
5. **Multi-MLP** - 4 completely independent MLP branches (one per output)
6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples
7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting
8. **Deep + Wide** - Combined deep and wide network paths
9. **Mixture of Experts** - Gating network selecting specialized expert networks
10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data
11. **FT-Transformer** - Feature Tokenizer Transformer (standard size)
**Two-Stage Models**
12. **MLP + Error Predictor** - Base MLP with unified error correction
13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors
14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant
15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors
16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant
17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches (BEST)
The **Multi-ResNet + Multi-Error Predictor (Large Dataset)** architecture achieved the best results with Delta-E 0.52.
### Training Methodology
**Data Generation**
1. **Dense xyY Grid** (~500K samples)
- Regular grid in valid xyY space (MacAdam limits for Illuminant C)
- Captures general input distribution
2. **Boundary Refinement** (~700K samples)
- Adaptive dense sampling near Munsell gamut boundaries
- Uses `maximum_chroma_from_renotation` to detect edges
- Focuses on regions where iterative algorithm is most complex
- Includes Y/GY/G hue regions with high value/chroma (challenging areas)
3. **Forward Augmentation** (~200K samples)
- Dense Munsell space sampling via `munsell_specification_to_xyY`
- Ensures coverage of known valid colors
Total: ~1.4M samples for large dataset training.
**Loss Functions**
Two loss function approaches were tested:
*Precision-Focused Loss* (Default):
```
total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss
```
- MSE: Standard mean squared error
- MAE: Mean absolute error
- Log penalty: Heavily penalizes small errors (pushes toward high precision)
- Huber loss: Small delta (0.01) for precision on small errors
*Pure MSE Loss* (Optimized config):
```
total_loss = MSE
```
Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy.
### Design Rationale
**Two-Stage Architecture**
The error predictor stage corrects systematic biases in the base model:
1. Base model learns the general xyY to Munsell mapping
2. Error predictor learns residual corrections specific to each component
3. Combined prediction: `final = base_prediction + error_correction`
This decomposition allows each stage to specialize and reduces the complexity each network must learn.
**Independent Branch Design**
Munsell components have different characteristics:
- **Hue**: Circular (0-10, wrapping), most complex
- **Value**: Linear (0-10), easiest to predict
- **Chroma**: Highly variable range depending on hue/value
- **Code**: Discrete hue sector (0-9)
Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization.
**Architecture Details**
*MLP (Base Only)*
Simple feedforward network predicting all 4 outputs simultaneously:
Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code)
- Smallest model (~8KB ONNX)
- Fastest inference (0.007 ms)
- Baseline for comparison
*Unified MLP*
Single large MLP with shared internal features:
Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4)
- Shared representations across all outputs
- Moderate size, good speed
*Multi-Head MLP*
Shared encoder with specialized decoder heads:
Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1)
├──► Value Head (512→256→128→1)
├──► Chroma Head (512→384→256→128→1)
└──► Code Head (512→256→128→1)
- Shared encoder learns common color space features
- 4 specialized decoder heads branch from shared representation
- Parameter efficient (encoder weights shared)
- Fast inference (encoder computed once)
*Multi-MLP*
Fully independent branches with no weight sharing:
Input (3) ──► Hue Branch (3→128→256→512→256→128→1)
Input (3) ──► Value Branch (3→128→256→512→256→128→1)
Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider]
Input (3) ──► Code Branch (3→128→256→512→256→128→1)
- 4 completely independent MLPs
- Each branch learns its own features from scratch
- Chroma branch is wider (2x) to handle its complexity
- Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors)
*Multi-ResNet*
Deep branches with residual-style connections:
Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers]
Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers]
Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider]
Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers]
- Deeper architecture than Multi-MLP
- BatchNorm + SiLU activation
- Best accuracy when combined with error predictor (Delta-E 0.52)
- Largest model (~14MB base, ~28MB with error predictor)
*Deep + Wide*
Combined deep and wide network paths:
Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4)
└──► Wide Path (direct connection) ─┘
- Deep path captures complex patterns
- Wide path preserves direct input information
- Good for mixed linear/nonlinear relationships
*MLP + Self-Attention*
MLP with attention mechanism for feature weighting:
Input (3) ──► MLP ──► Self-Attention ──► Output (4)
- Attention weights learn feature importance
- Slower due to attention computation (0.173 ms)
- Did not improve over simpler MLPs
*Mixture of Experts*
Gating network selecting specialized expert networks:
Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4)
- Multiple expert networks specialize in different input regions
- Gating network learns which expert to use
- More complex but did not outperform Multi-MLP
*FT-Transformer*
Feature Tokenizer Transformer for tabular data:
Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4)
- Each input feature tokenized separately
- Self-attention across feature tokens
- Good for tabular data with feature interactions
- Slower inference due to attention computation
*Error Predictor (Two-Stage)*
Second-stage network that corrects base model errors:
Stage 1: Input (3) ──► Base Model ──► Base Prediction (4)
Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4)
Final: Base Prediction + Error Correction = Final Output
- Learns residual corrections for each component
- Can have unified (1 network) or multi (4 networks) error predictors
- Consistently improves accuracy across all base architectures
- Best results: Multi-ResNet + Multi-Error Predictor (Delta-E 0.52)
**Loss-Metric Mismatch**
An important finding: **optimizing MSE does not optimize Delta-E**.
The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because:
- MSE treats all component errors equally
- Delta-E (CIE2000) weights errors based on human perception
- The precision-focused loss with custom weights better approximates perceptual importance
**Weighted Boundary Loss (Experimental)**
Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by:
1. Applying 3x loss weight to samples in challenging regions:
- Hue: 0.18-0.35 (normalized range covering Y/YG/G)
- Value > 0.7 (high brightness)
- Chroma > 0.5 (high saturation)
2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits
**Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model (Multi-MLP W+B + Multi-Error Predictor W+B Large, Delta-E 0.524) and the standard large dataset model (Multi-MLP + Multi-Error Predictor Large, Delta-E 0.525) achieve nearly identical results, making explicit loss weighting optional. The best overall model is Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52.
### Experimental Findings
The following experiments were conducted but did not improve results:
**Delta-E Training**
Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator.
*Hypothesis*: Perceptual Delta-E loss might outperform MSE-trained models.
*Implementation*: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients.
*Results*: While Delta-E is comparable, **hue accuracy is ~10x worse**:
| Metric (Normalized MAE) | Delta-E Model | MSE Model |
|--------------------------|---------------|-----------|
| Hue MAE | 0.30 | 0.03 |
| Value MAE | 0.002 | 0.004 |
| Chroma MAE | 0.007 | 0.008 |
| Code MAE | 0.07 | 0.01 |
| **Delta-E (perceptual)** | **0.52** | **0.50** |
*Key Takeaway*: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification.
**Classical Interpolation**
Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors.
*Results (Validation MAE)*:
| Component | RBF | KD-Tree | Delaunay | ML (Best) |
|-----------|------|---------|----------|-----------|
| Hue | 1.40 | 1.40 | 1.29 | **0.03** |
| Value | 0.01 | 0.10 | 0.02 | 0.05 |
| Chroma | 0.22 | 0.99 | 0.35 | **0.11** |
| Code | 0.33 | 0.28 | 0.28 | **0.00** |
*Key Insight*: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy.
**Circular Hue Loss**
Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps).
*Results*: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24).
*Key Takeaway*: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization.
**REAL-Only Refinement**
Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995).
*Results*: Essentially identical performance (Delta-E 1.5233 vs 1.5191).
*Key Takeaway*: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate.
**Gamma Normalization**
Gamma correction to the Y (luminance) channel during normalization.
*Results*: No consistent improvement across gamma values 1.0-3.0:
| Gamma | Median ΔE (± std) |
|----------------|-------------------|
| 1.0 (baseline) | 0.730 ± 0.054 |
| 2.5 (best) | 0.683 ± 0.132 |
![Gamma sweep results](_static/gamma_sweep_plot.png)
*Key Takeaway*: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise.
## Munsell to xyY (to_xyY)
### Performance Benchmarks
Comprehensive comparison using all 2,734 REAL Munsell colors:
| Model | Delta-E | Speed (ms) |
|-----------------------------------------------|-------------|------------|
| Colour Library (Baseline) | 0.00 | 1.27 |
| **Multi-MLP (Optimized)** | **0.48** | 0.008 |
| Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 |
| Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 |
| Multi-MLP | 0.66 | 0.016 |
| Multi-MLP + Error Predictor | 0.67 | 0.018 |
| Multi-Head (Optimized) | 0.71 | 0.015 |
| Multi-Head | 0.78 | 0.008 |
| Multi-Head + Multi-Error Predictor | 1.11 | 0.028 |
| Simple MLP | 1.42 | **0.0008** |
**Best Models**:
- **Best Accuracy**: Multi-MLP (Optimized) - Delta-E 0.48
- **Fastest**: Simple MLP (0.0008 ms/sample) - 1,654x faster than Colour library
- **Best Balance**: Multi-MLP (Optimized) - 154x faster with Delta-E 0.48
### Model Architectures
9 architectures were evaluated for the Munsell to xyY direction:
**Single-Stage Models**
1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs
2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y)
3. **Multi-Head (Optimized)** - Hyperparameter-optimized variant
4. **Multi-MLP** - 3 completely independent MLP branches
5. **Multi-MLP (Optimized)** - Hyperparameter-optimized variant (BEST)
**Two-Stage Models**
6. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction
7. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors
8. **Multi-MLP (Opt) + Multi-Error Predictor (Opt)** - Optimized two-stage
9. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction
The **Multi-MLP (Optimized)** architecture achieved the best results with Delta-E 0.48.
### Differentiable Approximator
A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss:
- **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU
- **Accuracy**: MAE ~0.0006 for x, y, and Y components
- **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz)
This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables.
## Shared Infrastructure
### Hyperparameter Optimization
Optuna was used for systematic hyperparameter search over:
- Learning rate (1e-4 to 1e-3)
- Batch size (256, 512, 1024)
- Dropout rate (0.0 to 0.2)
- Chroma branch width multiplier (1.0 to 2.0)
- Loss function weights (MSE, Huber)
Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization.
### Training Infrastructure
- **Optimizer**: AdamW with weight decay
- **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5)
- **Early stopping**: Patience=20 epochs
- **Checkpointing**: Best model saved based on validation loss
- **Logging**: MLflow for experiment tracking
### JAX Delta-E Implementation
Located in `learning_munsell/losses/jax_delta_e.py`:
- Differentiable xyY -> XYZ -> Lab color space conversions
- Full CIE 2000 Delta-E implementation with gradient support
- JIT-compiled functions for performance
Usage:
```python
from learning_munsell.losses import delta_E_loss, delta_E_CIE2000
# Compute perceptual loss between predicted and target xyY
loss = delta_E_loss(pred_xyY, target_xyY)
```
## Limitations
### BatchNorm Instability on MPS
Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend:
1. **Validation loss spikes** during training
2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1)
3. **Non-reproducible behavior**
**Affected Models**: Large dataset error predictors using BatchNorm.
**Workarounds**:
1. Use CPU for training
2. Replace BatchNorm with LayerNorm
3. Use smaller models (300K samples vs 2M)
4. Skip error predictor stage for affected models
The recommended production model (`multi_resnet_error_predictor_large.onnx`) was trained on the large dataset and does not exhibit this instability.
**References**:
- [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602)
- [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230)
- [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/)