diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c898190874595004719079d6e20ada91fcd1fb31 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,37 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.onnx.data filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..13f362e9895184eb774c2c4833149aaffc93158f --- /dev/null +++ b/.gitignore @@ -0,0 +1,37 @@ +# Common Files +*.egg-info +*.pyc +*.pyo +.DS_Store +.coverage* +uv.lock + +# Common Directories +.fleet/ +.idea/ +.ipynb_checkpoints/ +.python-version +.vs/ +.vscode/ +.sandbox/ +build/ +dist/ +docs/_build/ +docs/generated/ +node_modules/ +references/ + +__pycache__ + +.claude/settings.local.json +.claude/scratchpad.md + +# Project Directories +data/ +logs/ +mlartifacts/ +mlruns/ +mlruns.db +reports/ +results/ +runs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f7aac0b3503916ad2b6dff4dd680a8d0ef5efda --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v5.0.0" + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: requirements-txt-fixer + - id: trailing-whitespace + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + args: ["--ignore-words-list=colour"] + - repo: https://github.com/PyCQA/isort + rev: "6.0.1" + hooks: + - id: isort + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.12.4" + hooks: + - id: ruff-format + - id: ruff + args: [--fix] + - repo: https://github.com/pre-commit/mirrors-prettier + rev: "v4.0.0-alpha.8" + hooks: + - id: prettier + - repo: https://github.com/pre-commit/pygrep-hooks + rev: "v1.10.0" + hooks: + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..48299407805ba086cebc5c2aa2881e51cef609c7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,11 @@ +Copyright 2025 Colour Developers + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ac238472a0e35c800bc30ac4900168e02aa156d --- /dev/null +++ b/README.md @@ -0,0 +1,278 @@ +--- +license: bsd-3-clause +language: + - en +tags: + - python + - colour + - color + - colour-science + - color-science + - colour-spaces + - color-spaces + - colourspace + - colorspace +pipeline_tag: tabular-regression +library_name: onnxruntime +metrics: + - mae +model-index: + - name: from_xyY (CIE xyY to Munsell) + results: + - task: + type: tabular-regression + name: CIE xyY to Munsell Specification + dataset: + name: CIE xyY to Munsell Specification + type: munsell-renotation + metrics: + - type: delta-e + value: 0.52 + name: Delta-E CIE2000 + - type: inference_time_ms + value: 0.089 + name: Inference Time (ms/sample) + - name: to_xyY (Munsell to CIE xyY) + results: + - task: + type: tabular-regression + name: Munsell Specification to CIE xyY + dataset: + name: Munsell Specification to CIE xyY + type: munsell-renotation + metrics: + - type: delta-e + value: 0.48 + name: Delta-E CIE2000 + - type: inference_time_ms + value: 0.008 + name: Inference Time (ms/sample) +--- + +# Learning Munsell - Machine Learning for Munsell Color Conversions + +A project implementing machine learning-based methods for bidirectional conversion between CIE xyY colourspace values and Munsell specifications. + +**Two Conversion Directions:** + +- **from_xyY**: CIE xyY to Munsell specification +- **to_xyY**: Munsell specification to CIE xyY + +## Project Overview + +### Objective + +Provide 100-1000x speedup for batch Munsell conversions compared to colour-science routines while maintaining high accuracy. + +### Results + +**from_xyY** (CIE xyY to Munsell) — evaluated on all 2,734 REAL Munsell colors: + +| Model | Delta-E | Speed (ms) | +|----------------------------------------------------------| ---------- | ---------- | +| Colour Library (Baseline) | 0.00 | 111.90 | +| **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 | +| Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 | +| Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 | +| Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 | +| MLP + Error Predictor | 0.53 | 0.030 | +| Multi-ResNet (Large Dataset) | 0.54 | 0.044 | +| Multi-Head + Multi-Error Predictor | 0.54 | 0.042 | +| Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 | +| Deep + Wide | 0.60 | 0.074 | +| Multi-Head (Large Dataset) | 0.66 | 0.013 | +| Mixture of Experts | 0.80 | 0.020 | +| Transformer (Large Dataset) | 0.82 | 0.123 | +| Multi-MLP | 0.86 | 0.027 | +| MLP + Self-Attention | 0.88 | 0.173 | +| MLP (Base Only) | 1.09 | **0.007** | +| Unified MLP | 1.12 | 0.072 | + +- **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) — Delta-E 0.52, 1,252x faster +- **Fastest**: MLP Base Only (0.007 ms/sample) — 15,492x faster than Colour library +- **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large — 1,951x faster with Delta-E 0.52 + +**to_xyY** (Munsell to CIE xyY) — evaluated on all 2,734 REAL Munsell colors: + +| Model | Delta-E | Speed (ms) | +| --------------------------------------------- | ---------- | ----------- | +| Colour Library (Baseline) | 0.00 | 1.27 | +| **Multi-MLP (Optimized)** | **0.48** | 0.008 | +| Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 | +| Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 | +| Multi-MLP | 0.66 | 0.016 | +| Multi-MLP + Error Predictor | 0.67 | 0.018 | +| Multi-Head (Optimized) | 0.71 | 0.015 | +| Multi-Head | 0.78 | 0.008 | +| Multi-Head + Multi-Error Predictor | 1.11 | 0.028 | +| Simple MLP | 1.42 | **0.0008** | + +- **Best Accuracy**: Multi-MLP (Optimized) — Delta-E 0.48, 154x faster +- **Fastest**: Simple MLP (0.0008 ms/sample) — 1,654x faster than Colour library + +### Approach + +- **25+ architectures** tested for from_xyY (MLP, Multi-Head, Multi-MLP, Multi-ResNet, Transformers, Mixture of Experts) +- **9 architectures** tested for to_xyY (Simple MLP, Multi-Head, Multi-MLP with error predictors) +- **Two-stage models** (base + error predictor) on large dataset proved most effective +- **Best model**: Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52 +- **Training data**: ~1.4M samples from dense xyY grid with boundary refinement and forward Munsell sampling +- **Deployment**: ONNX format with ONNX Runtime + +For detailed architecture comparisons, model benchmarks, training pipeline details, and experimental results, see [docs/learning_munsell.md](docs/learning_munsell.md). + +## Installation + +**Dependencies (Runtime)**: + +- numpy >= 2.0 +- onnxruntime >= 1.16 + +**Dependencies (Training)**: + +- torch >= 2.0 +- scikit-learn >= 1.3 +- matplotlib >= 3.9 +- mlflow >= 2.10 +- optuna >= 3.0 +- colour-science >= 0.4.7 +- click >= 8.0 +- onnx >= 1.15 +- onnxscript >= 0.5.6 +- tqdm >= 4.66 +- jax >= 0.4.20 +- jaxlib >= 0.4.20 +- flax >= 0.10.7 +- optax >= 0.2.6 +- scipy >= 1.12 +- tensorboard >= 2.20 + +From the project root: + +```bash +cd learning-munsell + +# Install all dependencies (creates virtual environment automatically) +uv sync +``` + +## Usage + +### Generate Training Data + +```bash +uv run python learning_munsell/data_generation/generate_training_data.py +``` + +**Note**: This step is computationally expensive (uses iterative algorithm for ground truth). + +### Train Models + +**xyY to Munsell (from_xyY)** + +Best performing model (Multi-ResNet + Multi-Error Predictor on Large Dataset): + +```bash +# Train base Multi-ResNet on large dataset (~1.4M samples) +uv run python learning_munsell/training/from_xyY/train_multi_resnet_large.py + +# Train multi-error predictor +uv run python learning_munsell/training/from_xyY/train_multi_resnet_error_predictor_large.py +``` + +Alternative (Multi-Head architecture): + +```bash +uv run python learning_munsell/training/from_xyY/train_multi_head_large.py +uv run python learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py +``` + +Other architectures: + +```bash +uv run python learning_munsell/training/from_xyY/train_unified_mlp.py +uv run python learning_munsell/training/from_xyY/train_multi_mlp.py +uv run python learning_munsell/training/from_xyY/train_mlp_attention.py +uv run python learning_munsell/training/from_xyY/train_deep_wide.py +uv run python learning_munsell/training/from_xyY/train_ft_transformer.py +``` + +**Munsell to xyY (to_xyY)** + +Best performing model (Multi-MLP Optimized): + +```bash +uv run python learning_munsell/training/to_xyY/train_multi_mlp.py +uv run python learning_munsell/training/to_xyY/train_multi_head.py +uv run python learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py +uv run python learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py +uv run python learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py +``` + +Train the differentiable approximator for use in Delta-E loss: + +```bash +uv run python learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py +``` + +### Hyperparameter Search + +```bash +uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py +uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py +``` + +### Compare All Models + +```bash +uv run python learning_munsell/comparison/from_xyY/compare_all_models.py +``` + +Generates comprehensive HTML report at `reports/from_xyY/model_comparison.html`. + +### Monitor Training + +**MLflow**: + +```bash +uv run mlflow ui --backend-store-uri "sqlite:///mlruns.db" --port=5000 +``` + +Open in your browser. + +## Directory Structure + +``` +learning-munsell/ ++-- data/ # Training data +| +-- training_data.npz # Generated training samples +| +-- training_data_large.npz # Large dataset (~1.4M samples) +| +-- training_data_params.json # Generation parameters +| +-- training_data_large_params.json ++-- models/ # Trained models (ONNX + PyTorch) +| +-- from_xyY/ # xyY to Munsell models (25+ ONNX models) +| | +-- multi_resnet_error_predictor_large.onnx # BEST +| | +-- ... (additional model variants) +| +-- to_xyY/ # Munsell to xyY models (9 ONNX models) +| +-- multi_mlp_optimized.onnx # BEST +| +-- ... (additional model variants) ++-- learning_munsell/ # Source code +| +-- analysis/ # Analysis scripts +| +-- comparison/ # Model comparison scripts +| +-- data_generation/ # Data generation scripts +| +-- interpolation/ # Classical interpolation methods +| +-- losses/ # Loss functions (JAX Delta-E) +| +-- models/ # Model architecture definitions +| +-- training/ # Model training scripts +| +-- utilities/ # Shared utilities ++-- docs/ # Documentation ++-- reports/ # HTML comparison reports ++-- logs/ # Script output logs ++-- mlruns.db # MLflow experiment tracking database +``` + +## About + +**Learning Munsell** by Colour Developers +Research project for the Colour library + diff --git a/docs/_static/gamma_sweep_plot.pdf b/docs/_static/gamma_sweep_plot.pdf new file mode 100644 index 0000000000000000000000000000000000000000..1c3993f59a3a2459b4a4bd29d615c573bdd5cc25 Binary files /dev/null and b/docs/_static/gamma_sweep_plot.pdf differ diff --git a/docs/_static/gamma_sweep_plot.png b/docs/_static/gamma_sweep_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..3c14c7f25ba2569223c225517cf7c1fb65f6f718 --- /dev/null +++ b/docs/_static/gamma_sweep_plot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2a0d5dc57c0d37d5889cff4ac41a08b490387a54615d4372af5e5bd86018e36 +size 136568 diff --git a/docs/learning_munsell.md b/docs/learning_munsell.md new file mode 100644 index 0000000000000000000000000000000000000000..e571f1c672ba9a08375c8b337e9bcf7887541a34 --- /dev/null +++ b/docs/learning_munsell.md @@ -0,0 +1,478 @@ +# Learning Munsell + +Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings. + +## Overview + +This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications: + +- **xyY to Munsell (from_xyY)**: 25+ architectures, best Delta-E 0.52 +- **Munsell to xyY (to_xyY)**: 9 architectures, best Delta-E 0.48 + +### Delta-E Interpretation + +- **< 1.0**: Not perceptible by human eye +- **1-2**: Perceptible through close observation +- **2-10**: Perceptible at a glance +- **> 10**: Colors are perceived as completely different + +Our best models achieve **Delta-E 0.48-0.52**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**. + +## xyY to Munsell (from_xyY) + +### Performance Benchmarks + +Comprehensive comparison using all 2,734 REAL Munsell colors: + +| Model | Delta-E | Speed (ms) | +|----------------------------------------------------------|-------------|------------| +| Colour Library (Baseline) | 0.00 | 111.90 | +| **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 | +| Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 | +| Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 | +| Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 | +| MLP + Error Predictor | 0.53 | 0.030 | +| Multi-ResNet (Large Dataset) | 0.54 | 0.044 | +| Multi-Head + Multi-Error Predictor | 0.54 | 0.042 | +| Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 | +| Deep + Wide | 0.60 | 0.074 | +| Multi-Head (Large Dataset) | 0.66 | 0.013 | +| Mixture of Experts | 0.80 | 0.020 | +| Transformer (Large Dataset) | 0.82 | 0.123 | +| Multi-MLP | 0.86 | 0.027 | +| MLP + Self-Attention | 0.88 | 0.173 | +| MLP (Base Only) | 1.09 | **0.007** | +| Unified MLP | 1.12 | 0.072 | + +Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate). + +**Best Models**: + +- **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) - Delta-E 0.52 +- **Fastest**: MLP Base Only (0.007 ms/sample) - 15,492x faster than Colour library +- **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large - 1,951x faster with Delta-E 0.52 + +### Model Architectures + +25+ architectures were systematically evaluated: + +**Single-Stage Models** + +1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs +2. **Unified MLP** - Single large MLP with shared features +3. **Multi-Head** - Shared encoder with 4 independent decoder heads +4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples +5. **Multi-MLP** - 4 completely independent MLP branches (one per output) +6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples +7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting +8. **Deep + Wide** - Combined deep and wide network paths +9. **Mixture of Experts** - Gating network selecting specialized expert networks +10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data +11. **FT-Transformer** - Feature Tokenizer Transformer (standard size) + +**Two-Stage Models** + +12. **MLP + Error Predictor** - Base MLP with unified error correction +13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors +14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant +15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors +16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant +17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches (BEST) + +The **Multi-ResNet + Multi-Error Predictor (Large Dataset)** architecture achieved the best results with Delta-E 0.52. + +### Training Methodology + +**Data Generation** + +1. **Dense xyY Grid** (~500K samples) + - Regular grid in valid xyY space (MacAdam limits for Illuminant C) + - Captures general input distribution +2. **Boundary Refinement** (~700K samples) + - Adaptive dense sampling near Munsell gamut boundaries + - Uses `maximum_chroma_from_renotation` to detect edges + - Focuses on regions where iterative algorithm is most complex + - Includes Y/GY/G hue regions with high value/chroma (challenging areas) +3. **Forward Augmentation** (~200K samples) + - Dense Munsell space sampling via `munsell_specification_to_xyY` + - Ensures coverage of known valid colors + +Total: ~1.4M samples for large dataset training. + +**Loss Functions** + +Two loss function approaches were tested: + +*Precision-Focused Loss* (Default): + +``` +total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss +``` + +- MSE: Standard mean squared error +- MAE: Mean absolute error +- Log penalty: Heavily penalizes small errors (pushes toward high precision) +- Huber loss: Small delta (0.01) for precision on small errors + +*Pure MSE Loss* (Optimized config): + +``` +total_loss = MSE +``` + +Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy. + +### Design Rationale + +**Two-Stage Architecture** + +The error predictor stage corrects systematic biases in the base model: + +1. Base model learns the general xyY to Munsell mapping +2. Error predictor learns residual corrections specific to each component +3. Combined prediction: `final = base_prediction + error_correction` + +This decomposition allows each stage to specialize and reduces the complexity each network must learn. + +**Independent Branch Design** + +Munsell components have different characteristics: + +- **Hue**: Circular (0-10, wrapping), most complex +- **Value**: Linear (0-10), easiest to predict +- **Chroma**: Highly variable range depending on hue/value +- **Code**: Discrete hue sector (0-9) + +Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization. + +**Architecture Details** + +*MLP (Base Only)* + +Simple feedforward network predicting all 4 outputs simultaneously: + + Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code) + +- Smallest model (~8KB ONNX) +- Fastest inference (0.007 ms) +- Baseline for comparison + +*Unified MLP* + +Single large MLP with shared internal features: + + Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4) + +- Shared representations across all outputs +- Moderate size, good speed + +*Multi-Head MLP* + +Shared encoder with specialized decoder heads: + + Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1) + ├──► Value Head (512→256→128→1) + ├──► Chroma Head (512→384→256→128→1) + └──► Code Head (512→256→128→1) + +- Shared encoder learns common color space features +- 4 specialized decoder heads branch from shared representation +- Parameter efficient (encoder weights shared) +- Fast inference (encoder computed once) + +*Multi-MLP* + +Fully independent branches with no weight sharing: + + Input (3) ──► Hue Branch (3→128→256→512→256→128→1) + Input (3) ──► Value Branch (3→128→256→512→256→128→1) + Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider] + Input (3) ──► Code Branch (3→128→256→512→256→128→1) + +- 4 completely independent MLPs +- Each branch learns its own features from scratch +- Chroma branch is wider (2x) to handle its complexity +- Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors) + +*Multi-ResNet* + +Deep branches with residual-style connections: + + Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers] + Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers] + Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider] + Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers] + +- Deeper architecture than Multi-MLP +- BatchNorm + SiLU activation +- Best accuracy when combined with error predictor (Delta-E 0.52) +- Largest model (~14MB base, ~28MB with error predictor) + +*Deep + Wide* + +Combined deep and wide network paths: + + Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4) + └──► Wide Path (direct connection) ─┘ + +- Deep path captures complex patterns +- Wide path preserves direct input information +- Good for mixed linear/nonlinear relationships + +*MLP + Self-Attention* + +MLP with attention mechanism for feature weighting: + + Input (3) ──► MLP ──► Self-Attention ──► Output (4) + +- Attention weights learn feature importance +- Slower due to attention computation (0.173 ms) +- Did not improve over simpler MLPs + +*Mixture of Experts* + +Gating network selecting specialized expert networks: + + Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4) + +- Multiple expert networks specialize in different input regions +- Gating network learns which expert to use +- More complex but did not outperform Multi-MLP + +*FT-Transformer* + +Feature Tokenizer Transformer for tabular data: + + Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4) + +- Each input feature tokenized separately +- Self-attention across feature tokens +- Good for tabular data with feature interactions +- Slower inference due to attention computation + +*Error Predictor (Two-Stage)* + +Second-stage network that corrects base model errors: + + Stage 1: Input (3) ──► Base Model ──► Base Prediction (4) + Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4) + Final: Base Prediction + Error Correction = Final Output + +- Learns residual corrections for each component +- Can have unified (1 network) or multi (4 networks) error predictors +- Consistently improves accuracy across all base architectures +- Best results: Multi-ResNet + Multi-Error Predictor (Delta-E 0.52) + +**Loss-Metric Mismatch** + +An important finding: **optimizing MSE does not optimize Delta-E**. + +The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because: + +- MSE treats all component errors equally +- Delta-E (CIE2000) weights errors based on human perception +- The precision-focused loss with custom weights better approximates perceptual importance + +**Weighted Boundary Loss (Experimental)** + +Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by: + +1. Applying 3x loss weight to samples in challenging regions: + - Hue: 0.18-0.35 (normalized range covering Y/YG/G) + - Value > 0.7 (high brightness) + - Chroma > 0.5 (high saturation) +2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits + +**Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model (Multi-MLP W+B + Multi-Error Predictor W+B Large, Delta-E 0.524) and the standard large dataset model (Multi-MLP + Multi-Error Predictor Large, Delta-E 0.525) achieve nearly identical results, making explicit loss weighting optional. The best overall model is Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52. + +### Experimental Findings + +The following experiments were conducted but did not improve results: + +**Delta-E Training** + +Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator. + +*Hypothesis*: Perceptual Delta-E loss might outperform MSE-trained models. + +*Implementation*: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients. + +*Results*: While Delta-E is comparable, **hue accuracy is ~10x worse**: + +| Metric (Normalized MAE) | Delta-E Model | MSE Model | +|--------------------------|---------------|-----------| +| Hue MAE | 0.30 | 0.03 | +| Value MAE | 0.002 | 0.004 | +| Chroma MAE | 0.007 | 0.008 | +| Code MAE | 0.07 | 0.01 | +| **Delta-E (perceptual)** | **0.52** | **0.50** | + +*Key Takeaway*: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification. + +**Classical Interpolation** + +Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors. + +*Results (Validation MAE)*: + +| Component | RBF | KD-Tree | Delaunay | ML (Best) | +|-----------|------|---------|----------|-----------| +| Hue | 1.40 | 1.40 | 1.29 | **0.03** | +| Value | 0.01 | 0.10 | 0.02 | 0.05 | +| Chroma | 0.22 | 0.99 | 0.35 | **0.11** | +| Code | 0.33 | 0.28 | 0.28 | **0.00** | + +*Key Insight*: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy. + +**Circular Hue Loss** + +Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps). + +*Results*: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24). + +*Key Takeaway*: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization. + +**REAL-Only Refinement** + +Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995). + +*Results*: Essentially identical performance (Delta-E 1.5233 vs 1.5191). + +*Key Takeaway*: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate. + +**Gamma Normalization** + +Gamma correction to the Y (luminance) channel during normalization. + +*Results*: No consistent improvement across gamma values 1.0-3.0: + +| Gamma | Median ΔE (± std) | +|----------------|-------------------| +| 1.0 (baseline) | 0.730 ± 0.054 | +| 2.5 (best) | 0.683 ± 0.132 | + +![Gamma sweep results](_static/gamma_sweep_plot.png) + +*Key Takeaway*: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise. + +## Munsell to xyY (to_xyY) + +### Performance Benchmarks + +Comprehensive comparison using all 2,734 REAL Munsell colors: + +| Model | Delta-E | Speed (ms) | +|-----------------------------------------------|-------------|------------| +| Colour Library (Baseline) | 0.00 | 1.27 | +| **Multi-MLP (Optimized)** | **0.48** | 0.008 | +| Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 | +| Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 | +| Multi-MLP | 0.66 | 0.016 | +| Multi-MLP + Error Predictor | 0.67 | 0.018 | +| Multi-Head (Optimized) | 0.71 | 0.015 | +| Multi-Head | 0.78 | 0.008 | +| Multi-Head + Multi-Error Predictor | 1.11 | 0.028 | +| Simple MLP | 1.42 | **0.0008** | + +**Best Models**: + +- **Best Accuracy**: Multi-MLP (Optimized) - Delta-E 0.48 +- **Fastest**: Simple MLP (0.0008 ms/sample) - 1,654x faster than Colour library +- **Best Balance**: Multi-MLP (Optimized) - 154x faster with Delta-E 0.48 + +### Model Architectures + +9 architectures were evaluated for the Munsell to xyY direction: + +**Single-Stage Models** + +1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs +2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y) +3. **Multi-Head (Optimized)** - Hyperparameter-optimized variant +4. **Multi-MLP** - 3 completely independent MLP branches +5. **Multi-MLP (Optimized)** - Hyperparameter-optimized variant (BEST) + +**Two-Stage Models** + +6. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction +7. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors +8. **Multi-MLP (Opt) + Multi-Error Predictor (Opt)** - Optimized two-stage +9. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction + +The **Multi-MLP (Optimized)** architecture achieved the best results with Delta-E 0.48. + +### Differentiable Approximator + +A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss: + +- **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU +- **Accuracy**: MAE ~0.0006 for x, y, and Y components +- **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz) + +This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables. + +## Shared Infrastructure + +### Hyperparameter Optimization + +Optuna was used for systematic hyperparameter search over: + +- Learning rate (1e-4 to 1e-3) +- Batch size (256, 512, 1024) +- Dropout rate (0.0 to 0.2) +- Chroma branch width multiplier (1.0 to 2.0) +- Loss function weights (MSE, Huber) + +Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization. + +### Training Infrastructure + +- **Optimizer**: AdamW with weight decay +- **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5) +- **Early stopping**: Patience=20 epochs +- **Checkpointing**: Best model saved based on validation loss +- **Logging**: MLflow for experiment tracking + +### JAX Delta-E Implementation + +Located in `learning_munsell/losses/jax_delta_e.py`: + +- Differentiable xyY -> XYZ -> Lab color space conversions +- Full CIE 2000 Delta-E implementation with gradient support +- JIT-compiled functions for performance + +Usage: + +```python +from learning_munsell.losses import delta_E_loss, delta_E_CIE2000 + +# Compute perceptual loss between predicted and target xyY +loss = delta_E_loss(pred_xyY, target_xyY) +``` + +## Limitations + +### BatchNorm Instability on MPS + +Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend: + +1. **Validation loss spikes** during training +2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1) +3. **Non-reproducible behavior** + +**Affected Models**: Large dataset error predictors using BatchNorm. + +**Workarounds**: + +1. Use CPU for training +2. Replace BatchNorm with LayerNorm +3. Use smaller models (300K samples vs 2M) +4. Skip error predictor stage for affected models + +The recommended production model (`multi_resnet_error_predictor_large.onnx`) was trained on the large dataset and does not exhibit this instability. + +**References**: + +- [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602) +- [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230) +- [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/) diff --git a/learning_munsell/__init__.py b/learning_munsell/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9150aa8ee0f8a9cc44277710e3becbec1a9da32 --- /dev/null +++ b/learning_munsell/__init__.py @@ -0,0 +1,7 @@ +"""Learning Munsell - Machine Learning for Munsell Color Conversions.""" + +from pathlib import Path + +__all__ = ["PROJECT_ROOT"] + +PROJECT_ROOT = Path(__file__).parent.parent diff --git a/learning_munsell/analysis/__init__.py b/learning_munsell/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d74ec59c91f98d1509270569583fb895b13bef --- /dev/null +++ b/learning_munsell/analysis/__init__.py @@ -0,0 +1 @@ +"""Analysis utilities for Munsell color conversion models.""" diff --git a/learning_munsell/analysis/error_analysis.py b/learning_munsell/analysis/error_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..29ff237e681cf61bb27bebfb3f05376a6067d738 --- /dev/null +++ b/learning_munsell/analysis/error_analysis.py @@ -0,0 +1,304 @@ +""" +Analyze error distribution to identify problematic regions in Munsell space. + +This script: +1. Runs the best model on all REAL Munsell colors +2. Computes Delta-E for each sample +3. Identifies samples with high error (Delta-E > threshold) +4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues +5. Outputs statistics and visualizations +""" + +import logging +from collections import defaultdict + +import numpy as np +import onnxruntime as ort +from colour import XYZ_to_Lab, xyY_to_XYZ +from colour.difference import delta_E_CIE2000 +from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL +from colour.notation.munsell import ( + CCS_ILLUMINANT_MUNSELL, + munsell_colour_to_munsell_specification, + munsell_specification_to_xyY, +) + +from learning_munsell import PROJECT_ROOT + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + +HUE_NAMES = { + 1: "R", + 2: "YR", + 3: "Y", + 4: "GY", + 5: "G", + 6: "BG", + 7: "B", + 8: "PB", + 9: "P", + 10: "RP", + 0: "RP", +} + + +def load_model_and_params(model_name: str): + """Load ONNX model and normalization parameters.""" + model_dir = PROJECT_ROOT / "models" / "from_xyY" + + model_path = model_dir / f"{model_name}.onnx" + params_path = model_dir / f"{model_name}_normalization_params.npz" + + if not model_path.exists(): + raise FileNotFoundError(f"Model not found: {model_path}") + if not params_path.exists(): + raise FileNotFoundError(f"Params not found: {params_path}") + + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + input_params = params["input_params"].item() + output_params = params["output_params"].item() + + return session, input_params, output_params + + +def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray: + """Normalize xyY input.""" + normalized = np.copy(xyY).astype(np.float32) + # Scale Y from 0-100 to 0-1 range before normalization + normalized[..., 2] = xyY[..., 2] / 100.0 + normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / ( + params["x_range"][1] - params["x_range"][0] + ) + normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / ( + params["y_range"][1] - params["y_range"][0] + ) + normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / ( + params["Y_range"][1] - params["Y_range"][0] + ) + return normalized + + +def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray: + """Denormalize Munsell output.""" + denorm = np.copy(pred) + denorm[..., 0] = ( + pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) + + params["hue_range"][0] + ) + denorm[..., 1] = ( + pred[..., 1] * (params["value_range"][1] - params["value_range"][0]) + + params["value_range"][0] + ) + denorm[..., 2] = ( + pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + + params["chroma_range"][0] + ) + denorm[..., 3] = ( + pred[..., 3] * (params["code_range"][1] - params["code_range"][0]) + + params["code_range"][0] + ) + return denorm + + +def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float: + """Compute Delta-E between predicted spec (via xyY) and ground truth xyY.""" + try: + pred_xyY = munsell_specification_to_xyY(pred_spec) + pred_XYZ = xyY_to_XYZ(pred_xyY) + pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL) + + # Ground truth Y is in 0-100 range, need to scale to 0-1 + gt_xyY_scaled = gt_xyY.copy() + gt_xyY_scaled[2] = gt_xyY[2] / 100.0 + gt_XYZ = xyY_to_XYZ(gt_xyY_scaled) + gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL) + + return delta_E_CIE2000(gt_Lab, pred_Lab) + except Exception: + return np.nan + + +def analyze_errors(model_name: str = "multi_head_large", threshold: float = 3.0): + """Analyze error distribution for a model.""" + LOGGER.info("=" * 80) + LOGGER.info("Error Analysis for %s", model_name) + LOGGER.info("=" * 80) + + # Load model + session, input_params, output_params = load_model_and_params(model_name) + input_name = session.get_inputs()[0].name + + # Collect data + results = [] + + for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL: + hue_code_str, value, chroma = munsell_spec_tuple + munsell_str = f"{hue_code_str} {value}/{chroma}" + + try: + gt_spec = munsell_colour_to_munsell_specification(munsell_str) + gt_xyY = np.array(xyY_gt) + + # Predict + xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_params) + pred_norm = session.run(None, {input_name: xyY_norm})[0] + pred_spec = denormalize_output(pred_norm, output_params)[0] + + # Clamp to valid ranges + pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0) + pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0) + pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0) + pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0) + pred_spec[3] = np.round(pred_spec[3]) + + # Compute Delta-E + delta_e = compute_delta_e(pred_spec, gt_xyY) + + if not np.isnan(delta_e): + results.append({ + "munsell_str": munsell_str, + "gt_spec": gt_spec, + "pred_spec": pred_spec, + "delta_e": delta_e, + "hue": gt_spec[0], + "value": gt_spec[1], + "chroma": gt_spec[2], + "code": int(gt_spec[3]), + "gt_xyY": gt_xyY, + }) + except Exception as e: + LOGGER.warning("Failed for %s: %s", munsell_str, e) + + LOGGER.info("\nTotal samples evaluated: %d", len(results)) + + # Overall statistics + delta_es = [r["delta_e"] for r in results] + LOGGER.info("\nOverall Delta-E Statistics:") + LOGGER.info(" Mean: %.4f", np.mean(delta_es)) + LOGGER.info(" Median: %.4f", np.median(delta_es)) + LOGGER.info(" Std: %.4f", np.std(delta_es)) + LOGGER.info(" Min: %.4f", np.min(delta_es)) + LOGGER.info(" Max: %.4f", np.max(delta_es)) + + # Distribution + LOGGER.info("\nDelta-E Distribution:") + for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]: + count = sum(1 for d in delta_es if d <= thresh) + pct = 100 * count / len(delta_es) + LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct) + + # High error samples + high_error = [r for r in results if r["delta_e"] > threshold] + LOGGER.info("\nSamples with Delta-E > %.1f: %d (%.1f%%)", + threshold, len(high_error), 100 * len(high_error) / len(results)) + + # Analyze by hue family + LOGGER.info("\n" + "=" * 40) + LOGGER.info("Analysis by Hue Family") + LOGGER.info("=" * 40) + + by_hue = defaultdict(list) + for r in results: + hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}") + by_hue[hue_name].append(r["delta_e"]) + + LOGGER.info("\n%-4s %5s %6s %6s %6s %s", + "Hue", "Count", "Mean", "Median", "Max", ">3.0") + for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]: + if hue_name in by_hue: + des = by_hue[hue_name] + high = sum(1 for d in des if d > 3.0) + LOGGER.info("%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)", + hue_name, len(des), np.mean(des), np.median(des), + np.max(des), high, 100*high/len(des)) + + # Analyze by value range + LOGGER.info("\n" + "=" * 40) + LOGGER.info("Analysis by Value Range") + LOGGER.info("=" * 40) + + value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)] + LOGGER.info("\n%-8s %5s %6s %6s %6s %s", + "Value", "Count", "Mean", "Median", "Max", ">3.0") + for v_min, v_max in value_ranges: + des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max] + if des: + high = sum(1 for d in des if d > 3.0) + LOGGER.info("[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)", + v_min, v_max, len(des), np.mean(des), np.median(des), + np.max(des), high, 100*high/len(des) if des else 0) + + # Analyze by chroma range + LOGGER.info("\n" + "=" * 40) + LOGGER.info("Analysis by Chroma Range") + LOGGER.info("=" * 40) + + chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)] + LOGGER.info("\n%-8s %5s %6s %6s %6s %s", + "Chroma", "Count", "Mean", "Median", "Max", ">3.0") + for c_min, c_max in chroma_ranges: + des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max] + if des: + high = sum(1 for d in des if d > 3.0) + LOGGER.info("[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)", + c_min, c_max, len(des), np.mean(des), np.median(des), + np.max(des), high, 100*high/len(des) if des else 0) + + # Top 20 worst samples + LOGGER.info("\n" + "=" * 40) + LOGGER.info("Top 20 Worst Samples") + LOGGER.info("=" * 40) + + worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20] + LOGGER.info("\n%-15s %6s %-20s %-20s", + "Munsell", "DeltaE", "GT Spec", "Pred Spec") + for r in worst: + gt = f"[{r['gt_spec'][0]:.1f}, {r['gt_spec'][1]:.1f}, {r['gt_spec'][2]:.1f}, {int(r['gt_spec'][3])}]" + pred = f"[{r['pred_spec'][0]:.1f}, {r['pred_spec'][1]:.1f}, {r['pred_spec'][2]:.1f}, {int(r['pred_spec'][3])}]" + LOGGER.info("%-15s %6.2f %-20s %-20s", + r["munsell_str"], r["delta_e"], gt, pred) + + # Analyze component errors for high-error samples + LOGGER.info("\n" + "=" * 40) + LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold) + LOGGER.info("=" * 40) + + if high_error: + hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error] + value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error] + chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error] + code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error] + + LOGGER.info("\n%-10s %6s %6s %6s", + "Component", "Mean", "Median", "Max") + LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Hue", + np.mean(hue_errors), np.median(hue_errors), np.max(hue_errors)) + LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Value", + np.mean(value_errors), np.median(value_errors), np.max(value_errors)) + LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Chroma", + np.mean(chroma_errors), np.median(chroma_errors), np.max(chroma_errors)) + LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Code", + np.mean(code_errors), np.median(code_errors), np.max(code_errors)) + + return results + + +def main(): + """Run error analysis.""" + # Try the best models + models = [ + "multi_head_large", + ] + + for model_name in models: + try: + analyze_errors(model_name, threshold=3.0) + except FileNotFoundError as e: + LOGGER.warning("Skipping %s: %s", model_name, e) + LOGGER.info("\n") + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/comparison/from_xyY/__init__.py b/learning_munsell/comparison/from_xyY/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c11c4d7bb304179319eaa99da19a13219485a02 --- /dev/null +++ b/learning_munsell/comparison/from_xyY/__init__.py @@ -0,0 +1 @@ +"""Comparison scripts for xyY to Munsell conversion models.""" diff --git a/learning_munsell/comparison/from_xyY/compare_all_models.py b/learning_munsell/comparison/from_xyY/compare_all_models.py new file mode 100644 index 0000000000000000000000000000000000000000..2a19ecab84e7ac38046110f427f64f9119f4a897 --- /dev/null +++ b/learning_munsell/comparison/from_xyY/compare_all_models.py @@ -0,0 +1,1292 @@ +""" +Compare all ML models for xyY to Munsell conversion on real Munsell data. + +Models to compare: +1. MLP (Base only) +2. MLP + Error Predictor (Two-stage) +3. Unified MLP +4. MLP + Self-Attention +5. MLP + Self-Attention + Error Predictor +6. Deep + Wide +7. Mixture of Experts +8. FT-Transformer +""" + +import logging +import time +import warnings +from datetime import datetime +from pathlib import Path +from typing import Any + +import numpy as np +import onnxruntime as ort +from colour import XYZ_to_Lab, xyY_to_XYZ +from colour.difference import delta_E_CIE2000 +from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL +from colour.notation.munsell import ( + CCS_ILLUMINANT_MUNSELL, + munsell_colour_to_munsell_specification, + munsell_specification_to_xyY, + xyY_to_munsell_specification, +) +from numpy.typing import NDArray + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import ( + benchmark_inference_speed, + get_model_size_mb, +) + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +def normalize_input(X: NDArray, params: dict[str, Any] | None) -> NDArray: + """Normalize xyY input. + + If params is None, xyY is assumed to already be in [0, 1] range (no normalization needed). + """ + if params is None: + # xyY is already in [0, 1] range - no normalization needed + return X.astype(np.float32) + + X_norm = np.copy(X) + X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / ( + params["x_range"][1] - params["x_range"][0] + ) + X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / ( + params["y_range"][1] - params["y_range"][0] + ) + X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / ( + params["Y_range"][1] - params["Y_range"][0] + ) + return X_norm.astype(np.float32) + + +def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray: + """Denormalize Munsell output.""" + y = np.copy(y_norm) + y[..., 0] = ( + y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) + + params["hue_range"][0] + ) + y[..., 1] = ( + y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0]) + + params["value_range"][0] + ) + y[..., 2] = ( + y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + + params["chroma_range"][0] + ) + y[..., 3] = ( + y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0]) + + params["code_range"][0] + ) + return y + + +def clamp_munsell_specification(specification: NDArray) -> NDArray: + """Clamp Munsell specification to valid ranges.""" + + clamped = np.copy(specification) + clamped[..., 0] = np.clip(specification[..., 0], 0.0, 10.0) # Hue: [0, 10] + clamped[..., 1] = np.clip(specification[..., 1], 1.0, 9.0) # Value: [1, 9] (colour library constraint) + clamped[..., 2] = np.clip(specification[..., 2], 0.0, 50.0) # Chroma: [0, 50] + clamped[..., 3] = np.clip(specification[..., 3], 1.0, 10.0) # Code: [1, 10] + + return clamped + + +def evaluate_model( + session: ort.InferenceSession, + X_norm: NDArray, + ground_truth: NDArray, + params: dict[str, Any], + input_name: str = "xyY", + reference_Lab: NDArray | None = None, +) -> dict[str, Any]: + """Evaluate a single model.""" + pred_norm = session.run(None, {input_name: X_norm})[0] + pred = denormalize_output(pred_norm, params) + errors = np.abs(pred - ground_truth) + + result = { + "hue_mae": np.mean(errors[:, 0]), + "value_mae": np.mean(errors[:, 1]), + "chroma_mae": np.mean(errors[:, 2]), + "code_mae": np.mean(errors[:, 3]), + "max_errors": np.max(errors, axis=1), + "hue_errors": errors[:, 0], + "value_errors": errors[:, 1], + "chroma_errors": errors[:, 2], + "code_errors": errors[:, 3], + } + + # Compute Delta-E against ground truth + if reference_Lab is not None: + delta_E_values = [] + for idx in range(len(pred)): + try: + # Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab + ml_spec = clamp_munsell_specification(pred[idx]) + + # Round Code to nearest integer before round-trip conversion + ml_spec_for_conversion = ml_spec.copy() + ml_spec_for_conversion[3] = round(ml_spec[3]) + + ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion) + ml_XYZ = xyY_to_XYZ(ml_xyy) + ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL) + + delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab) + delta_E_values.append(delta_E) + except (RuntimeError, ValueError): + # Skip if conversion fails + continue + + result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan + else: + result["delta_E"] = np.nan + + return result + + +def generate_html_report( + results: dict[str, dict[str, Any]], + num_samples: int, + output_file: Path, + baseline_inference_time_ms: float | None = None, +) -> None: + """Generate HTML report with visualizations.""" + # Calculate metrics + avg_maes = {} + for model_name, result in results.items(): + avg_maes[model_name] = np.mean( + [ + result["hue_mae"], + result["value_mae"], + result["chroma_mae"], + result["code_mae"], + ] + ) + + # Sort by average MAE + sorted_models = sorted(avg_maes.items(), key=lambda x: x[1]) + + # Precision thresholds + thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0] + + html = f""" + + + + + ML Model Comparison Report - {datetime.now().strftime("%Y-%m-%d %H:%M")} + + + + + +
+ +
+

ML Model Comparison Report

+
+

xyY to Munsell Specification Conversion

+

Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

+

Test Samples: {num_samples:,} real Munsell colors

+
+
+""" + + # Best Models Summary (FIRST - moved to top) + # Find best models for each metric + delta_E_values = [ + r["delta_E"] for r in results.values() if not np.isnan(r["delta_E"]) + ] + + best_delta_E = ( + min( + results.items(), + key=lambda x: x[1]["delta_E"] + if not np.isnan(x[1]["delta_E"]) + else float("inf"), + )[0] + if delta_E_values + else None + ) + best_avg = sorted_models[0][0] + + # Performance Metrics Table (FIRST - as summary) + # Find best for each metric + best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0] + best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0] + + # Add Best Models Summary HTML + html += f""" + +
+

Best Models by Metric

+
+
+
Smallest Size
+
{results[best_size]["model_size_mb"]:.2f} MB
+
{best_size}
+
+
+
Fastest Speed
+
{results[best_speed]["inference_time_ms"]:.4f} ms
+
{best_speed}
+
+
+
Best Delta-E
+
{results[best_delta_E]["delta_E"]:.4f}
+
{best_delta_E}
+
+
+
Best Average MAE
+
{avg_maes[best_avg]:.4f}
+
{best_avg}
+
+
+
+""" + + # Get baseline speed (Colour Library Iterative) + baseline_speed = baseline_inference_time_ms + + # Sort by Delta-E for performance table (best first) + sorted_by_delta_E = sorted( + results.items(), + key=lambda x: x[1]["delta_E"] + if not np.isnan(x[1]["delta_E"]) + else float("inf"), + ) + + # Calculate maximum speed multiplier (fastest model) for highlighting + max_speed_multiplier = 0.0 + best_multiplier_model = None + for model_name, result in results.items(): + speed_ms = result["inference_time_ms"] + if speed_ms > 0: + speed_multiplier = baseline_speed / speed_ms + if speed_multiplier > max_speed_multiplier: + max_speed_multiplier = speed_multiplier + best_multiplier_model = model_name + + html += """ + +
+

Model Performance Metrics

+
+ + + + + + + + + + + + +""" + + for model_name, result in sorted_by_delta_E: + size_mb = result["model_size_mb"] + speed_ms = result["inference_time_ms"] + avg_mae = avg_maes[model_name] + delta_E = result["delta_E"] + + # Calculate relative speed (how many times faster than baseline) + speed_multiplier = baseline_speed / speed_ms if speed_ms > 0 else 0 + + size_class = "text-primary font-semibold" if model_name == best_size else "" + speed_class = "text-primary font-semibold" if model_name == best_speed else "" + avg_class = "text-primary font-semibold" if model_name == best_avg else "" + delta_E_class = ( + "text-primary font-semibold" if model_name == best_delta_E else "" + ) + + # Format Delta-E value + delta_E_str = f"{delta_E:.4f}" if not np.isnan(delta_E) else "—" + + # Highlight only the fastest model + if abs(speed_multiplier - 1.0) < 0.01: + # Baseline + multiplier_class = "text-muted-foreground" + multiplier_text = "1.0x" + elif model_name == best_multiplier_model: + # Fastest model (highest multiplier) + multiplier_class = "text-primary font-semibold" + if speed_multiplier > 1000: + multiplier_text = f"{speed_multiplier:.0f}x" + elif speed_multiplier > 100: + multiplier_text = f"{speed_multiplier:.1f}x" + else: + multiplier_text = f"{speed_multiplier:.2f}x" + elif speed_multiplier > 1.0: + # Faster than baseline but not the fastest + multiplier_class = "" + if speed_multiplier > 1000: + multiplier_text = f"{speed_multiplier:.0f}x" + elif speed_multiplier > 100: + multiplier_text = f"{speed_multiplier:.1f}x" + else: + multiplier_text = f"{speed_multiplier:.2f}x" + else: + # Slower than baseline + multiplier_class = "text-destructive" + multiplier_text = f"{speed_multiplier:.2f}x" + + html += f""" + + + + + + + + +""" + + html += """ + +
Model + Size (MB) +
ONNX files
+
+ Speed (ms/sample) +
10 iterations
+
+ vs Baseline +
Colour Iterative
+
+ Delta-E +
vs Colour Lib
+
Average MAE
{model_name}{size_mb:.2f}{speed_ms:.4f}{multiplier_text}{delta_E_str}{avg_mae:.4f}
+
+
+
+
Note: Speed measured with 10 iterations (3 warmup + 10 benchmark) on 2,734 samples.
+
Two-stage models include both base and error predictor. Highlighted values show best in each metric.
+
Baseline comparison: Speed multipliers show relative performance vs Colour Library's iterative xyY_to_munsell_specification(). Values <1.0x are faster.
+
+
+
+""" + + # Overall ranking by Delta-E + html += """ + +
+

Overall Ranking (by Delta-E)

+
+""" + + # Sort by Delta-E (best = lowest) + sorted_by_delta_E_ranking = sorted( + [ + (name, res["delta_E"]) + for name, res in results.items() + if not np.isnan(res["delta_E"]) + ], + key=lambda x: x[1], + ) + + max_delta_E = ( + max(delta_E for _, delta_E in sorted_by_delta_E_ranking) + if sorted_by_delta_E_ranking + else 1.0 + ) + for rank, (model_name, delta_E) in enumerate(sorted_by_delta_E_ranking, 1): + width_pct = (delta_E / max_delta_E) * 100 + html += f""" +
+
+ {rank}. {model_name} +
+
+
+
+
{delta_E:.4f}
+
+""" + + html += """ +
+
+""" + + # Precision Threshold Table + html += """ +
+

Accuracy at Precision Thresholds

+

Percentage of predictions where max error across all components is below threshold:

+
+ + + + +""" + + for threshold in thresholds: + html += f' \n' + + html += """ + + + +""" + + # Find best (highest) accuracy for each threshold column + best_accuracies = {} + min_accuracies = {} + for threshold in thresholds: + accuracies = [ + np.mean(results[model_name]["max_errors"] < threshold) * 100 + for model_name, _ in sorted_models + ] + best_accuracies[threshold] = max(accuracies) + min_accuracies[threshold] = min(accuracies) + + for model_name, _ in sorted_models: + result = results[model_name] + row_class = ( + "bg-primary/10 border-l-2 border-l-primary" + if model_name == best_avg + else "" + ) + html += f""" + + +""" + for threshold in thresholds: + accuracy_pct = np.mean(result["max_errors"] < threshold) * 100 + # Only highlight if there's meaningful variation + # (>0.1% difference between best and worst) + has_variation = ( + best_accuracies[threshold] - min_accuracies[threshold] + ) > 0.1 + is_best = abs(accuracy_pct - best_accuracies[threshold]) < 0.01 + cell_class = ( + "text-right py-3 px-4 font-bold text-primary" + if (has_variation and is_best) + else "text-right py-3 px-4" + ) + html += f' \n' + + html += """ + +""" + + html += """ + +
Model< {threshold:.0e}
{model_name}{accuracy_pct:.2f}%
+
+
+ +
+ + +""" + + # Write HTML file + with open(output_file, "w") as f: + f.write(html) + + LOGGER.info("") + LOGGER.info("HTML report saved to: %s", output_file) + + +def main() -> None: + """Compare all models.""" + LOGGER.info("=" * 80) + LOGGER.info("Comprehensive Model Comparison") + LOGGER.info("=" * 80) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + + # Load real Munsell dataset + LOGGER.info("") + LOGGER.info("Loading real Munsell dataset...") + xyY_samples = [] + ground_truth = [] + + for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL: + try: + hue_code, value, chroma = munsell_spec_tuple + munsell_str = f"{hue_code} {value}/{chroma}" + spec = munsell_colour_to_munsell_specification(munsell_str) + xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0]) + xyY_samples.append(xyY_scaled) + ground_truth.append(spec) + except Exception: # noqa: BLE001, S112 + continue + + xyY_samples = np.array(xyY_samples) + ground_truth = np.array(ground_truth) + LOGGER.info("Loaded %d valid Munsell colors", len(xyY_samples)) + + # Define models to compare + models = [ + { + "name": "MLP (Base Only)", + "files": [model_directory / "mlp.onnx"], + "params_file": model_directory / "mlp_normalization_params.npz", + "type": "single", + }, + { + "name": "MLP + Error Predictor", + "files": [ + model_directory / "mlp.onnx", + model_directory / "mlp_error_predictor.onnx", + ], + "params_file": model_directory / "mlp_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Unified MLP", + "files": [model_directory / "unified_mlp.onnx"], + "params_file": model_directory / "unified_mlp_normalization_params.npz", + "type": "single", + }, + { + "name": "MLP + Self-Attention", + "files": [model_directory / "mlp_attention.onnx"], + "params_file": model_directory + / "mlp_attention_normalization_params.npz", + "type": "single", + }, + { + "name": "MLP + Self-Attention + Error Predictor", + "files": [ + model_directory / "mlp_attention.onnx", + model_directory / "mlp_attention_error_predictor.onnx", + ], + "params_file": model_directory + / "mlp_attention_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Deep + Wide", + "files": [model_directory / "deep_wide.onnx"], + "params_file": model_directory / "deep_wide_normalization_params.npz", + "type": "single", + }, + { + "name": "Mixture of Experts", + "files": [model_directory / "mixture_of_experts.onnx"], + "params_file": model_directory + / "mixture_of_experts_normalization_params.npz", + "type": "single", + }, + { + "name": "FT-Transformer", + "files": [model_directory / "ft_transformer.onnx"], + "params_file": model_directory / "ft_transformer_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head", + "files": [model_directory / "multi_head.onnx"], + "params_file": model_directory / "multi_head_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head (Optimized)", + "files": [model_directory / "multi_head_optimized.onnx"], + "params_file": model_directory / "multi_head_optimized_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head + Error Predictor", + "files": [ + model_directory / "multi_head.onnx", + model_directory / "multi_head_error_predictor.onnx", + ], + "params_file": model_directory / "multi_head_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP", + "files": [model_directory / "multi_mlp.onnx"], + "params_file": model_directory / "multi_mlp_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-MLP + Error Predictor", + "files": [ + model_directory / "multi_mlp.onnx", + model_directory / "multi_mlp_error_predictor.onnx", + ], + "params_file": model_directory / "multi_mlp_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP + Multi-Error Predictor", + "files": [ + model_directory / "multi_mlp.onnx", + model_directory / "multi_mlp_multi_error_predictor.onnx", + ], + "params_file": model_directory / "multi_mlp_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP + Multi-Error Predictor (Optimized)", + "files": [ + model_directory / "multi_mlp.onnx", + model_directory / "multi_mlp_multi_error_predictor_optimized.onnx", + ], + "params_file": model_directory / "multi_mlp_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP (Optimized)", + "files": [model_directory / "multi_mlp_optimized.onnx"], + "params_file": model_directory / "multi_mlp_optimized_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head + Multi-Error Predictor", + "files": [ + model_directory / "multi_head.onnx", + model_directory / "multi_head_multi_error_predictor.onnx", + ], + "params_file": model_directory / "multi_head_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-Head + Cross-Attention Error Predictor", + "files": [ + model_directory / "multi_head.onnx", + model_directory / "multi_head_cross_attention_error_predictor.onnx", + ], + "params_file": model_directory / "multi_head_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-Head (Optimized) + Multi-Error Predictor (Optimized)", + "files": [ + model_directory / "multi_head_optimized.onnx", + model_directory / "multi_head_error_predictor_optimized.onnx", + ], + "params_file": model_directory / "multi_head_optimized_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-Head (Circular Loss)", + "files": [model_directory / "multi_head_circular.onnx"], + "params_file": model_directory / "multi_head_circular_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head (Large Dataset)", + "files": [model_directory / "multi_head_large.onnx"], + "params_file": model_directory / "multi_head_large_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head + Multi-Error Predictor (Large Dataset)", + "files": [ + model_directory / "multi_head_large.onnx", + model_directory / "multi_head_multi_error_predictor_large.onnx", + ], + "params_file": model_directory / "multi_head_large_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP (Large Dataset)", + "files": [model_directory / "multi_mlp_large.onnx"], + "params_file": model_directory / "multi_mlp_large_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-MLP + Multi-Error Predictor (Large Dataset)", + "files": [ + model_directory / "multi_mlp_large.onnx", + model_directory / "multi_mlp_multi_error_predictor_large.onnx", + ], + "params_file": model_directory / "multi_mlp_large_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Transformer (Large Dataset)", + "files": [model_directory / "transformer_large.onnx"], + "params_file": model_directory / "transformer_large_normalization_params.npz", + "type": "single", + }, + { + "name": "Transformer + Error Predictor (Large Dataset)", + "files": [ + model_directory / "transformer_large.onnx", + model_directory / "transformer_multi_error_predictor_large.onnx", + ], + "params_file": model_directory / "transformer_large_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-Head Refined (REAL Only)", + "files": [model_directory / "multi_head_refined_real.onnx"], + "params_file": model_directory / "multi_head_refined_real_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head Refined + Error Predictor (REAL Only)", + "files": [ + model_directory / "multi_head_refined_real.onnx", + model_directory / "multi_head_multi_error_predictor_refined_real.onnx", + ], + "params_file": model_directory / "multi_head_refined_real_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-Head + Multi-Error Predictor + Multi-Error Predictor (3-Stage)", + "files": [ + model_directory / "multi_head_large.onnx", + model_directory / "multi_head_multi_error_predictor_large.onnx", + model_directory / "multi_head_3stage_error_predictor.onnx", + ], + "params_file": model_directory / "multi_head_large_normalization_params.npz", + "type": "three_stage", + }, + { + "name": "Multi-Head (Weighted + Boundary Loss)", + "files": [model_directory / "multi_head_weighted_boundary.onnx"], + "params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor", + "files": [ + model_directory / "multi_head_weighted_boundary.onnx", + model_directory / "multi_head_weighted_boundary_multi_error_predictor.onnx", + ], + "params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor (Weighted + Boundary Loss)", + "files": [ + model_directory / "multi_head_weighted_boundary.onnx", + model_directory / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx", + ], + "params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP (Weighted + Boundary Loss) (Large Dataset)", + "files": [model_directory / "multi_mlp_weighted_boundary.onnx"], + "params_file": model_directory / "multi_mlp_weighted_boundary_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-MLP (Weighted + Boundary Loss) + Multi-Error Predictor (Weighted + Boundary Loss) (Large Dataset)", + "files": [ + model_directory / "multi_mlp_weighted_boundary.onnx", + model_directory / "multi_mlp_weighted_boundary_multi_error_predictor.onnx", + ], + "params_file": model_directory / "multi_mlp_weighted_boundary_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-ResNet (Large Dataset)", + "files": [model_directory / "multi_resnet_large.onnx"], + "params_file": model_directory / "multi_resnet_large_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-ResNet + Multi-Error Predictor (Large Dataset)", + "files": [ + model_directory / "multi_resnet_large.onnx", + model_directory / "multi_resnet_error_predictor_large.onnx", + ], + "params_file": model_directory / "multi_resnet_large_normalization_params.npz", + "type": "two_stage", + }, + ] + + # Benchmark colour library's iterative implementation first + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Colour Library (Iterative)") + LOGGER.info("=" * 80) + + # Benchmark the iterative xyY_to_munsell_specification function + # Note: Using full dataset (100% of samples) + + # Set random seed for reproducibility + np.random.seed(42) + + # Use 100% of samples for comprehensive benchmarking + sample_count = len(xyY_samples) + sampled_indices = np.arange(len(xyY_samples)) + xyY_benchmark_samples = xyY_samples[sampled_indices] + + # Measure inference time on sampled Munsell colors + start_time = time.perf_counter() + convergence_failures = 0 + successful_inferences = 0 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for xyy in xyY_benchmark_samples: + try: + xyY_to_munsell_specification(xyy) + successful_inferences += 1 + except (RuntimeError, ValueError): + # Out-of-gamut color that doesn't converge or not in renotation system + convergence_failures += 1 + + end_time = time.perf_counter() + + # Calculate average time per successful inference (in milliseconds) + total_time_s = end_time - start_time + colour_inference_time_ms = ( + (total_time_s / successful_inferences) * 1000 + if successful_inferences > 0 + else 0 + ) + + LOGGER.info("") + LOGGER.info("Performance Metrics:") + LOGGER.info(" Successful inferences: %d", successful_inferences) + LOGGER.info(" Convergence failures: %d", convergence_failures) + LOGGER.info(" Inference Speed: %.4f ms/sample", colour_inference_time_ms) + LOGGER.info(" Note: This is the baseline iterative implementation") + + # Store the baseline speed + baseline_inference_time_ms = colour_inference_time_ms + + # Convert ground truth Munsell specs to CIE Lab for Delta-E comparison + # Path: Munsell spec → xyY → XYZ → Lab + LOGGER.info("") + LOGGER.info( + "Converting ground truth to CIE Lab for Delta-E comparison..." + ) + LOGGER.info(" Path: Munsell spec \u2192 xyY \u2192 XYZ \u2192 Lab") + reference_Lab = [] + for spec in ground_truth: + try: + # Munsell specification → xyY + xyy = munsell_specification_to_xyY(spec) + # xyY → XYZ + XYZ = xyY_to_XYZ(xyy) + # XYZ → Lab (Illuminant C for Munsell) + Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL) + reference_Lab.append(Lab) + except (RuntimeError, ValueError): + # If conversion fails, use NaN + reference_Lab.append(np.array([np.nan, np.nan, np.nan])) + + reference_Lab = np.array(reference_Lab) + LOGGER.info( + " Converted %d ground truth specs to CIE Lab", + len(reference_Lab), + ) + + # Use the same sampled subset for ML model evaluations (for fair comparison) + xyY_samples = xyY_benchmark_samples + ground_truth = ground_truth[sampled_indices] + + # Evaluate each model + results = {} + + for model_info in models: + model_name = model_info["name"] + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info(model_name) + LOGGER.info("=" * 80) + + # Load normalization params for this model + params = np.load(model_info["params_file"], allow_pickle=True) + # input_params may not exist if xyY is already in [0, 1] range + input_params = ( + params["input_params"].item() + if "input_params" in params.files + else None + ) + output_params = params["output_params"].item() + + # Normalize input with this model's params (None means no normalization) + X_norm = normalize_input(xyY_samples, input_params) + + # Calculate model size + model_size_mb = get_model_size_mb(model_info["files"]) + + if model_info["type"] == "two_stage": + # Two-stage model + base_session = ort.InferenceSession(str(model_info["files"][0])) + error_session = ort.InferenceSession(str(model_info["files"][1])) + + # Define inference callable for benchmarking + def two_stage_inference( + _base_session: ort.InferenceSession = base_session, + _error_session: ort.InferenceSession = error_session, + _X_norm: NDArray = X_norm, + ) -> NDArray: + base_pred = _base_session.run(None, {"xyY": _X_norm})[0] + combined = np.concatenate([_X_norm, base_pred], axis=1).astype( + np.float32 + ) + error_corr = _error_session.run(None, {"combined_input": combined})[ + 0 + ] + return base_pred + error_corr + + # Benchmark speed + inference_time_ms = benchmark_inference_speed( + two_stage_inference, X_norm + ) + + # Get predictions + base_pred_norm = base_session.run(None, {"xyY": X_norm})[0] + combined_input = np.concatenate( + [X_norm, base_pred_norm], axis=1 + ).astype(np.float32) + error_correction_norm = error_session.run( + None, {"combined_input": combined_input} + )[0] + final_pred_norm = base_pred_norm + error_correction_norm + pred = denormalize_output(final_pred_norm, output_params) + errors = np.abs(pred - ground_truth) + + result = { + "hue_mae": np.mean(errors[:, 0]), + "value_mae": np.mean(errors[:, 1]), + "chroma_mae": np.mean(errors[:, 2]), + "code_mae": np.mean(errors[:, 3]), + "max_errors": np.max(errors, axis=1), + "hue_errors": errors[:, 0], + "value_errors": errors[:, 1], + "chroma_errors": errors[:, 2], + "code_errors": errors[:, 3], + "model_size_mb": model_size_mb, + "inference_time_ms": inference_time_ms, + } + + # Compute Delta-E against ground truth + delta_E_values = [] + for idx in range(len(pred)): + try: + # Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab + ml_spec = clamp_munsell_specification(pred[idx]) + + # Round Code to nearest integer before round-trip conversion + ml_spec_for_conversion = ml_spec.copy() + ml_spec_for_conversion[3] = round(ml_spec[3]) + + ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion) + ml_XYZ = xyY_to_XYZ(ml_xyy) + ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL) + + # Get ground truth Lab + reference_Lab_sample = reference_Lab[idx] + + # Compute Delta-E CIE2000 + delta_E = delta_E_CIE2000(reference_Lab_sample, ml_Lab) + delta_E_values.append(delta_E) + except (RuntimeError, ValueError): + # Skip if conversion fails + continue + + result["delta_E"] = ( + np.mean(delta_E_values) if delta_E_values else np.nan + ) + elif model_info["type"] == "three_stage": + # Three-stage model: base + error predictor 1 + error predictor 2 + base_session = ort.InferenceSession(str(model_info["files"][0])) + error1_session = ort.InferenceSession(str(model_info["files"][1])) + error2_session = ort.InferenceSession(str(model_info["files"][2])) + + # Define inference callable for benchmarking + def three_stage_inference( + _base_session: ort.InferenceSession = base_session, + _error1_session: ort.InferenceSession = error1_session, + _error2_session: ort.InferenceSession = error2_session, + _X_norm: NDArray = X_norm, + ) -> NDArray: + # Stage 1: Base model + base_pred = _base_session.run(None, {"xyY": _X_norm})[0] + # Stage 2: First error correction + combined1 = np.concatenate([_X_norm, base_pred], axis=1).astype( + np.float32 + ) + error1_corr = _error1_session.run( + None, {"combined_input": combined1} + )[0] + stage2_pred = base_pred + error1_corr + # Stage 3: Second error correction + combined2 = np.concatenate([_X_norm, stage2_pred], axis=1).astype( + np.float32 + ) + error2_corr = _error2_session.run( + None, {"combined_input": combined2} + )[0] + return stage2_pred + error2_corr + + # Benchmark speed + inference_time_ms = benchmark_inference_speed( + three_stage_inference, X_norm + ) + + # Get predictions + base_pred_norm = base_session.run(None, {"xyY": X_norm})[0] + combined1 = np.concatenate([X_norm, base_pred_norm], axis=1).astype( + np.float32 + ) + error1_corr_norm = error1_session.run( + None, {"combined_input": combined1} + )[0] + stage2_pred_norm = base_pred_norm + error1_corr_norm + combined2 = np.concatenate([X_norm, stage2_pred_norm], axis=1).astype( + np.float32 + ) + error2_corr_norm = error2_session.run( + None, {"combined_input": combined2} + )[0] + final_pred_norm = stage2_pred_norm + error2_corr_norm + pred = denormalize_output(final_pred_norm, output_params) + errors = np.abs(pred - ground_truth) + + result = { + "hue_mae": np.mean(errors[:, 0]), + "value_mae": np.mean(errors[:, 1]), + "chroma_mae": np.mean(errors[:, 2]), + "code_mae": np.mean(errors[:, 3]), + "max_errors": np.max(errors, axis=1), + "hue_errors": errors[:, 0], + "value_errors": errors[:, 1], + "chroma_errors": errors[:, 2], + "code_errors": errors[:, 3], + "model_size_mb": model_size_mb, + "inference_time_ms": inference_time_ms, + } + + # Compute Delta-E against ground truth for three-stage model + delta_E_values = [] + for idx in range(len(pred)): + try: + ml_spec = clamp_munsell_specification(pred[idx]) + ml_spec_for_conversion = ml_spec.copy() + ml_spec_for_conversion[3] = round(ml_spec[3]) + ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion) + ml_XYZ = xyY_to_XYZ(ml_xyy) + ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL) + delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab) + delta_E_values.append(delta_E) + except (RuntimeError, ValueError): + continue + + result["delta_E"] = ( + np.mean(delta_E_values) if delta_E_values else np.nan + ) + else: + # Single model + session = ort.InferenceSession(str(model_info["files"][0])) + + # Define inference callable for benchmarking + def single_inference( + _session: ort.InferenceSession = session, _X_norm: NDArray = X_norm + ) -> NDArray: + return _session.run(None, {"xyY": _X_norm})[0] + + # Benchmark speed + inference_time_ms = benchmark_inference_speed(single_inference, X_norm) + + result = evaluate_model( + session, + X_norm, + ground_truth, + output_params, + reference_Lab=reference_Lab, + ) + result["model_size_mb"] = model_size_mb + result["inference_time_ms"] = inference_time_ms + + results[model_name] = result + + # Print results + LOGGER.info("") + LOGGER.info("Mean Absolute Errors:") + LOGGER.info(" Hue: %.4f", result["hue_mae"]) + LOGGER.info(" Value: %.4f", result["value_mae"]) + LOGGER.info(" Chroma: %.4f", result["chroma_mae"]) + LOGGER.info(" Code: %.4f", result["code_mae"]) + if not np.isnan(result["delta_E"]): + LOGGER.info(" Delta-E (vs Ground Truth): %.4f", result["delta_E"]) + LOGGER.info("") + LOGGER.info("Performance Metrics:") + LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"]) + LOGGER.info( + " Inference Speed: %.4f ms/sample", result["inference_time_ms"] + ) + + + # Summary comparison + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("SUMMARY COMPARISON") + LOGGER.info("=" * 80) + LOGGER.info("") + + if not results: + LOGGER.info("⚠️ No models were successfully evaluated") + return + + # MAE comparison table + LOGGER.info("Mean Absolute Error Comparison:") + LOGGER.info("") + header = "{:<35} {:>8} {:>8} {:>8} {:>8} {:>10}".format( + "Model", + "Hue", + "Value", + "Chroma", + "Code", + "Delta-E", + ) + LOGGER.info(header) + LOGGER.info("-" * 90) + + for model_name, result in results.items(): + delta_E_str = ( + f"{result['delta_E']:.4f}" if not np.isnan(result["delta_E"]) else "N/A" + ) + LOGGER.info( + "%-35s %8.4f %8.4f %8.4f %8.4f %10s", + model_name[:35], + result["hue_mae"], + result["value_mae"], + result["chroma_mae"], + result["code_mae"], + delta_E_str, + ) + + # Precision threshold comparison + LOGGER.info("") + LOGGER.info("Accuracy at Precision Thresholds:") + LOGGER.info("") + + thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0] + header_parts = [f"{'Model/Threshold':<35}"] + header_parts.extend(f"{f'< {threshold:.0e}':>10}" for threshold in thresholds) + LOGGER.info(" ".join(header_parts)) + LOGGER.info("-" * 80) + + for model_name, result in results.items(): + row_parts = [f"{model_name[:35]:<35}"] + for threshold in thresholds: + accuracy_pct = np.mean(result["max_errors"] < threshold) * 100 + row_parts.append(f"{accuracy_pct:9.2f}%") + LOGGER.info(" ".join(row_parts)) + + # Performance metrics comparison + LOGGER.info("") + LOGGER.info("Model Size and Inference Speed Comparison:") + LOGGER.info("") + header = f"{'Model':<35} {'Size (MB)':>12} {'Speed (ms/sample)':>18}" + LOGGER.info(header) + LOGGER.info("-" * 80) + + for model_name, result in results.items(): + LOGGER.info( + "%-35s %11.2f %17.4f", + model_name[:35], + result["model_size_mb"], + result["inference_time_ms"], + ) + + # Find best model + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("BEST MODELS BY METRIC") + LOGGER.info("=" * 80) + LOGGER.info("") + + metrics = ["hue_mae", "value_mae", "chroma_mae", "code_mae"] + metric_names = ["Hue MAE", "Value MAE", "Chroma MAE", "Code MAE"] + + for metric, metric_name in zip(metrics, metric_names, strict=False): + best_model = min(results.items(), key=lambda x: x[1][metric]) + LOGGER.info( + "%-15s: %s (%.4f)", + metric_name, + best_model[0], + best_model[1][metric], + ) + + # Overall best (average rank) + LOGGER.info("") + LOGGER.info("Overall Best (by average component MAE):") + for model_name, result in results.items(): + avg_mae = np.mean( + [ + result["hue_mae"], + result["value_mae"], + result["chroma_mae"], + result["code_mae"], + ] + ) + LOGGER.info(" %s: %.4f", model_name, avg_mae) + + LOGGER.info("") + LOGGER.info("=" * 80) + + # Generate HTML report + report_dir = PROJECT_ROOT / "reports" / "from_xyY" + report_dir.mkdir(exist_ok=True) + report_file = report_dir / "model_comparison.html" + generate_html_report( + results, len(xyY_samples), report_file, baseline_inference_time_ms + ) + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/comparison/from_xyY/compare_gamma_model.py b/learning_munsell/comparison/from_xyY/compare_gamma_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7decac3b3d8ac9a167d828292e95f7a1a8519d56 --- /dev/null +++ b/learning_munsell/comparison/from_xyY/compare_gamma_model.py @@ -0,0 +1,390 @@ +""" +Quick comparison of the gamma-corrected models against baselines. + +This script compares: +1. MLP (Base) vs MLP (Gamma 2.33) +2. Multi-Head (Base) vs Multi-Head (Gamma 2.33) vs Multi-Head (ST.2084) +""" + +import logging +from typing import Any + +import numpy as np +import onnxruntime as ort +from colour import XYZ_to_Lab, xyY_to_XYZ +from colour.difference import delta_E_CIE2000 +from colour.models import eotf_inverse_ST2084 +from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL +from colour.notation.munsell import ( + CCS_ILLUMINANT_MUNSELL, + munsell_colour_to_munsell_specification, + munsell_specification_to_xyY, +) +from numpy.typing import NDArray + +from learning_munsell import PROJECT_ROOT + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +def normalize_input_standard(X: NDArray, params: dict[str, Any]) -> NDArray: + """Standard xyY normalization.""" + X_norm = np.copy(X) + X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / ( + params["x_range"][1] - params["x_range"][0] + ) + X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / ( + params["y_range"][1] - params["y_range"][0] + ) + X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / ( + params["Y_range"][1] - params["Y_range"][0] + ) + return X_norm.astype(np.float32) + + +def normalize_input_gamma(X: NDArray, params: dict[str, Any]) -> NDArray: + """Gamma-corrected xyY normalization.""" + gamma = params.get("gamma", 2.33) + X_norm = np.copy(X) + X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / ( + params["x_range"][1] - params["x_range"][0] + ) + X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / ( + params["y_range"][1] - params["y_range"][0] + ) + # Normalize Y then apply gamma + Y_normalized = (X[..., 2] - params["Y_range"][0]) / ( + params["Y_range"][1] - params["Y_range"][0] + ) + Y_normalized = np.clip(Y_normalized, 0, 1) + X_norm[..., 2] = np.power(Y_normalized, 1.0 / gamma) + return X_norm.astype(np.float32) + + +def normalize_input_st2084(X: NDArray, params: dict[str, Any]) -> NDArray: + """ST.2084 (PQ) encoded xyY normalization.""" + L_p = params.get("L_p", 100.0) + X_norm = np.copy(X) + X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / ( + params["x_range"][1] - params["x_range"][0] + ) + X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / ( + params["y_range"][1] - params["y_range"][0] + ) + # Normalize Y then apply ST.2084 + Y_normalized = (X[..., 2] - params["Y_range"][0]) / ( + params["Y_range"][1] - params["Y_range"][0] + ) + Y_normalized = np.clip(Y_normalized, 0, 1) + # Scale to cd/m² and apply ST.2084 inverse EOTF + Y_cdm2 = Y_normalized * L_p + X_norm[..., 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p) + return X_norm.astype(np.float32) + + +def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray: + """Denormalize Munsell output.""" + y = np.copy(y_norm) + y[..., 0] = ( + y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) + + params["hue_range"][0] + ) + y[..., 1] = ( + y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0]) + + params["value_range"][0] + ) + y[..., 2] = ( + y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + + params["chroma_range"][0] + ) + y[..., 3] = ( + y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0]) + + params["code_range"][0] + ) + return y + + +def clamp_munsell_specification(spec: NDArray) -> NDArray: + """Clamp Munsell specification to valid ranges.""" + clamped = np.copy(spec) + clamped[..., 0] = np.clip(spec[..., 0], 0.0, 10.0) # Hue: [0, 10] + clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0) # Value: [1, 9] (colour library constraint) + clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0) # Chroma: [0, 50] + clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0) # Code: [1, 10] + return clamped + + +def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]: + """Compute Delta-E for predictions.""" + delta_E_values = [] + for idx in range(len(pred)): + try: + ml_spec = clamp_munsell_specification(pred[idx]) + ml_spec_for_conversion = ml_spec.copy() + ml_spec_for_conversion[3] = round(ml_spec[3]) + ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion) + ml_XYZ = xyY_to_XYZ(ml_xyy) + ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL) + delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab) + delta_E_values.append(delta_E) + except (RuntimeError, ValueError): + continue + return delta_E_values + + +def main() -> None: + """Compare gamma model against baseline.""" + LOGGER.info("=" * 80) + LOGGER.info("Gamma Model Comparison: MLP vs MLP (Gamma 2.33)") + LOGGER.info("=" * 80) + + models_dir = PROJECT_ROOT / "models" / "from_xyY" + + # Load real Munsell data + LOGGER.info("\nLoading real Munsell colours...") + xyY_values = [] + munsell_specs = [] + reference_Lab = [] + + for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL: + try: + hue_code, value, chroma = munsell_spec_tuple + munsell_str = f"{hue_code} {value}/{chroma}" + spec = munsell_colour_to_munsell_specification(munsell_str) + xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0]) + + XYZ = xyY_to_XYZ(xyY_scaled) + Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL) + + xyY_values.append(xyY_scaled) + munsell_specs.append(spec) + reference_Lab.append(Lab) + except (RuntimeError, ValueError): + continue + + xyY_array = np.array(xyY_values) + ground_truth = np.array(munsell_specs) + reference_Lab = np.array(reference_Lab) + + LOGGER.info("Loaded %d real Munsell colours", len(xyY_array)) + + # Test baseline MLP + LOGGER.info("\n" + "-" * 40) + LOGGER.info("1. MLP (Base) - Standard Normalization") + LOGGER.info("-" * 40) + + base_onnx = models_dir / "mlp.onnx" + base_params_file = models_dir / "mlp_normalization_params.npz" + + if base_onnx.exists() and base_params_file.exists(): + base_session = ort.InferenceSession(str(base_onnx)) + base_params_data = np.load(base_params_file, allow_pickle=True) + base_input_params = base_params_data["input_params"].item() + base_output_params = base_params_data["output_params"].item() + + X_norm_base = normalize_input_standard(xyY_array, base_input_params) + pred_norm = base_session.run(None, {"xyY": X_norm_base})[0] + pred_base = denormalize_output(pred_norm, base_output_params) + + errors_base = np.abs(pred_base - ground_truth) + delta_E_base = compute_delta_e(pred_base, reference_Lab) + + LOGGER.info(" Hue MAE: %.4f", np.mean(errors_base[:, 0])) + LOGGER.info(" Value MAE: %.4f", np.mean(errors_base[:, 1])) + LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_base[:, 2])) + LOGGER.info(" Code MAE: %.4f", np.mean(errors_base[:, 3])) + LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)", + np.mean(delta_E_base), np.median(delta_E_base)) + else: + LOGGER.info(" Model not found, skipping...") + delta_E_base = [] + + # Test gamma MLP + LOGGER.info("\n" + "-" * 40) + LOGGER.info("2. MLP (Gamma 2.33) - Gamma-Corrected Y") + LOGGER.info("-" * 40) + + gamma_onnx = models_dir / "mlp_gamma.onnx" + gamma_params_file = models_dir / "mlp_gamma_normalization_params.npz" + + if gamma_onnx.exists() and gamma_params_file.exists(): + gamma_session = ort.InferenceSession(str(gamma_onnx)) + gamma_params_data = np.load(gamma_params_file, allow_pickle=True) + gamma_input_params = gamma_params_data["input_params"].item() + gamma_output_params = gamma_params_data["output_params"].item() + + X_norm_gamma = normalize_input_gamma(xyY_array, gamma_input_params) + pred_norm = gamma_session.run(None, {"xyY_gamma": X_norm_gamma})[0] + pred_gamma = denormalize_output(pred_norm, gamma_output_params) + + errors_gamma = np.abs(pred_gamma - ground_truth) + delta_E_gamma = compute_delta_e(pred_gamma, reference_Lab) + + LOGGER.info(" Hue MAE: %.4f", np.mean(errors_gamma[:, 0])) + LOGGER.info(" Value MAE: %.4f", np.mean(errors_gamma[:, 1])) + LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_gamma[:, 2])) + LOGGER.info(" Code MAE: %.4f", np.mean(errors_gamma[:, 3])) + LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)", + np.mean(delta_E_gamma), np.median(delta_E_gamma)) + else: + LOGGER.info(" Model not found, skipping...") + delta_E_gamma = [] + + # Summary comparison for MLP + if delta_E_base and delta_E_gamma: + LOGGER.info("\n" + "=" * 80) + LOGGER.info("MLP COMPARISON SUMMARY") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Delta-E (lower is better):") + LOGGER.info(" MLP (Base): %.4f mean, %.4f median", + np.mean(delta_E_base), np.median(delta_E_base)) + LOGGER.info(" MLP (Gamma): %.4f mean, %.4f median", + np.mean(delta_E_gamma), np.median(delta_E_gamma)) + LOGGER.info("") + + improvement = (np.mean(delta_E_base) - np.mean(delta_E_gamma)) / np.mean(delta_E_base) * 100 + if improvement > 0: + LOGGER.info(" Gamma model is %.1f%% BETTER", improvement) + else: + LOGGER.info(" Gamma model is %.1f%% WORSE", -improvement) + + # Test Multi-Head baseline + LOGGER.info("\n" + "=" * 80) + LOGGER.info("MULTI-HEAD GAMMA EXPERIMENT") + LOGGER.info("=" * 80) + + LOGGER.info("\n" + "-" * 40) + LOGGER.info("3. Multi-Head (Base) - Standard Normalization") + LOGGER.info("-" * 40) + + mh_base_onnx = models_dir / "multi_head.onnx" + mh_base_params_file = models_dir / "multi_head_normalization_params.npz" + + if mh_base_onnx.exists() and mh_base_params_file.exists(): + mh_base_session = ort.InferenceSession(str(mh_base_onnx)) + mh_base_params_data = np.load(mh_base_params_file, allow_pickle=True) + mh_base_input_params = mh_base_params_data["input_params"].item() + mh_base_output_params = mh_base_params_data["output_params"].item() + + X_norm_mh_base = normalize_input_standard(xyY_array, mh_base_input_params) + pred_norm = mh_base_session.run(None, {"xyY": X_norm_mh_base})[0] + pred_mh_base = denormalize_output(pred_norm, mh_base_output_params) + + errors_mh_base = np.abs(pred_mh_base - ground_truth) + delta_E_mh_base = compute_delta_e(pred_mh_base, reference_Lab) + + LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_base[:, 0])) + LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_base[:, 1])) + LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_base[:, 2])) + LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_base[:, 3])) + LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)", + np.mean(delta_E_mh_base), np.median(delta_E_mh_base)) + else: + LOGGER.info(" Model not found, skipping...") + delta_E_mh_base = [] + + # Test Multi-Head gamma + LOGGER.info("\n" + "-" * 40) + LOGGER.info("4. Multi-Head (Gamma 2.33) - Gamma-Corrected Y") + LOGGER.info("-" * 40) + + mh_gamma_onnx = models_dir / "multi_head_gamma.onnx" + mh_gamma_params_file = models_dir / "multi_head_gamma_normalization_params.npz" + + if mh_gamma_onnx.exists() and mh_gamma_params_file.exists(): + mh_gamma_session = ort.InferenceSession(str(mh_gamma_onnx)) + mh_gamma_params_data = np.load(mh_gamma_params_file, allow_pickle=True) + mh_gamma_input_params = mh_gamma_params_data["input_params"].item() + mh_gamma_output_params = mh_gamma_params_data["output_params"].item() + + X_norm_mh_gamma = normalize_input_gamma(xyY_array, mh_gamma_input_params) + pred_norm = mh_gamma_session.run(None, {"xyY_gamma": X_norm_mh_gamma})[0] + pred_mh_gamma = denormalize_output(pred_norm, mh_gamma_output_params) + + errors_mh_gamma = np.abs(pred_mh_gamma - ground_truth) + delta_E_mh_gamma = compute_delta_e(pred_mh_gamma, reference_Lab) + + LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_gamma[:, 0])) + LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_gamma[:, 1])) + LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_gamma[:, 2])) + LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_gamma[:, 3])) + LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)", + np.mean(delta_E_mh_gamma), np.median(delta_E_mh_gamma)) + else: + LOGGER.info(" Model not found, skipping...") + delta_E_mh_gamma = [] + + # Test Multi-Head ST.2084 + LOGGER.info("\n" + "-" * 40) + LOGGER.info("5. Multi-Head (ST.2084) - PQ-Encoded Y") + LOGGER.info("-" * 40) + + mh_st2084_onnx = models_dir / "multi_head_st2084.onnx" + mh_st2084_params_file = models_dir / "multi_head_st2084_normalization_params.npz" + + if mh_st2084_onnx.exists() and mh_st2084_params_file.exists(): + mh_st2084_session = ort.InferenceSession(str(mh_st2084_onnx)) + mh_st2084_params_data = np.load(mh_st2084_params_file, allow_pickle=True) + mh_st2084_input_params = mh_st2084_params_data["input_params"].item() + mh_st2084_output_params = mh_st2084_params_data["output_params"].item() + + X_norm_mh_st2084 = normalize_input_st2084(xyY_array, mh_st2084_input_params) + pred_norm = mh_st2084_session.run(None, {"xyY_st2084": X_norm_mh_st2084})[0] + pred_mh_st2084 = denormalize_output(pred_norm, mh_st2084_output_params) + + errors_mh_st2084 = np.abs(pred_mh_st2084 - ground_truth) + delta_E_mh_st2084 = compute_delta_e(pred_mh_st2084, reference_Lab) + + LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_st2084[:, 0])) + LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_st2084[:, 1])) + LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_st2084[:, 2])) + LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_st2084[:, 3])) + LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)", + np.mean(delta_E_mh_st2084), np.median(delta_E_mh_st2084)) + else: + LOGGER.info(" Model not found, skipping...") + delta_E_mh_st2084 = [] + + # Summary comparison for Multi-Head + if delta_E_mh_base and delta_E_mh_gamma: + LOGGER.info("\n" + "=" * 80) + LOGGER.info("MULTI-HEAD COMPARISON SUMMARY") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Delta-E (lower is better):") + LOGGER.info(" Multi-Head (Base): %.4f mean, %.4f median", + np.mean(delta_E_mh_base), np.median(delta_E_mh_base)) + LOGGER.info(" Multi-Head (Gamma): %.4f mean, %.4f median", + np.mean(delta_E_mh_gamma), np.median(delta_E_mh_gamma)) + if delta_E_mh_st2084: + LOGGER.info(" Multi-Head (ST.2084): %.4f mean, %.4f median", + np.mean(delta_E_mh_st2084), np.median(delta_E_mh_st2084)) + LOGGER.info("") + + mh_gamma_improvement = (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_gamma)) / np.mean(delta_E_mh_base) * 100 + if mh_gamma_improvement > 0: + LOGGER.info(" Multi-Head Gamma vs Base: %.1f%% BETTER", mh_gamma_improvement) + else: + LOGGER.info(" Multi-Head Gamma vs Base: %.1f%% WORSE", -mh_gamma_improvement) + + if delta_E_mh_st2084: + mh_st2084_improvement = (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_st2084)) / np.mean(delta_E_mh_base) * 100 + if mh_st2084_improvement > 0: + LOGGER.info(" Multi-Head ST.2084 vs Base: %.1f%% BETTER", mh_st2084_improvement) + else: + LOGGER.info(" Multi-Head ST.2084 vs Base: %.1f%% WORSE", -mh_st2084_improvement) + + # Compare ST.2084 vs Gamma + st2084_vs_gamma = (np.mean(delta_E_mh_gamma) - np.mean(delta_E_mh_st2084)) / np.mean(delta_E_mh_gamma) * 100 + if st2084_vs_gamma > 0: + LOGGER.info(" Multi-Head ST.2084 vs Gamma: %.1f%% BETTER", st2084_vs_gamma) + else: + LOGGER.info(" Multi-Head ST.2084 vs Gamma: %.1f%% WORSE", -st2084_vs_gamma) + + LOGGER.info("\n" + "=" * 80) + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/comparison/to_xyY/__init__.py b/learning_munsell/comparison/to_xyY/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de52dc85459247b0820ff558b8d679c6b8121fe3 --- /dev/null +++ b/learning_munsell/comparison/to_xyY/__init__.py @@ -0,0 +1 @@ +"""Comparison scripts for Munsell to xyY conversion models.""" diff --git a/learning_munsell/comparison/to_xyY/compare_all_models.py b/learning_munsell/comparison/to_xyY/compare_all_models.py new file mode 100644 index 0000000000000000000000000000000000000000..0eca8b89b97ed7ed6652da98974a581a04e1bd04 --- /dev/null +++ b/learning_munsell/comparison/to_xyY/compare_all_models.py @@ -0,0 +1,617 @@ +""" +Compare all ML models for Munsell to xyY conversion on real Munsell data. + +Models to compare: +1. Simple MLP Approximator +2. Multi-Head MLP +3. Multi-Head MLP (Optimized) - with hyperparameter optimization +4. Multi-Head + Multi-Error Predictor +5. Multi-MLP - 3 independent branches +6. Multi-MLP (Optimized) - 3 independent branches with optimized hyperparameters +7. Multi-MLP + Error Predictor +8. Multi-MLP + Multi-Error Predictor +9. Multi-MLP (Optimized) + Multi-Error Predictor (Optimized) +""" + +from __future__ import annotations + +import logging +import time +import warnings +from typing import TYPE_CHECKING + +import numpy as np +import onnxruntime as ort +from colour import XYZ_to_Lab, xyY_to_XYZ +from colour.difference import delta_E_CIE2000 +from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL +from colour.notation.munsell import ( + CCS_ILLUMINANT_MUNSELL, + munsell_colour_to_munsell_specification, + munsell_specification_to_xyY, +) +from numpy.typing import NDArray # noqa: TC002 + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import ( + benchmark_inference_speed, + generate_html_report_footer, + generate_html_report_header, + generate_ranking_section, + get_model_size_mb, +) + +if TYPE_CHECKING: + from pathlib import Path + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +def normalize_munsell(munsell: np.ndarray) -> np.ndarray: + """Normalize Munsell specs to [0, 1] range.""" + normalized = munsell.copy() + normalized[..., 0] = munsell[..., 0] / 10.0 # Hue (in decade) + normalized[..., 1] = munsell[..., 1] / 10.0 # Value + normalized[..., 2] = munsell[..., 2] / 50.0 # Chroma + normalized[..., 3] = munsell[..., 3] / 10.0 # Code + return normalized.astype(np.float32) + + +def evaluate_model( + session: ort.InferenceSession, + X_norm: np.ndarray, + ground_truth: np.ndarray, + input_name: str = "munsell_normalized", +) -> dict: + """Evaluate a single model.""" + pred = session.run(None, {input_name: X_norm})[0] + errors = np.abs(pred - ground_truth) + + return { + "x_mae": np.mean(errors[:, 0]), + "y_mae": np.mean(errors[:, 1]), + "Y_mae": np.mean(errors[:, 2]), + "predictions": pred, + "errors": errors, + "max_errors": np.max(errors, axis=1), + } + + +def compute_delta_E( + ml_predictions: np.ndarray, + reference_xyY: np.ndarray, +) -> float: + """Compute Delta-E CIE2000 between ML predictions and reference xyY (ground truth).""" + delta_E_values = [] + + for ml_xyY, ref_xyY in zip(ml_predictions, reference_xyY, strict=False): + try: + ml_XYZ = xyY_to_XYZ(ml_xyY) + ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL) + + ref_XYZ = xyY_to_XYZ(ref_xyY) + ref_Lab = XYZ_to_Lab(ref_XYZ, CCS_ILLUMINANT_MUNSELL) + + delta_E = delta_E_CIE2000(ref_Lab, ml_Lab) + if not np.isnan(delta_E): + delta_E_values.append(delta_E) + except (RuntimeError, ValueError): + continue + + return np.mean(delta_E_values) if delta_E_values else np.nan + + +def generate_html_report( + results: dict, + num_samples: int, + output_file: Path, + baseline_inference_time_ms: float, +) -> None: + """Generate HTML report with visualizations.""" + # Calculate average MAE + avg_maes = {} + for model_name, result in results.items(): + avg_maes[model_name] = np.mean( + [ + result["x_mae"], + result["y_mae"], + result["Y_mae"], + ] + ) + + # Sort by average MAE + sorted_models = sorted(avg_maes.items(), key=lambda x: x[1]) + + # Start HTML + html = generate_html_report_header( + title="ML Model Comparison Report", + subtitle="Munsell to xyY Conversion", + num_samples=num_samples, + ) + + # Best Models Summary + best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0] + best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0] + best_avg = sorted_models[0][0] + + # Find best Delta-E + delta_E_results = [ + (n, r["delta_E"]) for n, r in results.items() if not np.isnan(r["delta_E"]) + ] + best_delta_E = ( + min(delta_E_results, key=lambda x: x[1])[0] if delta_E_results else None + ) + + html += f""" + +
+

Best Models by Metric

+
+
+
Smallest Size
+
{results[best_size]["model_size_mb"]:.2f} MB
+
{best_size}
+
+
+
Fastest Speed
+
{results[best_speed]["inference_time_ms"]:.4f} ms
+
{best_speed}
+
+""" + + if best_delta_E: + html += f""" +
+
Best Delta-E
+
{results[best_delta_E]["delta_E"]:.6f}
+
{best_delta_E}
+
+""" + + html += f""" +
+
Best Average MAE
+
{avg_maes[best_avg]:.6f}
+
{best_avg}
+
+
+
+""" + + # Performance Metrics Table + sorted_by_avg_mae = sorted(results.items(), key=lambda x: avg_maes[x[0]]) + + html += """ + +
+

Model Performance Metrics

+
+ + + + + + + + + + + + + + +""" + + for model_name, result in sorted_by_avg_mae: + size_mb = result["model_size_mb"] + speed_ms = result["inference_time_ms"] + delta_E = result["delta_E"] + + # Calculate speedup vs baseline + speedup = baseline_inference_time_ms / speed_ms if speed_ms > 0 else 0 + + size_class = "text-primary font-semibold" if model_name == best_size else "" + speed_class = "text-primary font-semibold" if model_name == best_speed else "" + delta_E_class = ( + "text-primary font-semibold" if model_name == best_delta_E else "" + ) + + delta_E_str = f"{delta_E:.6f}" if not np.isnan(delta_E) else "—" + + speedup_text = f"{speedup:.0f}x" if speedup > 100 else f"{speedup:.1f}x" + + html += f""" + + + + + + + + + + +""" + + html += """ + +
ModelSize (MB)Speed (ms/sample)vs BaselineMAE xMAE yMAE YDelta-E
{model_name}{size_mb:.2f}{speed_ms:.4f}{speedup_text}{result["x_mae"]:.6f}{result["y_mae"]:.6f}{result["Y_mae"]:.6f}{delta_E_str}
+
+
+""" + + # Add ranking section + html += generate_ranking_section( + results, + metric_key="avg_mae", + title="Overall Ranking (by Average MAE)", + ) + + # Precision thresholds + thresholds = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1] + + html += """ +
+

Accuracy at Precision Thresholds

+

Percentage of predictions where max error across all components is below threshold:

+
+ + + + +""" + + for threshold in thresholds: + html += f' \n' + + html += """ + + + +""" + + for model_name, _ in sorted_models: + result = results[model_name] + html += f""" + + +""" + for threshold in thresholds: + accuracy_pct = np.mean(result["max_errors"] < threshold) * 100 + html += f' \n' + + html += """ + +""" + + html += """ + +
Model< {threshold:.0e}
{model_name}{accuracy_pct:.2f}%
+
+
+""" + + html += generate_html_report_footer() + + # Write HTML file + with open(output_file, "w") as f: + f.write(html) + + LOGGER.info("") + LOGGER.info("HTML report saved to: %s", output_file) + + +def main() -> None: + """Compare all models.""" + LOGGER.info("=" * 80) + LOGGER.info("Munsell to xyY Model Comparison") + LOGGER.info("=" * 80) + + # Paths + model_directory = PROJECT_ROOT / "models" / "to_xyY" + + # Load real Munsell dataset + LOGGER.info("") + LOGGER.info("Loading real Munsell dataset...") + munsell_specs = [] + xyY_ground_truth = [] + + for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL: + try: + hue_code, value, chroma = munsell_spec_tuple + munsell_str = f"{hue_code} {value}/{chroma}" + spec = munsell_colour_to_munsell_specification(munsell_str) + xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0]) + munsell_specs.append(spec) + xyY_ground_truth.append(xyY_scaled) + except Exception: # noqa: BLE001, S112 + continue + + munsell_specs = np.array(munsell_specs, dtype=np.float32) + xyY_ground_truth = np.array(xyY_ground_truth, dtype=np.float32) + LOGGER.info("Loaded %d valid Munsell colors", len(munsell_specs)) + + # Normalize inputs + munsell_normalized = normalize_munsell(munsell_specs) + + # Benchmark colour library first + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Colour Library (munsell_specification_to_xyY)") + LOGGER.info("=" * 80) + + # Benchmark the munsell_specification_to_xyY function + # Note: Using full dataset (100% of samples) + + # Set random seed for reproducibility + np.random.seed(42) + + # Use 100% of samples for comprehensive benchmarking + sampled_indices = np.arange(len(munsell_specs)) + munsell_benchmark = munsell_specs[sampled_indices] + + start_time = time.perf_counter() + colour_predictions = [] + successful_inferences = 0 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for spec in munsell_benchmark: + try: + xyY = munsell_specification_to_xyY(spec) + colour_predictions.append(xyY) + successful_inferences += 1 + except (RuntimeError, ValueError): + colour_predictions.append(np.array([np.nan, np.nan, np.nan])) + + end_time = time.perf_counter() + + total_time_s = end_time - start_time + baseline_inference_time_ms = ( + (total_time_s / successful_inferences) * 1000 + if successful_inferences > 0 + else 0 + ) + colour_predictions = np.array(colour_predictions) + + LOGGER.info(" Successful inferences: %d", successful_inferences) + LOGGER.info(" Inference Speed: %.4f ms/sample", baseline_inference_time_ms) + + # Define models to compare + models = [ + { + "name": "Simple MLP", + "files": [model_directory / "munsell_to_xyY_approximator.onnx"], + "params_file": model_directory + / "munsell_to_xyY_approximator_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head", + "files": [model_directory / "multi_head.onnx"], + "params_file": model_directory / "multi_head_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head (Optimized)", + "files": [model_directory / "multi_head_optimized.onnx"], + "params_file": model_directory + / "multi_head_optimized_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-Head + Multi-Error Predictor", + "files": [ + model_directory / "multi_head.onnx", + model_directory / "multi_head_multi_error_predictor.onnx", + ], + "params_file": model_directory + / "multi_head_multi_error_predictor_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP", + "files": [model_directory / "multi_mlp.onnx"], + "params_file": model_directory / "multi_mlp_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-MLP (Optimized)", + "files": [model_directory / "multi_mlp_optimized.onnx"], + "params_file": model_directory + / "multi_mlp_optimized_normalization_params.npz", + "type": "single", + }, + { + "name": "Multi-MLP + Error Predictor", + "files": [ + model_directory / "multi_mlp.onnx", + model_directory / "multi_mlp_error_predictor.onnx", + ], + "params_file": model_directory + / "multi_mlp_error_predictor_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP + Multi-Error Predictor", + "files": [ + model_directory / "multi_mlp.onnx", + model_directory / "multi_mlp_multi_error_predictor.onnx", + ], + "params_file": model_directory + / "multi_mlp_multi_error_predictor_normalization_params.npz", + "type": "two_stage", + }, + { + "name": "Multi-MLP (Optimized) + Multi-Error Predictor (Optimized)", + "files": [ + model_directory / "multi_mlp_optimized.onnx", + model_directory / "multi_mlp_multi_error_predictor_optimized.onnx", + ], + "params_file": model_directory + / "multi_mlp_multi_error_predictor_optimized_normalization_params.npz", + "type": "two_stage", + }, + ] + + # Evaluate each model + results = {} + + for model_info in models: + model_name = model_info["name"] + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info(model_name) + LOGGER.info("=" * 80) + + # Calculate model size + model_size_mb = get_model_size_mb(model_info["files"]) + + if model_info["type"] == "two_stage": + # Two-stage model + base_session = ort.InferenceSession(str(model_info["files"][0])) + error_session = ort.InferenceSession(str(model_info["files"][1])) + error_input_name = error_session.get_inputs()[0].name + + # Define inference callable + def two_stage_inference( + _base_session: ort.InferenceSession = base_session, + _error_session: ort.InferenceSession = error_session, + _munsell_normalized: NDArray = munsell_normalized, + _error_input_name: str = error_input_name, + ) -> NDArray: + base_pred = _base_session.run( + None, {"munsell_normalized": _munsell_normalized} + )[0] + combined = np.concatenate( + [_munsell_normalized, base_pred], axis=1 + ).astype(np.float32) + error_corr = _error_session.run( + None, {_error_input_name: combined} + )[0] + return base_pred + error_corr + + # Benchmark speed + inference_time_ms = benchmark_inference_speed( + two_stage_inference, munsell_normalized + ) + + # Get predictions + base_pred = base_session.run( + None, {"munsell_normalized": munsell_normalized} + )[0] + combined = np.concatenate( + [munsell_normalized, base_pred], axis=1 + ).astype(np.float32) + error_corr = error_session.run( + None, {error_input_name: combined} + )[0] + pred = base_pred + error_corr + + errors = np.abs(pred - xyY_ground_truth) + result = { + "x_mae": np.mean(errors[:, 0]), + "y_mae": np.mean(errors[:, 1]), + "Y_mae": np.mean(errors[:, 2]), + "predictions": pred, + "errors": errors, + "max_errors": np.max(errors, axis=1), + } + else: + # Single model + session = ort.InferenceSession(str(model_info["files"][0])) + + # Define inference callable + def single_inference( + _session: ort.InferenceSession = session, + _munsell_normalized: NDArray = munsell_normalized, + ) -> NDArray: + return _session.run( + None, {"munsell_normalized": _munsell_normalized} + )[0] + + # Benchmark speed + inference_time_ms = benchmark_inference_speed( + single_inference, munsell_normalized + ) + + result = evaluate_model(session, munsell_normalized, xyY_ground_truth) + + result["model_size_mb"] = model_size_mb + result["inference_time_ms"] = inference_time_ms + result["avg_mae"] = np.mean( + [result["x_mae"], result["y_mae"], result["Y_mae"]] + ) + + # Compute Delta-E against ground truth (measured xyY) + sampled_predictions = result["predictions"][sampled_indices] + result["delta_E"] = compute_delta_E( + sampled_predictions, + xyY_ground_truth, + ) + + results[model_name] = result + + # Print results + LOGGER.info("") + LOGGER.info("Mean Absolute Errors:") + LOGGER.info(" x: %.6f", result["x_mae"]) + LOGGER.info(" y: %.6f", result["y_mae"]) + LOGGER.info(" Y: %.6f", result["Y_mae"]) + if not np.isnan(result["delta_E"]): + LOGGER.info(" Delta-E (vs Ground Truth): %.6f", result["delta_E"]) + LOGGER.info("") + LOGGER.info("Performance Metrics:") + LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"]) + LOGGER.info( + " Inference Speed: %.4f ms/sample", result["inference_time_ms"] + ) + LOGGER.info( + " Speedup vs Colour: %.1fx", + baseline_inference_time_ms / inference_time_ms, + ) + + + # Summary + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("SUMMARY COMPARISON") + LOGGER.info("=" * 80) + LOGGER.info("") + + if not results: + LOGGER.info("No models were successfully evaluated") + return + + # MAE comparison table + LOGGER.info("Mean Absolute Error Comparison:") + LOGGER.info("") + header = f"{'Model':<40} {'x':>10} {'y':>10} {'Y':>10} {'Delta-E':>12}" + LOGGER.info(header) + LOGGER.info("-" * 85) + + for model_name, result in results.items(): + delta_E_str = ( + f"{result['delta_E']:.6f}" if not np.isnan(result["delta_E"]) else "N/A" + ) + LOGGER.info( + "%-40s %10.6f %10.6f %10.6f %12s", + model_name, + result["x_mae"], + result["y_mae"], + result["Y_mae"], + delta_E_str, + ) + + # Generate HTML report + report_dir = PROJECT_ROOT / "reports" / "to_xyY" + report_dir.mkdir(parents=True, exist_ok=True) + report_file = report_dir / "model_comparison.html" + generate_html_report( + results, len(munsell_specs), report_file, baseline_inference_time_ms + ) + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/data_generation/generate_training_data.py b/learning_munsell/data_generation/generate_training_data.py new file mode 100644 index 0000000000000000000000000000000000000000..5b72c9a6317729570f6c9a40a3707ac3d5b4641a --- /dev/null +++ b/learning_munsell/data_generation/generate_training_data.py @@ -0,0 +1,310 @@ +""" +Generate training data for ML-based xyY to Munsell conversion. + +Generates samples by sampling in Munsell space and converting to xyY via +forward conversion, guaranteeing 100% valid samples. + +Usage: + uv run python -m learning_munsell.data_generation.generate_training_data + uv run python -m learning_munsell.data_generation.generate_training_data \\ + --n-samples 2000000 --perturbation 0.10 --output training_data_large +""" + +import argparse +import json +import logging +import multiprocessing as mp +import warnings +from datetime import datetime, timezone + +import numpy as np +from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL +from colour.notation.munsell import ( + munsell_colour_to_munsell_specification, + munsell_specification_to_xyY, +) +from colour.utilities import ColourUsageWarning +from numpy.typing import NDArray +from sklearn.model_selection import train_test_split + +from learning_munsell import PROJECT_ROOT + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +def _worker_generate_samples( + args: tuple[int, NDArray, int, float], +) -> tuple[list[NDArray], list[NDArray]]: + """ + Worker function to generate samples in parallel. + + Parameters + ---------- + args : tuple + - worker_id: Worker identifier + - base_specs: Array of base Munsell specifications + - samples_per_base: Number of samples to generate per base color + - perturbation_pct: Perturbation percentage + + Returns + ------- + tuple + - xyY_samples: List of xyY arrays + - munsell_samples: List of Munsell specification arrays + """ + worker_id, base_specs, samples_per_base, perturbation_pct = args + + np.random.seed(42 + worker_id) + + warnings.filterwarnings("ignore", category=ColourUsageWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + xyY_samples = [] + munsell_samples = [] + + hue_range = 9.5 + value_range = 9.0 + chroma_range = 50.0 + + for base_spec in base_specs: + for _ in range(samples_per_base): + hue_delta = np.random.uniform( + -perturbation_pct * hue_range, perturbation_pct * hue_range + ) + value_delta = np.random.uniform( + -perturbation_pct * value_range, perturbation_pct * value_range + ) + chroma_delta = np.random.uniform( + -perturbation_pct * chroma_range, perturbation_pct * chroma_range + ) + + perturbed_spec = base_spec.copy() + perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0) + perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0) + perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0) + + try: + xyY = munsell_specification_to_xyY(perturbed_spec) + xyY_samples.append(xyY) + munsell_samples.append(perturbed_spec) + except Exception: # noqa: BLE001, S110 + continue + + return xyY_samples, munsell_samples + + +def generate_forward_munsell_samples( + n_samples: int = 500000, + perturbation_pct: float = 0.05, + n_workers: int | None = None, +) -> tuple[NDArray, NDArray]: + """ + Generate samples by sampling directly in Munsell space and converting to xyY. + + Parameters + ---------- + n_samples : int + Target number of samples to generate. + perturbation_pct : float + Perturbation as percentage of valid range. + n_workers : int, optional + Number of parallel workers. Defaults to CPU count. + + Returns + ------- + tuple + - xyY_samples: Array of shape (n, 3) with xyY values + - munsell_samples: Array of shape (n, 4) with Munsell specifications + """ + if n_workers is None: + n_workers = mp.cpu_count() + + LOGGER.info( + "Generating %d samples with %.0f%% perturbations using %d workers...", + n_samples, + perturbation_pct * 100, + n_workers, + ) + + # Extract base Munsell specifications + base_specs = [] + for munsell_spec_tuple, _ in MUNSELL_COLOURS_ALL: + hue_code_str, value, chroma = munsell_spec_tuple + munsell_str = f"{hue_code_str} {value}/{chroma}" + spec = munsell_colour_to_munsell_specification(munsell_str) + base_specs.append(spec) + + base_specs = np.array(base_specs) + samples_per_base = n_samples // len(base_specs) + 1 + + LOGGER.info("Using %d base Munsell colors", len(base_specs)) + LOGGER.info("Generating ~%d samples per base color", samples_per_base) + + # Split base specs across workers + specs_per_worker = len(base_specs) // n_workers + worker_args = [] + + for i in range(n_workers): + start_idx = i * specs_per_worker + end_idx = start_idx + specs_per_worker if i < n_workers - 1 else len(base_specs) + worker_specs = base_specs[start_idx:end_idx] + worker_args.append((i, worker_specs, samples_per_base, perturbation_pct)) + + # Run in parallel + LOGGER.info("Starting %d parallel workers...", n_workers) + with mp.Pool(n_workers) as pool: + results = pool.map(_worker_generate_samples, worker_args) + + # Combine results + all_xyY = [] + all_munsell = [] + for xyY_samples, munsell_samples in results: + all_xyY.extend(xyY_samples) + all_munsell.extend(munsell_samples) + + # Trim to exact sample count + all_xyY = all_xyY[:n_samples] + all_munsell = all_munsell[:n_samples] + + LOGGER.info("Generated %d valid samples", len(all_xyY)) + return np.array(all_xyY), np.array(all_munsell) + + +def main( + n_samples: int = 500000, + perturbation_pct: float = 0.05, + output: str = "training_data", +) -> None: + """Generate and save training data.""" + LOGGER.info("=" * 80) + LOGGER.info("Training Data Generation") + LOGGER.info("=" * 80) + + output_dir = PROJECT_ROOT / "data" + output_dir.mkdir(exist_ok=True) + + LOGGER.info("") + LOGGER.info("SAMPLING STRATEGY") + LOGGER.info("=" * 80) + LOGGER.info("Forward Munsell->xyY sampling:") + LOGGER.info( + " - Base: %d colors from MUNSELL_COLOURS_ALL", len(MUNSELL_COLOURS_ALL) + ) + LOGGER.info( + " - Perturbations: +/-%.0f%% of valid range per component", + perturbation_pct * 100, + ) + LOGGER.info( + " - Hue: +/-%.2f (+/-%.0f%% of 9.5 range)", + perturbation_pct * 9.5, + perturbation_pct * 100, + ) + LOGGER.info( + " - Value: +/-%.2f (+/-%.0f%% of 9.0 range)", + perturbation_pct * 9.0, + perturbation_pct * 100, + ) + LOGGER.info( + " - Chroma: +/-%.1f (+/-%.0f%% of 50.0 range)", + perturbation_pct * 50.0, + perturbation_pct * 100, + ) + LOGGER.info(" - Target samples: %d", n_samples) + LOGGER.info("=" * 80) + LOGGER.info("") + + # Generate samples + xyY_all, munsell_all = generate_forward_munsell_samples( + n_samples=n_samples, + perturbation_pct=perturbation_pct, + ) + + valid_mask = np.ones(len(xyY_all), dtype=bool) + + LOGGER.info("") + LOGGER.info("Sample statistics:") + LOGGER.info(" Total samples generated: %d", len(xyY_all)) + LOGGER.info(" All samples are valid (100%% by forward conversion)") + + LOGGER.info("") + LOGGER.info("Using %d valid samples for training", len(xyY_all)) + + # Split into train/validation/test (70/15/15) + X_temp, X_test, y_temp, y_test = train_test_split( + xyY_all, munsell_all, test_size=0.15, random_state=42 + ) + X_train, X_val, y_train, y_val = train_test_split( + X_temp, y_temp, test_size=0.15 / 0.85, random_state=42 + ) + + LOGGER.info("") + LOGGER.info("Data split:") + LOGGER.info(" Train: %d samples", len(X_train)) + LOGGER.info(" Validation: %d samples", len(X_val)) + LOGGER.info(" Test: %d samples", len(X_test)) + + # Save training data + cache_file = output_dir / f"{output}.npz" + np.savez_compressed( + cache_file, + X_train=X_train, + y_train=y_train, + X_val=X_val, + y_val=y_val, + X_test=X_test, + y_test=y_test, + xyY_all=xyY_all, + munsell_all=munsell_all, + valid_mask=valid_mask, + ) + + # Save parameters to sidecar file + params_file = output_dir / f"{output}_params.json" + params = { + "n_samples": n_samples, + "perturbation_pct": perturbation_pct, + "n_base_colors": len(MUNSELL_COLOURS_ALL), + "train_samples": len(X_train), + "val_samples": len(X_val), + "test_samples": len(X_test), + "generated_at": datetime.now(timezone.utc).isoformat(), + } + with open(params_file, "w") as f: + json.dump(params, f, indent=2) + + LOGGER.info("") + LOGGER.info("Training data saved to: %s", cache_file) + LOGGER.info("Parameters saved to: %s", params_file) + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate training data for xyY to Munsell conversion" + ) + parser.add_argument( + "--n-samples", + type=int, + default=500000, + help="Number of samples to generate (default: 500000)", + ) + parser.add_argument( + "--perturbation", + type=float, + default=0.05, + help="Perturbation as fraction of valid range (default: 0.05)", + ) + parser.add_argument( + "--output", + type=str, + default="training_data", + help="Output filename without extension (default: training_data)", + ) + args = parser.parse_args() + + main( + n_samples=args.n_samples, + perturbation_pct=args.perturbation, + output=args.output, + ) diff --git a/learning_munsell/interpolation/__init__.py b/learning_munsell/interpolation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5167886e2be88638356214968d26f7212f1e70 --- /dev/null +++ b/learning_munsell/interpolation/__init__.py @@ -0,0 +1 @@ +"""Interpolation-based methods for Munsell conversions.""" diff --git a/learning_munsell/interpolation/from_xyY/__init__.py b/learning_munsell/interpolation/from_xyY/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1069646b6bb56174322c41558fdec874f7cd3998 --- /dev/null +++ b/learning_munsell/interpolation/from_xyY/__init__.py @@ -0,0 +1,43 @@ +"""Interpolation-based methods for xyY to Munsell conversions.""" + +import numpy as np +from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL +from colour.notation.munsell import munsell_colour_to_munsell_specification +from numpy.typing import NDArray + + +def load_munsell_reference_data() -> tuple[NDArray, NDArray]: + """ + Load reference Munsell data from colour library. + + Returns xyY coordinates and corresponding Munsell specifications + [hue, value, chroma, code] for all 4,995 reference colors. + + The Y values are normalized to [0, 1] range (originally 0-102.57). + + Returns + ------- + Tuple[NDArray, NDArray] + X : xyY values of shape (4995, 3) with Y normalized to [0, 1] + y : Munsell specifications of shape (4995, 4) + """ + xyY_list = [] + munsell_list = [] + + for munsell_tuple, xyY in MUNSELL_COLOURS_ALL: + hue_name, value, chroma = munsell_tuple + munsell_string = f"{hue_name} {value}/{chroma}" + + # Convert to numeric specification [hue, value, chroma, code] + spec = munsell_colour_to_munsell_specification(munsell_string) + + # Normalize Y to [0, 1] range (max ~102.57) + xyY_normalized = np.array([xyY[0], xyY[1], xyY[2] / 100.0]) + + xyY_list.append(xyY_normalized) + munsell_list.append(spec) + + return np.array(xyY_list), np.array(munsell_list) + + +__all__ = ["load_munsell_reference_data"] diff --git a/learning_munsell/interpolation/from_xyY/compare_methods.py b/learning_munsell/interpolation/from_xyY/compare_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..60d82aef2c4e03837e7d1a4a03151e875fb1f29b --- /dev/null +++ b/learning_munsell/interpolation/from_xyY/compare_methods.py @@ -0,0 +1,208 @@ +""" +Compare classical interpolation methods against the best ML model. + +Evaluates RBF, KD-Tree, and Delaunay interpolation on REAL Munsell colors +and compares with the Multi-Head (W+B) + Multi-Error Predictor (W+B) model. +""" + +import logging + +import numpy as np +import onnxruntime as ort +from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL +from colour.notation.munsell import munsell_colour_to_munsell_specification +from scipy.interpolate import LinearNDInterpolator, RBFInterpolator +from scipy.spatial import KDTree +from sklearn.model_selection import train_test_split + +from learning_munsell import PROJECT_ROOT + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +def load_reference_data(): + """Load ALL Munsell colors as training data for interpolators.""" + X, y = [], [] + for munsell_tuple, xyY in MUNSELL_COLOURS_ALL: + hue_name, value, chroma = munsell_tuple + munsell_str = f"{hue_name} {value}/{chroma}" + spec = munsell_colour_to_munsell_specification(munsell_str) + # Normalize Y to [0, 1] + X.append([xyY[0], xyY[1], xyY[2] / 100.0]) + y.append(spec) + return np.array(X), np.array(y) + + + + +def evaluate(predictions, y_true, method_name): + """Calculate MAE for each component.""" + errors = np.abs(predictions - y_true) + results = { + "hue": errors[:, 0].mean(), + "value": errors[:, 1].mean(), + "chroma": errors[:, 2].mean(), + "code": errors[:, 3].mean(), + } + LOGGER.info(" %s:", method_name) + for comp in ["hue", "value", "chroma", "code"]: + LOGGER.info(" %s MAE: %.4f", comp.capitalize(), results[comp]) + return results + + +def rbf_predict(X_train, y_train, X_test): + """RBF interpolation prediction.""" + predictions = np.zeros((len(X_test), 4)) + for i in range(4): + rbf = RBFInterpolator(X_train, y_train[:, i], kernel="thin_plate_spline") + predictions[:, i] = rbf(X_test) + return predictions + + +def kdtree_predict(X_train, y_train, X_test, k=5): + """KD-Tree with inverse distance weighting prediction.""" + tree = KDTree(X_train) + distances, indices = tree.query(X_test, k=k) + distances = np.maximum(distances, 1e-10) + weights = 1.0 / (distances**2) + weights /= weights.sum(axis=1, keepdims=True) + + predictions = np.zeros((len(X_test), 4)) + for i in range(len(X_test)): + predictions[i] = np.sum(weights[i, :, np.newaxis] * y_train[indices[i]], axis=0) + return predictions + + +def delaunay_predict(X_train, y_train, X_test): + """Delaunay interpolation with NN fallback.""" + predictions = np.zeros((len(X_test), 4)) + tree = KDTree(X_train) + + for i in range(4): + interp = LinearNDInterpolator(X_train, y_train[:, i]) + predictions[:, i] = interp(X_test) + + # Fallback to nearest neighbor for NaN + nan_mask = np.any(np.isnan(predictions), axis=1) + if nan_mask.sum() > 0: + _, indices = tree.query(X_test[nan_mask]) + predictions[nan_mask] = y_train[indices] + + return predictions + + +def ml_predict(X_test): + """ML model prediction using base + error predictor.""" + base_path = PROJECT_ROOT / "models" / "from_xyY" / "multi_head_weighted_boundary.onnx" + error_path = ( + PROJECT_ROOT + / "models" + / "from_xyY" + / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx" + ) + + if not base_path.exists() or not error_path.exists(): + return None + + # Input is already normalized to [0, 1] for x, y, Y + X_norm = X_test.astype(np.float32) + + # Base model prediction + base_session = ort.InferenceSession(str(base_path)) + base_out = base_session.run(None, {"xyY": X_norm})[0] + + # Error predictor (takes xyY + base predictions) + error_session = ort.InferenceSession(str(error_path)) + combined_input = np.concatenate([X_norm, base_out], axis=1).astype(np.float32) + error_out = error_session.run(None, {"combined_input": combined_input})[0] + + # Combined prediction (normalized) + pred_norm = base_out + error_out + + # Denormalize using actual ranges from params file + predictions = np.zeros_like(pred_norm) + predictions[:, 0] = pred_norm[:, 0] * (10.0 - 0.5) + 0.5 # Hue: [0.5, 10] + predictions[:, 1] = pred_norm[:, 1] * (10.0 - 0.0) + 0.0 # Value: [0, 10] + predictions[:, 2] = pred_norm[:, 2] * (50.0 - 0.0) + 0.0 # Chroma: [0, 50] + predictions[:, 3] = pred_norm[:, 3] * (10.0 - 1.0) + 1.0 # Code: [1, 10] + + return predictions + + +def main(): + """Compare all methods using held-out test set.""" + LOGGER.info("=" * 80) + LOGGER.info("Classical Interpolation vs ML Model Comparison") + LOGGER.info("=" * 80) + + LOGGER.info("") + LOGGER.info("Loading data...") + X_all, y_all = load_reference_data() + + # 80/20 train/test split for fair comparison + X_train, X_test, y_train, y_test = train_test_split( + X_all, y_all, test_size=0.2, random_state=42 + ) + LOGGER.info(" Total: %d colors", len(X_all)) + LOGGER.info(" Training: %d colors (80%%)", len(X_train)) + LOGGER.info(" Test: %d colors (20%%)", len(X_test)) + + results = {} + + # RBF + LOGGER.info("") + LOGGER.info("-" * 60) + LOGGER.info("RBF Interpolation (thin_plate_spline)") + rbf_pred = rbf_predict(X_train, y_train, X_test) + results["RBF"] = evaluate(rbf_pred, y_test, "RBF") + + # KD-Tree + LOGGER.info("") + LOGGER.info("-" * 60) + LOGGER.info("KD-Tree Interpolation (k=5, IDW)") + kdt_pred = kdtree_predict(X_train, y_train, X_test, k=5) + results["KD-Tree"] = evaluate(kdt_pred, y_test, "KD-Tree") + + # Delaunay + LOGGER.info("") + LOGGER.info("-" * 60) + LOGGER.info("Delaunay Interpolation (with NN fallback)") + del_pred = delaunay_predict(X_train, y_train, X_test) + results["Delaunay"] = evaluate(del_pred, y_test, "Delaunay") + + # ML + LOGGER.info("") + LOGGER.info("-" * 60) + LOGGER.info("ML Model (Multi-Head W+B + Multi-Error Predictor W+B)") + ml_pred = ml_predict(X_test) + if ml_pred is not None: + results["ML"] = evaluate(ml_pred, y_test, "ML") + else: + LOGGER.info(" Skipped (model not found)") + + # Summary + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("SUMMARY (MAE on %d held-out test colors)", len(X_test)) + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("%-12s %8s %8s %8s %8s", "Method", "Hue", "Value", "Chroma", "Code") + LOGGER.info("-" * 52) + + for method, mae in results.items(): + LOGGER.info( + "%-12s %8.4f %8.4f %8.4f %8.4f", + method, + mae["hue"], + mae["value"], + mae["chroma"], + mae["code"], + ) + + LOGGER.info("") + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/interpolation/from_xyY/delaunay_interpolator.py b/learning_munsell/interpolation/from_xyY/delaunay_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..722a39da3195b4be7f0abd8bfa820543c0a14ef1 --- /dev/null +++ b/learning_munsell/interpolation/from_xyY/delaunay_interpolator.py @@ -0,0 +1,283 @@ +""" +Delaunay triangulation based interpolation for xyY to Munsell conversion. + +This approach uses scipy's LinearNDInterpolator which performs piecewise +linear interpolation based on Delaunay triangulation. + +Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly. + +Advantages: +- Piecewise linear: exact at data points, linear between +- Handles irregular point distributions +- No hyperparameters to tune + +Disadvantages: +- Returns NaN outside convex hull of data points +- Non-convex Munsell boundary may cause issues +- C0 continuous only (discontinuous gradients at cell boundaries) +""" + +import logging +import pickle +from pathlib import Path + +import numpy as np +from numpy.typing import NDArray +from scipy.interpolate import LinearNDInterpolator +from scipy.spatial import KDTree +from sklearn.model_selection import train_test_split + +from learning_munsell import PROJECT_ROOT, setup_logging +from learning_munsell.interpolation.from_xyY import load_munsell_reference_data + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +class MunsellDelaunayInterpolator: + """ + Delaunay triangulation based interpolator for xyY to Munsell conversion. + + Uses LinearNDInterpolator for piecewise linear interpolation within + the Delaunay triangulation. Falls back to nearest neighbor for points + outside the convex hull. + """ + + def __init__(self, fallback_to_nearest: bool = True) -> None: + """ + Initialize the Delaunay interpolator. + + Parameters + ---------- + fallback_to_nearest + If True, use nearest neighbor for points outside convex hull. + If False, return NaN for such points. + """ + self.fallback_to_nearest = fallback_to_nearest + self.interpolators: dict = {} + self.kdtree: KDTree | None = None + self.y_data: NDArray | None = None + self.fitted = False + + def fit(self, X: NDArray, y: NDArray) -> "MunsellDelaunayInterpolator": + """ + Build the Delaunay interpolator from training data. + + Parameters + ---------- + X + xyY input values of shape (n, 3) + y + Munsell output values [hue, value, chroma, code] of shape (n, 4) + + Returns + ------- + self + """ + LOGGER.info("Building Delaunay interpolator...") + LOGGER.info(" Fallback to nearest: %s", self.fallback_to_nearest) + LOGGER.info(" Data points: %d", len(X)) + + component_names = ["hue", "value", "chroma", "code"] + + for i, name in enumerate(component_names): + LOGGER.info(" Building %s interpolator...", name) + self.interpolators[name] = LinearNDInterpolator(X, y[:, i]) + + # Build KDTree for nearest neighbor fallback + if self.fallback_to_nearest: + LOGGER.info(" Building KD-Tree for fallback...") + self.kdtree = KDTree(X) + self.y_data = y.copy() + + self.fitted = True + LOGGER.info("Delaunay interpolator built successfully") + return self + + def predict(self, X: NDArray) -> NDArray: + """ + Predict Munsell values using Delaunay interpolation. + + Parameters + ---------- + X + xyY input values of shape (n, 3) + + Returns + ------- + NDArray + Predicted Munsell values [hue, value, chroma, code] of shape (n, 4) + """ + if not self.fitted: + msg = "Interpolator not fitted. Call fit() first." + raise RuntimeError(msg) + + results = np.zeros((len(X), 4)) + + for i, name in enumerate(["hue", "value", "chroma", "code"]): + results[:, i] = self.interpolators[name](X) + + # Handle NaN values (points outside convex hull) + if self.fallback_to_nearest: + nan_mask = np.any(np.isnan(results), axis=1) + n_nan = nan_mask.sum() + + if n_nan > 0: + LOGGER.debug(" %d points outside hull, using nearest neighbor", n_nan) + # Find nearest neighbors for NaN points + _, indices = self.kdtree.query(X[nan_mask]) + results[nan_mask] = self.y_data[indices] + + return results + + def save(self, path: Path) -> None: + """Save the interpolator to disk.""" + with open(path, "wb") as f: + pickle.dump( + { + "fallback_to_nearest": self.fallback_to_nearest, + "interpolators": self.interpolators, + "kdtree": self.kdtree, + "y_data": self.y_data, + }, + f, + ) + LOGGER.info("Saved Delaunay interpolator to %s", path) + + @classmethod + def load(cls, path: Path) -> "MunsellDelaunayInterpolator": + """Load the interpolator from disk.""" + with open(path, "rb") as f: + data = pickle.load(f) # noqa: S301 + + instance = cls(fallback_to_nearest=data["fallback_to_nearest"]) + instance.interpolators = data["interpolators"] + instance.kdtree = data["kdtree"] + instance.y_data = data["y_data"] + instance.fitted = True + + LOGGER.info("Loaded Delaunay interpolator from %s", path) + return instance + + +def evaluate_delaunay( + interpolator: MunsellDelaunayInterpolator, + X: NDArray, + y: NDArray, + name: str = "Test", +) -> dict: + """Evaluate Delaunay interpolator performance.""" + predictions = interpolator.predict(X) + + # Check for NaN values + nan_count = np.isnan(predictions).any(axis=1).sum() + if nan_count > 0: + LOGGER.warning(" %d/%d predictions contain NaN", nan_count, len(X)) + + # Filter out NaN for error calculation + valid_mask = ~np.isnan(predictions).any(axis=1) + if valid_mask.sum() == 0: + LOGGER.error(" All predictions are NaN!") + return { + "hue": float("nan"), + "value": float("nan"), + "chroma": float("nan"), + "code": float("nan"), + } + + errors = np.abs(predictions[valid_mask] - y[valid_mask]) + + component_names = ["Hue", "Value", "Chroma", "Code"] + results = {} + + LOGGER.info("%s set MAE (%d/%d valid):", name, valid_mask.sum(), len(X)) + for i, comp_name in enumerate(component_names): + mae = errors[:, i].mean() + results[comp_name.lower()] = mae + LOGGER.info(" %s: %.4f", comp_name, mae) + + return results + + +def main() -> None: + """Build and evaluate Delaunay interpolator using reference Munsell data.""" + + log_file = setup_logging("delaunay_interpolator", "from_xyY") + + LOGGER.info("=" * 80) + LOGGER.info("Delaunay Interpolation for xyY to Munsell Conversion") + LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)") + LOGGER.info("=" * 80) + + # Load reference data from colour library + LOGGER.info("") + LOGGER.info("Loading reference Munsell data...") + X_all, y_all = load_munsell_reference_data() + LOGGER.info("Total reference colors: %d", len(X_all)) + + # Split into train/validation (80/20) + X_train, X_val, y_train, y_val = train_test_split( + X_all, y_all, test_size=0.2, random_state=42 + ) + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Test with and without fallback + LOGGER.info("") + LOGGER.info("Testing Delaunay interpolation...") + LOGGER.info("-" * 60) + + best_config = None + best_mae = float("inf") + + for fallback in [True, False]: + LOGGER.info("") + LOGGER.info("Fallback to nearest: %s", fallback) + + interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=fallback) + interpolator.fit(X_train, y_train) + + results = evaluate_delaunay(interpolator, X_val, y_val, "Validation") + + # Skip if results contain NaN + if any(np.isnan(v) for v in results.values()): + LOGGER.info(" Skipping due to NaN results") + continue + + total_mae = sum(results.values()) + + if total_mae < best_mae: + best_mae = total_mae + best_config = fallback + + LOGGER.info("") + LOGGER.info("=" * 60) + LOGGER.info("Best configuration: fallback_to_nearest=%s", best_config) + LOGGER.info("=" * 60) + + # Train final model on ALL data + LOGGER.info("") + LOGGER.info("Training final model on all %d reference colors...", len(X_all)) + + final_interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=best_config) + final_interpolator.fit(X_all, y_all) + + LOGGER.info("") + LOGGER.info("Final evaluation (training set = all data):") + evaluate_delaunay(final_interpolator, X_all, y_all, "All data") + + # Save the model + model_dir = PROJECT_ROOT / "models" / "from_xyY" + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / "delaunay_interpolator.pkl" + final_interpolator.save(model_path) + + LOGGER.info("") + LOGGER.info("=" * 80) + + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/interpolation/from_xyY/kdtree_interpolator.py b/learning_munsell/interpolation/from_xyY/kdtree_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2c50055c86e491373ed37a83d2ee0d579368a1 --- /dev/null +++ b/learning_munsell/interpolation/from_xyY/kdtree_interpolator.py @@ -0,0 +1,263 @@ +""" +KD-Tree based interpolation for xyY to Munsell conversion. + +This approach uses scipy's KDTree for fast nearest neighbor lookups, +with optional weighted interpolation using k nearest neighbors. + +Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly. + +Advantages over RBF: +- O(n) memory, O(log n) query time +- Scales to millions of data points +- No matrix inversion required + +Advantages over ML: +- Deterministic +- No training required +- Easy to understand +""" + +import logging +import pickle +from pathlib import Path + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial import KDTree +from sklearn.model_selection import train_test_split + +from learning_munsell import PROJECT_ROOT, setup_logging +from learning_munsell.interpolation.from_xyY import load_munsell_reference_data + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +class MunsellKDTreeInterpolator: + """ + KD-Tree based interpolator for xyY to Munsell conversion. + + Uses k-nearest neighbors with inverse distance weighting + for smooth interpolation. + """ + + def __init__(self, k: int = 5, power: float = 2.0) -> None: + """ + Initialize the KD-Tree interpolator. + + Parameters + ---------- + k + Number of nearest neighbors to use for interpolation. + power + Power for inverse distance weighting. Higher = sharper. + """ + self.k = k + self.power = power + self.tree: KDTree | None = None + self.y_data: NDArray | None = None + self.fitted = False + + def fit(self, X: NDArray, y: NDArray) -> "MunsellKDTreeInterpolator": + """ + Build the KD-Tree from training data. + + Parameters + ---------- + X + xyY input values of shape (n, 3) + y + Munsell output values [hue, value, chroma, code] of shape (n, 4) + + Returns + ------- + self + """ + LOGGER.info("Building KD-Tree interpolator...") + LOGGER.info(" k neighbors: %d", self.k) + LOGGER.info(" IDW power: %.1f", self.power) + LOGGER.info(" Data points: %d", len(X)) + + self.tree = KDTree(X) + self.y_data = y.copy() + self.fitted = True + + LOGGER.info("KD-Tree built successfully") + return self + + def predict(self, X: NDArray) -> NDArray: + """ + Predict Munsell values using k-NN with IDW. + + Parameters + ---------- + X + xyY input values of shape (n, 3) + + Returns + ------- + NDArray + Predicted Munsell values [hue, value, chroma, code] of shape (n, 4) + """ + if not self.fitted: + msg = "Interpolator not fitted. Call fit() first." + raise RuntimeError(msg) + + # Query k nearest neighbors + distances, indices = self.tree.query(X, k=self.k) + + # Ensure 2D arrays for consistent handling + if self.k == 1: + distances = distances.reshape(-1, 1) + indices = indices.reshape(-1, 1) + + # Inverse distance weighting + # Avoid division by zero + distances = np.maximum(distances, 1e-10) + weights = 1.0 / (distances**self.power) + weights /= weights.sum(axis=1, keepdims=True) + + # Weighted average of neighbor values + results = np.zeros((len(X), 4)) + for i in range(len(X)): + neighbor_values = self.y_data[indices[i]] + if self.k == 1: + results[i] = neighbor_values.flatten() + else: + results[i] = np.sum(weights[i, :, np.newaxis] * neighbor_values, axis=0) + + return results + + def save(self, path: Path) -> None: + """Save the interpolator to disk.""" + with open(path, "wb") as f: + pickle.dump( + { + "k": self.k, + "power": self.power, + "tree": self.tree, + "y_data": self.y_data, + }, + f, + ) + LOGGER.info("Saved KD-Tree interpolator to %s", path) + + @classmethod + def load(cls, path: Path) -> "MunsellKDTreeInterpolator": + """Load the interpolator from disk.""" + with open(path, "rb") as f: + data = pickle.load(f) # noqa: S301 + + instance = cls(k=data["k"], power=data["power"]) + instance.tree = data["tree"] + instance.y_data = data["y_data"] + instance.fitted = True + + LOGGER.info("Loaded KD-Tree interpolator from %s", path) + return instance + + +def evaluate_kdtree( + interpolator: MunsellKDTreeInterpolator, + X: NDArray, + y: NDArray, + name: str = "Test", +) -> dict: + """Evaluate KD-Tree interpolator performance.""" + predictions = interpolator.predict(X) + errors = np.abs(predictions - y) + + component_names = ["Hue", "Value", "Chroma", "Code"] + results = {} + + LOGGER.info("%s set MAE:", name) + for i, comp_name in enumerate(component_names): + mae = errors[:, i].mean() + results[comp_name.lower()] = mae + LOGGER.info(" %s: %.4f", comp_name, mae) + + return results + + +def main() -> None: + """Build and evaluate KD-Tree interpolator using reference Munsell data.""" + + log_file = setup_logging("kdtree_interpolator", "from_xyY") + + LOGGER.info("=" * 80) + LOGGER.info("KD-Tree Interpolation for xyY to Munsell Conversion") + LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)") + LOGGER.info("=" * 80) + + # Load reference data from colour library + LOGGER.info("") + LOGGER.info("Loading reference Munsell data...") + X_all, y_all = load_munsell_reference_data() + LOGGER.info("Total reference colors: %d", len(X_all)) + + # Split into train/validation (80/20) + X_train, X_val, y_train, y_val = train_test_split( + X_all, y_all, test_size=0.2, random_state=42 + ) + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Test different k values + k_values = [1, 3, 5, 10, 20, 50] + + best_k = None + best_mae = float("inf") + + LOGGER.info("") + LOGGER.info("Testing different k values...") + LOGGER.info("-" * 60) + + for k in k_values: + LOGGER.info("") + LOGGER.info("k = %d:", k) + + interpolator = MunsellKDTreeInterpolator(k=k, power=2.0) + interpolator.fit(X_train, y_train) + + results = evaluate_kdtree(interpolator, X_val, y_val, "Validation") + total_mae = sum(results.values()) + + if total_mae < best_mae: + best_mae = total_mae + best_k = k + + LOGGER.info("") + LOGGER.info("=" * 60) + LOGGER.info("Best k: %d", best_k) + LOGGER.info("=" * 60) + + # Train final model with best k on ALL data + LOGGER.info("") + LOGGER.info( + "Training final model on all %d reference colors with k=%d...", + len(X_all), + best_k, + ) + + final_interpolator = MunsellKDTreeInterpolator(k=best_k, power=2.0) + final_interpolator.fit(X_all, y_all) + + LOGGER.info("") + LOGGER.info("Final evaluation (training set = all data):") + evaluate_kdtree(final_interpolator, X_all, y_all, "All data") + + # Save the model + model_dir = PROJECT_ROOT / "models" / "from_xyY" + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / "kdtree_interpolator.pkl" + final_interpolator.save(model_path) + + LOGGER.info("") + LOGGER.info("=" * 80) + + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/interpolation/from_xyY/rbf_interpolator.py b/learning_munsell/interpolation/from_xyY/rbf_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b37e4db9f4f7ab10c5b44566d548ae2b649004 --- /dev/null +++ b/learning_munsell/interpolation/from_xyY/rbf_interpolator.py @@ -0,0 +1,300 @@ +""" +RBF (Radial Basis Function) interpolation for xyY to Munsell conversion. + +This approach uses scipy's RBFInterpolator to build a lookup table +with smooth interpolation between known color samples. + +Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly. + +Advantages over ML: +- Deterministic, no training required +- Exact interpolation at known points +- Smooth interpolation between points +- Easy to understand and debug + +Disadvantages: +- Memory scales with number of data points +- Query time scales with data points (O(n) naive, can optimize) +- May struggle with extrapolation +""" + +import logging +import pickle +from pathlib import Path + +import numpy as np +from numpy.typing import NDArray +from scipy.interpolate import RBFInterpolator +from sklearn.model_selection import train_test_split + +from learning_munsell import PROJECT_ROOT, setup_logging +from learning_munsell.interpolation.from_xyY import load_munsell_reference_data + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +class MunsellRBFInterpolator: + """ + RBF-based interpolator for xyY to Munsell conversion. + + Uses separate RBF interpolators for each Munsell component + (hue, value, chroma, code) to allow independent kernel tuning. + """ + + def __init__( + self, + kernel: str = "thin_plate_spline", + smoothing: float = 0.0, + epsilon: float | None = None, + ) -> None: + """ + Initialize the RBF interpolator. + + Parameters + ---------- + kernel + RBF kernel type. Options: 'linear', 'thin_plate_spline', + 'cubic', 'quintic', 'multiquadric', 'inverse_multiquadric', + 'inverse_quadratic', 'gaussian' + smoothing + Smoothing parameter. 0 = exact interpolation. + epsilon + Shape parameter for kernels that use it. + """ + self.kernel = kernel + self.smoothing = smoothing + self.epsilon = epsilon + + self.interpolators: dict[str, RBFInterpolator] = {} + self.fitted = False + + def fit(self, X: NDArray, y: NDArray) -> "MunsellRBFInterpolator": + """ + Fit RBF interpolators to the training data. + + Parameters + ---------- + X + xyY input values of shape (n, 3) + y + Munsell output values [hue, value, chroma, code] of shape (n, 4) + + Returns + ------- + self + """ + LOGGER.info("Fitting RBF interpolators...") + LOGGER.info(" Kernel: %s", self.kernel) + LOGGER.info(" Smoothing: %s", self.smoothing) + LOGGER.info(" Data points: %d", len(X)) + + component_names = ["hue", "value", "chroma", "code"] + + for i, name in enumerate(component_names): + LOGGER.info(" Building %s interpolator...", name) + + kwargs = { + "kernel": self.kernel, + "smoothing": self.smoothing, + } + if self.epsilon is not None: + kwargs["epsilon"] = self.epsilon + + self.interpolators[name] = RBFInterpolator(X, y[:, i], **kwargs) + + self.fitted = True + LOGGER.info("RBF interpolators fitted successfully") + + return self + + def predict(self, X: NDArray) -> NDArray: + """ + Predict Munsell values for given xyY inputs. + + Parameters + ---------- + X + xyY input values of shape (n, 3) + + Returns + ------- + NDArray + Predicted Munsell values [hue, value, chroma, code] of shape (n, 4) + """ + if not self.fitted: + msg = "Interpolator not fitted. Call fit() first." + raise RuntimeError(msg) + + results = np.zeros((len(X), 4)) + + for i, name in enumerate(["hue", "value", "chroma", "code"]): + results[:, i] = self.interpolators[name](X) + + return results + + def save(self, path: Path) -> None: + """Save the interpolator to disk.""" + with open(path, "wb") as f: + pickle.dump( + { + "kernel": self.kernel, + "smoothing": self.smoothing, + "epsilon": self.epsilon, + "interpolators": self.interpolators, + }, + f, + ) + LOGGER.info("Saved RBF interpolator to %s", path) + + @classmethod + def load(cls, path: Path) -> "MunsellRBFInterpolator": + """Load the interpolator from disk.""" + with open(path, "rb") as f: + data = pickle.load(f) # noqa: S301 + + instance = cls( + kernel=data["kernel"], + smoothing=data["smoothing"], + epsilon=data["epsilon"], + ) + instance.interpolators = data["interpolators"] + instance.fitted = True + + LOGGER.info("Loaded RBF interpolator from %s", path) + return instance + + +def evaluate_rbf( + interpolator: MunsellRBFInterpolator, + X: NDArray, + y: NDArray, + name: str = "Test", +) -> dict[str, float]: + """ + Evaluate RBF interpolator performance. + + Parameters + ---------- + interpolator + Fitted RBF interpolator + X + Input xyY values + y + Ground truth Munsell values + name + Name for logging + + Returns + ------- + dict + Dictionary of MAE values for each component + """ + predictions = interpolator.predict(X) + errors = np.abs(predictions - y) + + component_names = ["Hue", "Value", "Chroma", "Code"] + results = {} + + LOGGER.info("%s set MAE:", name) + for i, comp_name in enumerate(component_names): + mae = errors[:, i].mean() + results[comp_name.lower()] = mae + LOGGER.info(" %s: %.4f", comp_name, mae) + + return results + + +def main() -> None: + """Build and evaluate RBF interpolator using reference Munsell data.""" + + log_file = setup_logging("rbf_interpolator", "from_xyY") + + LOGGER.info("=" * 80) + LOGGER.info("RBF Interpolation for xyY to Munsell Conversion") + LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)") + LOGGER.info("=" * 80) + + # Load reference data from colour library + LOGGER.info("") + LOGGER.info("Loading reference Munsell data...") + X_all, y_all = load_munsell_reference_data() + LOGGER.info("Total reference colors: %d", len(X_all)) + + # Split into train/validation (80/20) + X_train, X_val, y_train, y_val = train_test_split( + X_all, y_all, test_size=0.2, random_state=42 + ) + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Test different kernels + kernels_to_test = [ + ("thin_plate_spline", 0.0), + ("thin_plate_spline", 0.001), + ("thin_plate_spline", 0.01), + ("cubic", 0.0), + ("linear", 0.0), + ("multiquadric", 0.0), + ] + + best_kernel = None + best_smoothing = None + best_mae = float("inf") + + LOGGER.info("") + LOGGER.info("Testing different RBF kernels...") + LOGGER.info("-" * 60) + + for kernel, smoothing in kernels_to_test: + LOGGER.info("") + LOGGER.info("Kernel: %s, Smoothing: %s", kernel, smoothing) + + try: + interpolator = MunsellRBFInterpolator(kernel=kernel, smoothing=smoothing) + interpolator.fit(X_train, y_train) + + results = evaluate_rbf(interpolator, X_val, y_val, "Validation") + total_mae = sum(results.values()) + + if total_mae < best_mae: + best_mae = total_mae + best_kernel = kernel + best_smoothing = smoothing + + except Exception: + LOGGER.exception(" Failed") + + LOGGER.info("") + LOGGER.info("=" * 60) + LOGGER.info("Best configuration: %s with smoothing=%s", best_kernel, best_smoothing) + LOGGER.info("=" * 60) + + # Train final model with best kernel on ALL data + LOGGER.info("") + LOGGER.info("Training final model on all %d reference colors...", len(X_all)) + + final_interpolator = MunsellRBFInterpolator( + kernel=best_kernel, smoothing=best_smoothing + ) + final_interpolator.fit(X_all, y_all) + + LOGGER.info("") + LOGGER.info("Final evaluation (training set = all data):") + evaluate_rbf(final_interpolator, X_all, y_all, "All data") + + # Save the model + model_dir = PROJECT_ROOT / "models" / "from_xyY" + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / "rbf_interpolator.pkl" + final_interpolator.save(model_path) + + LOGGER.info("") + LOGGER.info("=" * 80) + + log_file.close() + + +if __name__ == "__main__": + main() diff --git a/learning_munsell/losses/__init__.py b/learning_munsell/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10afe671739158e83d2c1e14cd5a6e49c44950e6 --- /dev/null +++ b/learning_munsell/losses/__init__.py @@ -0,0 +1,17 @@ +"""Loss functions for Munsell ML training.""" + +from learning_munsell.losses.jax_delta_e import ( + XYZ_to_Lab, + delta_E_CIE2000, + delta_E_loss, + xyY_to_Lab, + xyY_to_XYZ, +) + +__all__ = [ + "delta_E_CIE2000", + "delta_E_loss", + "xyY_to_Lab", + "xyY_to_XYZ", + "XYZ_to_Lab", +] diff --git a/learning_munsell/losses/jax_delta_e.py b/learning_munsell/losses/jax_delta_e.py new file mode 100644 index 0000000000000000000000000000000000000000..6dadd3f90f425896b0ee1bda29ac978db0052e1f --- /dev/null +++ b/learning_munsell/losses/jax_delta_e.py @@ -0,0 +1,299 @@ +""" +Differentiable Delta-E Loss Functions using JAX +================================================ + +This module provides JAX implementations of color space conversions +and Delta-E (CIE2000) loss function for use in training. + +The key insight is that we can compute Delta-E between: +- The input xyY (which we convert to Lab as the "target") +- The predicted Munsell converted back to Lab + +For the Munsell -> xyY conversion, we either: +1. Use a pre-trained neural network approximator +2. Use differentiable interpolation on the Munsell Renotation data +""" + +from __future__ import annotations + +import colour +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +# D65 illuminant XYZ reference values (standard for sRGB) +D65_XYZ = jnp.array([95.047, 100.0, 108.883]) + +# Illuminant C XYZ reference values (used by Munsell system) +ILLUMINANT_C_XYZ = jnp.array([98.074, 100.0, 118.232]) + + +def xyY_to_XYZ(xyY: Array, scale_Y: bool = True) -> Array: + """ + Convert CIE xyY to CIE XYZ. + + Parameters + ---------- + xyY : Array + CIE xyY values with shape (..., 3) + scale_Y : bool + If True, scale Y from 0-1 to 0-100 range (required for Lab conversion) + + Returns + ------- + Array + CIE XYZ values with shape (..., 3) + """ + x = xyY[..., 0] + y = xyY[..., 1] + Y = xyY[..., 2] + + # Scale Y to 0-100 range if needed (colour library uses 0-100) + if scale_Y: + Y = Y * 100.0 + + # Avoid division by zero + y_safe = jnp.where(y == 0, 1e-10, y) + + X = (x * Y) / y_safe + Z = ((1 - x - y) * Y) / y_safe + + # Handle y=0 case (set X=Z=0) + X = jnp.where(y == 0, 0.0, X) + Z = jnp.where(y == 0, 0.0, Z) + + return jnp.stack([X, Y, Z], axis=-1) + + +def XYZ_to_Lab(XYZ: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array: + """ + Convert CIE XYZ to CIE Lab. + + Parameters + ---------- + XYZ : Array + CIE XYZ values with shape (..., 3) + illuminant : Array + Reference white XYZ values + + Returns + ------- + Array + CIE Lab values with shape (..., 3) + """ + # Normalize by illuminant + XYZ_n = XYZ / illuminant + + # CIE Lab transfer function + delta = 6.0 / 29.0 + delta_cube = delta**3 + + # f(t) = t^(1/3) if t > delta^3, else t/(3*delta^2) + 4/29 + def f(t: Array) -> Array: + return jnp.where(t > delta_cube, jnp.cbrt(t), t / (3 * delta**2) + 4.0 / 29.0) + + f_X = f(XYZ_n[..., 0]) + f_Y = f(XYZ_n[..., 1]) + f_Z = f(XYZ_n[..., 2]) + + L = 116.0 * f_Y - 16.0 + a = 500.0 * (f_X - f_Y) + b = 200.0 * (f_Y - f_Z) + + return jnp.stack([L, a, b], axis=-1) + + +def xyY_to_Lab(xyY: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array: + """Convert CIE xyY directly to CIE Lab.""" + return XYZ_to_Lab(xyY_to_XYZ(xyY), illuminant) + + +def delta_E_CIE2000(Lab_1: Array, Lab_2: Array) -> Array: + """ + Compute CIE 2000 Delta-E color difference. + + This is a differentiable JAX implementation of the CIE 2000 Delta-E formula. + + Parameters + ---------- + Lab_1 : Array + First CIE Lab color(s) with shape (..., 3) + Lab_2 : Array + Second CIE Lab color(s) with shape (..., 3) + + Returns + ------- + Array + Delta-E values with shape (...) + """ + L_1, a_1, b_1 = Lab_1[..., 0], Lab_1[..., 1], Lab_1[..., 2] + L_2, a_2, b_2 = Lab_2[..., 0], Lab_2[..., 1], Lab_2[..., 2] + + # Chroma + C_1_ab = jnp.sqrt(a_1**2 + b_1**2) + C_2_ab = jnp.sqrt(a_2**2 + b_2**2) + + C_bar_ab = (C_1_ab + C_2_ab) / 2 + C_bar_ab_7 = C_bar_ab**7 + + # G factor for a' adjustment (25^7 = 6103515625.0) + G = 0.5 * (1 - jnp.sqrt(C_bar_ab_7 / (C_bar_ab_7 + 6103515625.0))) + + # Adjusted a' + a_p_1 = (1 + G) * a_1 + a_p_2 = (1 + G) * a_2 + + # Adjusted chroma C' + C_p_1 = jnp.sqrt(a_p_1**2 + b_1**2) + C_p_2 = jnp.sqrt(a_p_2**2 + b_2**2) + + # Hue angle h' (in degrees) + h_p_1 = jnp.degrees(jnp.arctan2(b_1, a_p_1)) % 360 + h_p_2 = jnp.degrees(jnp.arctan2(b_2, a_p_2)) % 360 + + # Handle achromatic case + h_p_1 = jnp.where((b_1 == 0) & (a_p_1 == 0), 0.0, h_p_1) + h_p_2 = jnp.where((b_2 == 0) & (a_p_2 == 0), 0.0, h_p_2) + + # Delta L', C' + delta_L_p = L_2 - L_1 + delta_C_p = C_p_2 - C_p_1 + + # Delta h' + h_p_diff = h_p_2 - h_p_1 + C_p_product = C_p_1 * C_p_2 + + delta_h_p = jnp.where( + C_p_product == 0, + 0.0, + jnp.where( + jnp.abs(h_p_diff) <= 180, + h_p_diff, + jnp.where(h_p_diff > 180, h_p_diff - 360, h_p_diff + 360), + ), + ) + + # Delta H' + delta_H_p = 2 * jnp.sqrt(C_p_product) * jnp.sin(jnp.radians(delta_h_p / 2)) + + # Mean L', C' + L_bar_p = (L_1 + L_2) / 2 + C_bar_p = (C_p_1 + C_p_2) / 2 + + # Mean h' + h_p_sum = h_p_1 + h_p_2 + h_p_abs_diff = jnp.abs(h_p_1 - h_p_2) + + h_bar_p = jnp.where( + C_p_product == 0, + h_p_sum, + jnp.where( + h_p_abs_diff <= 180, + h_p_sum / 2, + jnp.where(h_p_sum < 360, (h_p_sum + 360) / 2, (h_p_sum - 360) / 2), + ), + ) + + # T factor + T = ( + 1 + - 0.17 * jnp.cos(jnp.radians(h_bar_p - 30)) + + 0.24 * jnp.cos(jnp.radians(2 * h_bar_p)) + + 0.32 * jnp.cos(jnp.radians(3 * h_bar_p + 6)) + - 0.20 * jnp.cos(jnp.radians(4 * h_bar_p - 63)) + ) + + # Delta theta + delta_theta = 30 * jnp.exp(-(((h_bar_p - 275) / 25) ** 2)) + + # R_C (25^7 = 6103515625.0) + C_bar_p_7 = C_bar_p**7 + R_C = 2 * jnp.sqrt(C_bar_p_7 / (C_bar_p_7 + 6103515625.0)) + + # S_L, S_C, S_H + L_bar_p_minus_50_sq = (L_bar_p - 50) ** 2 + S_L = 1 + (0.015 * L_bar_p_minus_50_sq) / jnp.sqrt(20 + L_bar_p_minus_50_sq) + S_C = 1 + 0.045 * C_bar_p + S_H = 1 + 0.015 * C_bar_p * T + + # R_T + R_T = -jnp.sin(jnp.radians(2 * delta_theta)) * R_C + + # Final Delta E + k_L, k_C, k_H = 1.0, 1.0, 1.0 + + term_L = delta_L_p / (k_L * S_L) + term_C = delta_C_p / (k_C * S_C) + term_H = delta_H_p / (k_H * S_H) + + return jnp.sqrt(term_L**2 + term_C**2 + term_H**2 + R_T * term_C * term_H) + + +def delta_E_loss(pred_xyY: Array, target_xyY: Array) -> Array: + """ + Compute mean Delta-E loss between predicted and target xyY values. + + This is the primary loss function for training with perceptual accuracy. + + Parameters + ---------- + pred_xyY : Array + Predicted xyY values with shape (batch, 3) + target_xyY : Array + Target xyY values with shape (batch, 3) + + Returns + ------- + Array + Scalar mean Delta-E loss + """ + pred_Lab = xyY_to_Lab(pred_xyY) + target_Lab = xyY_to_Lab(target_xyY) + return jnp.mean(delta_E_CIE2000(pred_Lab, target_Lab)) + + +# JIT-compiled versions for performance +xyY_to_XYZ_jit = jax.jit(xyY_to_XYZ) +XYZ_to_Lab_jit = jax.jit(XYZ_to_Lab) +xyY_to_Lab_jit = jax.jit(xyY_to_Lab) +delta_E_CIE2000_jit = jax.jit(delta_E_CIE2000) +delta_E_loss_jit = jax.jit(delta_E_loss) + +# Gradient functions +grad_delta_E_loss = jax.grad(delta_E_loss) + + +def test_jax_delta_e() -> None: + """Test the JAX Delta-E implementation against colour library.""" + # Test xyY values + xyY_1 = np.array([0.3127, 0.3290, 0.5]) # D65 white point, Y=0.5 + xyY_2 = np.array([0.35, 0.35, 0.5]) # Slightly shifted + + # Convert using JAX + Lab_1_jax = xyY_to_Lab(jnp.array(xyY_1)) + Lab_2_jax = xyY_to_Lab(jnp.array(xyY_2)) + delta_E_CIE2000(Lab_1_jax, Lab_2_jax) + + # Convert using colour library + XYZ_1 = colour.xyY_to_XYZ(xyY_1) + XYZ_2 = colour.xyY_to_XYZ(xyY_2) + Lab_1_colour = colour.XYZ_to_Lab( + XYZ_1, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"] + ) + Lab_2_colour = colour.XYZ_to_Lab( + XYZ_2, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"] + ) + colour.delta_E(Lab_1_colour, Lab_2_colour, method="CIE 2000") + + # Test gradient computation + pred_xyY = jnp.array([[0.35, 0.35, 0.5]]) + target_xyY = jnp.array([[0.3127, 0.3290, 0.5]]) + + # Compute gradient + grad_fn = jax.grad(lambda x: delta_E_loss(x, target_xyY)) + grad_fn(pred_xyY) + + +if __name__ == "__main__": + test_jax_delta_e() diff --git a/learning_munsell/models/__init__.py b/learning_munsell/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..554bc0835c77e2cc6b3d8dad57a78028f9553111 --- /dev/null +++ b/learning_munsell/models/__init__.py @@ -0,0 +1,47 @@ +"""Neural network models for Munsell color conversions.""" + +from learning_munsell.models.networks import ( + # Building blocks + ResidualBlock, + # Component networks + ComponentMLP, + ComponentErrorPredictor, + # Transformer building blocks + FeatureTokenizer, + TransformerBlock, + # Composite models: xyY → Munsell + MLPToMunsell, + MultiHeadMLPToMunsell, + MultiMLPToMunsell, + TransformerToMunsell, + # Error predictors: xyY → Munsell + MultiHeadErrorPredictorToMunsell, + MultiMLPErrorPredictorToMunsell, + # Composite models: Munsell → xyY + MultiMLPToxyY, + # Error predictors: Munsell → xyY + MultiMLPErrorPredictorToxyY, +) + +__all__ = [ + # Building blocks + "ResidualBlock", + # Component networks (single output) + "ComponentMLP", + "ComponentErrorPredictor", + # Transformer building blocks + "FeatureTokenizer", + "TransformerBlock", + # Composite models: xyY → Munsell + "MLPToMunsell", + "MultiHeadMLPToMunsell", + "MultiMLPToMunsell", + "TransformerToMunsell", + # Error predictors: xyY → Munsell + "MultiHeadErrorPredictorToMunsell", + "MultiMLPErrorPredictorToMunsell", + # Composite models: Munsell → xyY + "MultiMLPToxyY", + # Error predictors: Munsell → xyY + "MultiMLPErrorPredictorToxyY", +] diff --git a/learning_munsell/models/networks.py b/learning_munsell/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce17e9b32e7c5075cf7a011d7f0613ca25c284c --- /dev/null +++ b/learning_munsell/models/networks.py @@ -0,0 +1,1294 @@ +""" +Reusable neural network building blocks. + +Provides shared network architectures for training scripts, +including MLP components and error predictors. +""" + +from __future__ import annotations + +import torch +from torch import nn, Tensor + +__all__ = [ + # Building blocks + "ResidualBlock", + # Component networks (single output) + "ComponentMLP", + "ComponentResNet", + "ComponentErrorPredictor", + # Transformer building blocks + "FeatureTokenizer", + "TransformerBlock", + # Composite models: xyY → Munsell + "MLPToMunsell", + "MultiHeadMLPToMunsell", + "MultiMLPToMunsell", + "MultiResNetToMunsell", + "TransformerToMunsell", + # Error predictors: xyY → Munsell + "MultiHeadErrorPredictorToMunsell", + "MultiMLPErrorPredictorToMunsell", + "MultiResNetErrorPredictorToMunsell", + # Composite models: Munsell → xyY + "MultiMLPToxyY", + # Error predictors: Munsell → xyY + "MultiMLPErrorPredictorToxyY", +] + + +# ============================================================================= +# Building Blocks +# ============================================================================= + + +class ResidualBlock(nn.Module): + """ + Residual block with GELU activation and batch normalization. + + Architecture: input → Linear → GELU → BatchNorm → Linear → BatchNorm → add input → GELU + + Parameters + ---------- + dim : int + Dimension of input and output features. + + Attributes + ---------- + block : nn.Sequential + Sequential block with linear layers, GELU, and BatchNorm. + activation : nn.GELU + Final activation after residual addition. + """ + + def __init__(self, dim: int) -> None: + """Initialize residual block.""" + super().__init__() + self.block = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.BatchNorm1d(dim), + nn.Linear(dim, dim), + nn.BatchNorm1d(dim), + ) + self.activation = nn.GELU() + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass with residual connection. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, dim). + + Returns + ------- + Tensor + Output tensor of shape (batch_size, dim). + """ + return self.activation(x + self.block(x)) + + +# ============================================================================= +# Component Networks (Single Output) +# ============================================================================= + + +class ComponentMLP(nn.Module): + """ + Independent MLP for a single Munsell component. + + Architecture: input_dim → 128 → 256 → 512 → 256 → 128 → 1 + + Parameters + ---------- + input_dim : int, optional + Input feature dimension. Default is 3 (for xyY). + width_multiplier : float, optional + Multiplier for hidden layer dimensions. Default is 1.0. + dropout : float, optional + Dropout probability between layers. Default is 0.0. + + Attributes + ---------- + network : nn.Sequential + Feed-forward network with encoder-decoder structure. + + Notes + ----- + Uses ReLU activations and batch normalization. The encoder-decoder + architecture expands to 512-dim (or scaled by width_multiplier) and + then contracts back to a single output. Optional dropout can be + applied between layers for regularization. + """ + + def __init__( + self, + input_dim: int = 3, + width_multiplier: float = 1.0, + dropout: float = 0.0, + ) -> None: + """Initialize the component-specific MLP.""" + super().__init__() + + # Scale hidden dimensions + h1 = int(128 * width_multiplier) + h2 = int(256 * width_multiplier) + h3 = int(512 * width_multiplier) + + layers: list[nn.Module] = [ + # Encoder + nn.Linear(input_dim, h1), + nn.ReLU(), + nn.BatchNorm1d(h1), + ] + + if dropout > 0: + layers.append(nn.Dropout(dropout)) + + layers.extend( + [ + nn.Linear(h1, h2), + nn.ReLU(), + nn.BatchNorm1d(h2), + ] + ) + + if dropout > 0: + layers.append(nn.Dropout(dropout)) + + layers.extend( + [ + nn.Linear(h2, h3), + nn.ReLU(), + nn.BatchNorm1d(h3), + ] + ) + + if dropout > 0: + layers.append(nn.Dropout(dropout)) + + layers.extend( + [ + # Decoder + nn.Linear(h3, h2), + nn.ReLU(), + nn.BatchNorm1d(h2), + ] + ) + + if dropout > 0: + layers.append(nn.Dropout(dropout)) + + layers.extend( + [ + nn.Linear(h2, h1), + nn.ReLU(), + nn.BatchNorm1d(h1), + # Output + nn.Linear(h1, 1), + ] + ) + + self.network = nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the component-specific network. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, input_dim). + + Returns + ------- + Tensor + Output tensor of shape (batch_size, 1) containing the predicted + component value. + """ + return self.network(x) + + +class ComponentResNet(nn.Module): + """ + Independent ResNet for a single Munsell component with true skip connections. + + Architecture: input → projection → ResidualBlock × num_blocks → output + + Unlike ComponentMLP, this uses actual residual blocks where: + output = activation(x + f(x)) + + Parameters + ---------- + input_dim : int, optional + Input feature dimension. Default is 3 (for xyY). + hidden_dim : int, optional + Hidden dimension for residual blocks. Default is 256. + num_blocks : int, optional + Number of residual blocks. Default is 4. + + Attributes + ---------- + input_proj : nn.Sequential + Projects input to hidden dimension with GELU activation. + res_blocks : nn.ModuleList + List of ResidualBlock modules with skip connections. + output_proj : nn.Linear + Projects hidden dimension to single output. + """ + + def __init__( + self, + input_dim: int = 3, + hidden_dim: int = 256, + num_blocks: int = 4, + ) -> None: + """Initialize the component-specific ResNet.""" + super().__init__() + + # Project input to hidden dimension + self.input_proj = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + ) + + # Stack of residual blocks with skip connections + self.res_blocks = nn.ModuleList( + [ResidualBlock(hidden_dim) for _ in range(num_blocks)] + ) + + # Project to output + self.output_proj = nn.Linear(hidden_dim, 1) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the ResNet with skip connections. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, input_dim). + + Returns + ------- + Tensor + Output tensor of shape (batch_size, 1). + """ + x = self.input_proj(x) + for block in self.res_blocks: + x = block(x) # Each block applies: activation(x + f(x)) + return self.output_proj(x) + + +class ComponentErrorPredictor(nn.Module): + """ + Independent error predictor for a single Munsell component. + + A deep MLP that learns to predict residual errors for one Munsell + component (hue, value, chroma, or code). + + Parameters + ---------- + input_dim : int, optional + Input feature dimension. Default is 7 (xyY_norm + base_pred_norm). + width_multiplier : float, optional + Multiplier for hidden layer widths. Default is 1.0. + Use 1.5 for chroma which requires more capacity. + + Attributes + ---------- + network : nn.Sequential + Feed-forward network: input → 128 → 256 → 512 → 256 → 128 → 1 + with GELU activations and BatchNorm after each hidden layer. + + Notes + ----- + Default input is [xyY_norm (3) + base_pred_norm (4)] = 7 features. + Output is a single scalar error correction for the component. + """ + + def __init__( + self, + input_dim: int = 7, + width_multiplier: float = 1.0, + ) -> None: + """Initialize the error predictor.""" + super().__init__() + + # Scale hidden dimensions + h1 = int(128 * width_multiplier) + h2 = int(256 * width_multiplier) + h3 = int(512 * width_multiplier) + + self.network = nn.Sequential( + # Encoder + nn.Linear(input_dim, h1), + nn.GELU(), + nn.BatchNorm1d(h1), + nn.Linear(h1, h2), + nn.GELU(), + nn.BatchNorm1d(h2), + nn.Linear(h2, h3), + nn.GELU(), + nn.BatchNorm1d(h3), + # Decoder + nn.Linear(h3, h2), + nn.GELU(), + nn.BatchNorm1d(h2), + nn.Linear(h2, h1), + nn.GELU(), + nn.BatchNorm1d(h1), + # Output + nn.Linear(h1, 1), + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the error predictor. + + Parameters + ---------- + x : Tensor + Combined input of shape (batch_size, input_dim). + + Returns + ------- + Tensor + Predicted error correction of shape (batch_size, 1). + """ + return self.network(x) + + +# ============================================================================= +# Transformer Building Blocks +# ============================================================================= + + +class FeatureTokenizer(nn.Module): + """ + Tokenize each input feature into high-dimensional embedding. + + Converts each scalar input feature into a learned embedding vector, + similar to word embeddings in NLP. Also prepends a learnable CLS token + used for regression output. + + Parameters + ---------- + num_features : int + Number of input features to tokenize. + embedding_dim : int + Dimensionality of each token embedding. + + Attributes + ---------- + feature_embeddings : nn.ModuleList + List of linear layers, one per input feature. + cls_token : nn.Parameter + Learnable classification token prepended to feature tokens. + """ + + def __init__(self, num_features: int, embedding_dim: int) -> None: + """Initialize the feature tokenizer.""" + super().__init__() + # Each feature gets its own embedding + self.feature_embeddings = nn.ModuleList( + [nn.Linear(1, embedding_dim) for _ in range(num_features)] + ) + # CLS token for regression + self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) + + def forward(self, x: Tensor) -> Tensor: + """ + Transform input features into token embeddings. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, num_features). + + Returns + ------- + Tensor + Token embeddings of shape (batch_size, 1+num_features, embedding_dim). + First token is CLS, followed by feature tokens. + """ + batch_size = x.size(0) + + # Tokenize each feature + tokens = [] + for i, embedding in enumerate(self.feature_embeddings): + feature_val = x[:, i : i + 1] # (batch_size, 1) + token = embedding(feature_val) # (batch_size, embedding_dim) + tokens.append(token.unsqueeze(1)) # (batch_size, 1, embedding_dim) + + # Concatenate feature tokens + feature_tokens = torch.cat( + tokens, dim=1 + ) # (batch_size, num_features, embedding_dim) + + # Prepend CLS token + cls_tokens = self.cls_token.expand( + batch_size, -1, -1 + ) # (batch_size, 1, embedding_dim) + return torch.cat( + [cls_tokens, feature_tokens], dim=1 + ) # (batch_size, 1+num_features, embedding_dim) + + +class TransformerBlock(nn.Module): + """ + Standard transformer block with multi-head attention and feedforward network. + + Implements the classic transformer architecture with self-attention, + feedforward layers, layer normalization, and residual connections. + + Parameters + ---------- + embedding_dim : int + Dimension of token embeddings. + num_heads : int + Number of attention heads. + ff_dim : int + Hidden dimension of feedforward network. + dropout : float, optional + Dropout probability, default is 0.1. + + Attributes + ---------- + attention : nn.MultiheadAttention + Multi-head self-attention mechanism. + norm1 : nn.LayerNorm + Layer normalization after attention. + feedforward : nn.Sequential + Feedforward network with GELU activation. + norm2 : nn.LayerNorm + Layer normalization after feedforward. + """ + + def __init__( + self, embedding_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1 + ) -> None: + """Initialize the transformer block.""" + super().__init__() + + self.attention = nn.MultiheadAttention( + embedding_dim, num_heads, dropout=dropout, batch_first=True + ) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.feedforward = nn.Sequential( + nn.Linear(embedding_dim, ff_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(ff_dim, embedding_dim), + nn.Dropout(dropout), + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + """ + Apply transformer block to input tokens. + + Parameters + ---------- + x : Tensor + Input tokens of shape (batch_size, num_tokens, embedding_dim). + + Returns + ------- + Tensor + Transformed tokens of shape (batch_size, num_tokens, embedding_dim). + """ + # Self-attention with residual + attn_output, _ = self.attention(x, x, x) + x = self.norm1(x + attn_output) + + # Feedforward with residual + ff_output = self.feedforward(x) + return self.norm2(x + ff_output) + + +# ============================================================================= +# Composite Models: xyY → Munsell +# ============================================================================= + + +class MLPToMunsell(nn.Module): + """ + Large MLP for xyY to Munsell conversion. + + Architecture: 3 → 128 → 256 → 512 → 512 → 256 → 128 → 4 + + Attributes + ---------- + network : nn.Sequential + Feed-forward network with ReLU activations and BatchNorm. + """ + + def __init__(self) -> None: + """Initialize the MunsellMLP network.""" + super().__init__() + + self.network = nn.Sequential( + nn.Linear(3, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 512), + nn.ReLU(), + nn.BatchNorm1d(512), + nn.Linear(512, 512), + nn.ReLU(), + nn.BatchNorm1d(512), + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 4), + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the network. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, 3) containing normalized xyY values. + + Returns + ------- + Tensor + Output tensor of shape (batch_size, 4) containing normalized Munsell + specifications [hue, value, chroma, code]. + """ + return self.network(x) + + +class MultiHeadMLPToMunsell(nn.Module): + """ + Multi-head MLP for xyY to Munsell conversion. + + Each component (hue, value, chroma, code) has a specialized decoder head + after a shared encoder. The chroma head is wider to handle the more complex + non-linear relationship between xyY and chroma. + + Attributes + ---------- + encoder : nn.Sequential + Shared encoder: 3 → 128 → 256 → 512 with ReLU and BatchNorm. + hue_head : nn.Sequential + Hue decoder: 512 → 256 → 128 → 1 (circular component). + value_head : nn.Sequential + Value decoder: 512 → 256 → 128 → 1 (linear component). + chroma_head : nn.Sequential + Chroma decoder: 512 → 384 → 256 → 128 → 1 (wider for complexity). + code_head : nn.Sequential + Code decoder: 512 → 256 → 128 → 1 (discrete component). + + Notes + ----- + The chroma head has increased capacity (384 units in first layer) to handle + the more complex non-linear relationship between xyY and chroma. + """ + + def __init__(self) -> None: + """Initialize the multi-head MLP model.""" + super().__init__() + + # Shared encoder - learns general color space features + self.encoder = nn.Sequential( + nn.Linear(3, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 512), + nn.ReLU(), + nn.BatchNorm1d(512), + ) + + # Hue head - circular/angular component + self.hue_head = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + # Value head - linear lightness + self.value_head = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + # Chroma head - non-linear saturation (WIDER for harder task) + self.chroma_head = nn.Sequential( + nn.Linear(512, 384), # Wider than other heads + nn.ReLU(), + nn.BatchNorm1d(384), + nn.Linear(384, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + # Code head - discrete categorical + self.code_head = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the multi-head network. + + Parameters + ---------- + x : Tensor + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Tensor + Concatenated Munsell predictions [hue, value, chroma, code] + of shape (batch_size, 4). + """ + # Shared feature extraction + features = self.encoder(x) + + # Component-specific predictions + hue = self.hue_head(features) + value = self.value_head(features) + chroma = self.chroma_head(features) + code = self.code_head(features) + + # Concatenate: [Hue, Value, Chroma, Code] + return torch.cat([hue, value, chroma, code], dim=1) + + +class MultiMLPToMunsell(nn.Module): + """ + Multi-MLP for xyY to Munsell conversion. + + Uses 4 independent ComponentMLP branches, one for each Munsell component. + The chroma branch can be wider to handle the more complex relationship. + + Parameters + ---------- + chroma_width_multiplier : float, optional + Width multiplier for the chroma branch. Default is 2.0. + dropout : float, optional + Dropout probability for all branches. Default is 0.1. + + Attributes + ---------- + hue_branch : ComponentMLP + MLP for hue component (1.0x width). + value_branch : ComponentMLP + MLP for value component (1.0x width). + chroma_branch : ComponentMLP + MLP for chroma component (configurable width). + code_branch : ComponentMLP + MLP for hue code component (1.0x width). + """ + + def __init__( + self, chroma_width_multiplier: float = 2.0, dropout: float = 0.1 + ) -> None: + """Initialize the multi-branch MLP model.""" + super().__init__() + + self.hue_branch = ComponentMLP( + input_dim=3, width_multiplier=1.0, dropout=dropout + ) + self.value_branch = ComponentMLP( + input_dim=3, width_multiplier=1.0, dropout=dropout + ) + self.chroma_branch = ComponentMLP( + input_dim=3, width_multiplier=chroma_width_multiplier, dropout=dropout + ) + self.code_branch = ComponentMLP( + input_dim=3, width_multiplier=1.0, dropout=dropout + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through all 4 independent branches. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, 3) containing normalized xyY values. + + Returns + ------- + Tensor + Concatenated predictions [hue, value, chroma, code] + of shape (batch_size, 4). + """ + hue = self.hue_branch(x) + value = self.value_branch(x) + chroma = self.chroma_branch(x) + code = self.code_branch(x) + return torch.cat([hue, value, chroma, code], dim=1) + + +class MultiResNetToMunsell(nn.Module): + """ + Multi-ResNet for xyY to Munsell conversion with true skip connections. + + Uses 4 independent ComponentResNet branches, one for each Munsell component. + Each branch contains actual residual blocks with skip connections. + + Parameters + ---------- + hidden_dim : int, optional + Hidden dimension for residual blocks. Default is 256. + num_blocks : int, optional + Number of residual blocks per branch. Default is 4. + chroma_hidden_dim : int, optional + Hidden dimension for chroma branch (typically larger). Default is 512. + + Attributes + ---------- + hue_branch : ComponentResNet + ResNet for hue component. + value_branch : ComponentResNet + ResNet for value component. + chroma_branch : ComponentResNet + ResNet for chroma component (larger hidden dim). + code_branch : ComponentResNet + ResNet for hue code component. + """ + + def __init__( + self, + hidden_dim: int = 256, + num_blocks: int = 4, + chroma_hidden_dim: int = 512, + ) -> None: + """Initialize the multi-branch ResNet model.""" + super().__init__() + + self.hue_branch = ComponentResNet( + input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks + ) + self.value_branch = ComponentResNet( + input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks + ) + self.chroma_branch = ComponentResNet( + input_dim=3, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks + ) + self.code_branch = ComponentResNet( + input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through all 4 independent ResNet branches. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, 3) containing normalized xyY values. + + Returns + ------- + Tensor + Concatenated predictions [hue, value, chroma, code] + of shape (batch_size, 4). + """ + hue = self.hue_branch(x) + value = self.value_branch(x) + chroma = self.chroma_branch(x) + code = self.code_branch(x) + return torch.cat([hue, value, chroma, code], dim=1) + + +class TransformerToMunsell(nn.Module): + """ + Transformer for xyY to Munsell conversion. + + Uses a feature tokenizer to convert input features to embeddings, + followed by transformer blocks with self-attention, and separate + output heads for each Munsell component. + + Parameters + ---------- + num_features : int, optional + Number of input features (default is 3 for xyY). + embedding_dim : int, optional + Dimension of token embeddings (default is 256). + num_blocks : int, optional + Number of transformer blocks (default is 6). + num_heads : int, optional + Number of attention heads (default is 8). + ff_dim : int, optional + Feedforward network hidden dimension (default is 1024). + dropout : float, optional + Dropout probability (default is 0.1). + + Attributes + ---------- + tokenizer : FeatureTokenizer + Converts input features to token embeddings with CLS token. + transformer_blocks : nn.ModuleList + Stack of transformer blocks with self-attention. + final_norm : nn.LayerNorm + Final layer normalization before output heads. + hue_head : nn.Sequential + Output head for hue prediction. + value_head : nn.Sequential + Output head for value prediction. + chroma_head : nn.Sequential + Deeper output head for chroma prediction. + code_head : nn.Sequential + Output head for hue code prediction. + + Notes + ----- + Architecture: 3 xyY features → 3 tokens + 1 CLS token → transformer blocks + with self-attention → multi-head output with specialized component heads. + The chroma head has additional depth due to prediction difficulty. + """ + + def __init__( + self, + num_features: int = 3, + embedding_dim: int = 256, + num_blocks: int = 6, + num_heads: int = 8, + ff_dim: int = 1024, + dropout: float = 0.1, + ) -> None: + """Initialize the transformer model.""" + super().__init__() + + self.tokenizer = FeatureTokenizer(num_features, embedding_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TransformerBlock(embedding_dim, num_heads, ff_dim, dropout) + for _ in range(num_blocks) + ] + ) + + self.final_norm = nn.LayerNorm(embedding_dim) + + # Multi-head output - separate heads for each Munsell component + self.hue_head = nn.Sequential( + nn.Linear(embedding_dim, 128), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(128, 1), + ) + self.value_head = nn.Sequential( + nn.Linear(embedding_dim, 128), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(128, 1), + ) + self.chroma_head = nn.Sequential( + nn.Linear(embedding_dim, 256), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(256, 128), + nn.GELU(), + nn.Linear(128, 1), + ) + self.code_head = nn.Sequential( + nn.Linear(embedding_dim, 128), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(128, 1), + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the transformer. + + Parameters + ---------- + x : Tensor + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Tensor + Predicted Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + + Notes + ----- + The CLS token representation is used for the final prediction through + separate task-specific heads for each Munsell component. + """ + tokens = self.tokenizer(x) + + for block in self.transformer_blocks: + tokens = block(tokens) + + tokens = self.final_norm(tokens) + cls_token = tokens[:, 0, :] + + hue = self.hue_head(cls_token) + value = self.value_head(cls_token) + chroma = self.chroma_head(cls_token) + code = self.code_head(cls_token) + + return torch.cat([hue, value, chroma, code], dim=1) + + +# ============================================================================= +# Error Predictors: xyY → Munsell +# ============================================================================= + + +class MultiHeadErrorPredictorToMunsell(nn.Module): + """ + Multi-Head error predictor for xyY to Munsell conversion. + + Each branch is a ComponentErrorPredictor specialized for one + Munsell component. The chroma branch is wider (1.5x) to handle + the more complex error patterns in chroma prediction. + + Parameters + ---------- + input_dim : int, optional + Input feature dimension. Default is 7. + chroma_width : float, optional + Width multiplier for chroma branch. Default is 1.5. + + Attributes + ---------- + hue_branch : ComponentErrorPredictor + Error predictor for hue component (1.0x width). + value_branch : ComponentErrorPredictor + Error predictor for value component (1.0x width). + chroma_branch : ComponentErrorPredictor + Error predictor for chroma component (1.5x width by default). + code_branch : ComponentErrorPredictor + Error predictor for hue code component (1.0x width). + """ + + def __init__( + self, + input_dim: int = 7, + chroma_width: float = 1.5, + ) -> None: + """Initialize the multi-head error predictor.""" + super().__init__() + + # Independent error predictor for each component + self.hue_branch = ComponentErrorPredictor( + input_dim=input_dim, width_multiplier=1.0 + ) + self.value_branch = ComponentErrorPredictor( + input_dim=input_dim, width_multiplier=1.0 + ) + self.chroma_branch = ComponentErrorPredictor( + input_dim=input_dim, width_multiplier=chroma_width + ) + self.code_branch = ComponentErrorPredictor( + input_dim=input_dim, width_multiplier=1.0 + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + x : Tensor + Combined input of shape (batch_size, input_dim). + + Returns + ------- + Tensor + Concatenated error corrections [hue, value, chroma, code] + of shape (batch_size, 4). + """ + # Each branch processes the same combined input independently + hue_error = self.hue_branch(x) + value_error = self.value_branch(x) + chroma_error = self.chroma_branch(x) + code_error = self.code_branch(x) + + # Concatenate: [Hue_error, Value_error, Chroma_error, Code_error] + return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1) + + +class MultiMLPErrorPredictorToMunsell(nn.Module): + """ + Multi-MLP error predictor for xyY to Munsell conversion. + + Uses 4 independent ComponentErrorPredictor branches, one for each + Munsell component error. + + Parameters + ---------- + chroma_width : float, optional + Width multiplier for chroma branch. Default is 1.5. + + Attributes + ---------- + hue_branch : ComponentErrorPredictor + Error predictor for hue component (1.0x width). + value_branch : ComponentErrorPredictor + Error predictor for value component (1.0x width). + chroma_branch : ComponentErrorPredictor + Error predictor for chroma component (configurable width). + code_branch : ComponentErrorPredictor + Error predictor for hue code component (1.0x width). + """ + + def __init__(self, chroma_width: float = 1.5) -> None: + """Initialize the multi-head error predictor.""" + super().__init__() + + self.hue_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0) + self.value_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0) + self.chroma_branch = ComponentErrorPredictor( + input_dim=7, width_multiplier=chroma_width + ) + self.code_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + x : Tensor + Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7). + + Returns + ------- + Tensor + Concatenated error corrections [hue, value, chroma, code] + of shape (batch_size, 4). + """ + hue_error = self.hue_branch(x) + value_error = self.value_branch(x) + chroma_error = self.chroma_branch(x) + code_error = self.code_branch(x) + return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1) + + +class MultiResNetErrorPredictorToMunsell(nn.Module): + """ + Multi-ResNet error predictor for xyY to Munsell conversion. + + Uses 4 independent ComponentResNet branches with true skip connections, + one for each Munsell component error. + + Parameters + ---------- + hidden_dim : int, optional + Hidden dimension for residual blocks. Default is 256. + num_blocks : int, optional + Number of residual blocks per branch. Default is 4. + chroma_hidden_dim : int, optional + Hidden dimension for chroma branch. Default is 384. + + Attributes + ---------- + hue_branch : ComponentResNet + ResNet error predictor for hue component. + value_branch : ComponentResNet + ResNet error predictor for value component. + chroma_branch : ComponentResNet + ResNet error predictor for chroma component. + code_branch : ComponentResNet + ResNet error predictor for code component. + """ + + def __init__( + self, + hidden_dim: int = 256, + num_blocks: int = 4, + chroma_hidden_dim: int = 384, + ) -> None: + """Initialize the multi-ResNet error predictor.""" + super().__init__() + + # Input: xyY (3) + base prediction (4) = 7 + self.hue_branch = ComponentResNet( + input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks + ) + self.value_branch = ComponentResNet( + input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks + ) + self.chroma_branch = ComponentResNet( + input_dim=7, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks + ) + self.code_branch = ComponentResNet( + input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + x : Tensor + Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7). + + Returns + ------- + Tensor + Concatenated error corrections [hue, value, chroma, code] + of shape (batch_size, 4). + """ + hue_error = self.hue_branch(x) + value_error = self.value_branch(x) + chroma_error = self.chroma_branch(x) + code_error = self.code_branch(x) + return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1) + + +# ============================================================================= +# Composite Models: Munsell → xyY +# ============================================================================= + + +class MultiMLPToxyY(nn.Module): + """ + Multi-MLP for Munsell to xyY conversion. + + Uses 3 independent ComponentMLP branches, one for each xyY component. + + Parameters + ---------- + width_multiplier : float, optional + Width multiplier for x and y branches. Default is 1.0. + y_width_multiplier : float, optional + Width multiplier for Y (luminance) branch. Default is 1.25. + + Attributes + ---------- + x_branch : ComponentMLP + MLP for x chromaticity component. + y_branch : ComponentMLP + MLP for y chromaticity component. + Y_branch : ComponentMLP + MLP for Y luminance component. + """ + + def __init__( + self, width_multiplier: float = 1.0, y_width_multiplier: float = 1.25 + ) -> None: + """Initialize the multi-MLP model.""" + super().__init__() + + self.x_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier) + self.y_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier) + self.Y_branch = ComponentMLP( + input_dim=4, width_multiplier=y_width_multiplier + ) + + def forward(self, munsell: Tensor) -> Tensor: + """ + Forward pass through all branches. + + Parameters + ---------- + munsell : Tensor + Normalized Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + + Returns + ------- + Tensor + Predicted xyY values [x, y, Y] of shape (batch_size, 3). + """ + x = self.x_branch(munsell) + y = self.y_branch(munsell) + Y = self.Y_branch(munsell) + return torch.cat([x, y, Y], dim=1) + + +# ============================================================================= +# Error Predictors: Munsell → xyY +# ============================================================================= + + +class MultiMLPErrorPredictorToxyY(nn.Module): + """ + Multi-MLP error predictor for Munsell to xyY conversion. + + Uses 3 independent ComponentErrorPredictor branches, one for each + xyY component error. + + Parameters + ---------- + width_multiplier : float, optional + Width multiplier for all branches. Default is 1.0. + + Attributes + ---------- + x_branch : ComponentErrorPredictor + Error predictor for x chromaticity component. + y_branch : ComponentErrorPredictor + Error predictor for y chromaticity component. + Y_branch : ComponentErrorPredictor + Error predictor for Y luminance component. + """ + + def __init__(self, width_multiplier: float = 1.0) -> None: + """Initialize the multi-head error predictor.""" + super().__init__() + + self.x_branch = ComponentErrorPredictor( + input_dim=7, width_multiplier=width_multiplier + ) + self.y_branch = ComponentErrorPredictor( + input_dim=7, width_multiplier=width_multiplier + ) + self.Y_branch = ComponentErrorPredictor( + input_dim=7, width_multiplier=width_multiplier + ) + + def forward(self, combined_input: Tensor) -> Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + combined_input : Tensor + Combined input [munsell_norm, base_pred] of shape (batch_size, 7). + + Returns + ------- + Tensor + Concatenated error corrections [x, y, Y] of shape (batch_size, 3). + """ + x_error = self.x_branch(combined_input) + y_error = self.y_branch(combined_input) + Y_error = self.Y_branch(combined_input) + return torch.cat([x_error, y_error, Y_error], dim=1) diff --git a/learning_munsell/training/from_xyY/__init__.py b/learning_munsell/training/from_xyY/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43884fa12942f51baff8e5e1f71e4ca9eea3c146 --- /dev/null +++ b/learning_munsell/training/from_xyY/__init__.py @@ -0,0 +1 @@ +"""Training scripts for xyY to Munsell conversion.""" diff --git a/learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py b/learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6b107a00184f28cf582b6fb98dd84c7d1e8e07 --- /dev/null +++ b/learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py @@ -0,0 +1,503 @@ +""" +Hyperparameter search for Multi-Error Predictor using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Chroma width multiplier +- Loss function weights (MSE, MAE, log penalty, Huber) +- Huber delta +- Dropout + +Objective: Minimize validation loss +""" + +import logging +from datetime import datetime +from pathlib import Path + +import mlflow +import numpy as np +import onnxruntime as ort +import optuna +import torch +from numpy.typing import NDArray +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ( + ComponentErrorPredictor, + MultiMLPErrorPredictorToMunsell, +) +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell + +LOGGER = logging.getLogger(__name__) + + +def precision_focused_loss( + pred: torch.Tensor, + target: torch.Tensor, + mse_weight: float = 1.0, + mae_weight: float = 0.5, + log_weight: float = 0.3, + huber_weight: float = 0.5, + huber_delta: float = 0.01, +) -> torch.Tensor: + """ + Precision-focused loss function with configurable weights. + + Combines multiple loss components to encourage accurate error prediction: + - MSE: Standard mean squared error + - MAE: Mean absolute error for robustness + - Log penalty: Penalizes small errors more heavily + - Huber loss: Robust to outliers with adjustable delta + + Parameters + ---------- + pred : torch.Tensor + Predicted values, shape (batch_size, n_components). + target : torch.Tensor + Target values, shape (batch_size, n_components). + mse_weight : float, optional + Weight for MSE component. Default is 1.0. + mae_weight : float, optional + Weight for MAE component. Default is 0.5. + log_weight : float, optional + Weight for logarithmic penalty component. Default is 0.3. + huber_weight : float, optional + Weight for Huber loss component. Default is 0.5. + huber_delta : float, optional + Delta parameter for Huber loss transition point. Default is 0.01. + + Returns + ------- + torch.Tensor + Weighted combination of loss components, scalar tensor. + """ + + mse = torch.mean((pred - target) ** 2) + mae = torch.mean(torch.abs(pred - target)) + log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0)) + + abs_error = torch.abs(pred - target) + huber = torch.where( + abs_error <= huber_delta, + 0.5 * abs_error**2, + huber_delta * (abs_error - 0.5 * huber_delta), + ) + huber_loss = torch.mean(huber) + + return ( + mse_weight * mse + + mae_weight * mae + + log_weight * log_penalty + + huber_weight * huber_loss + ) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the base ONNX model and its normalization parameters. + + Parameters + ---------- + model_path : Path + Path to the base model ONNX file. + params_path : Path + Path to the normalization parameters NPZ file. + + Returns + ------- + ort.InferenceSession + ONNX Runtime inference session for the base model. + dict + Input normalization parameters (x_range, y_range, Y_range). + dict + Output normalization parameters (hue_range, value_range, chroma_range, code_range). + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, + loss_params: dict[str, float], +) -> float: + """ + Train the model for one epoch. + + Parameters + ---------- + model : nn.Module + Error predictor model to train. + dataloader : DataLoader + DataLoader providing training batches. + optimizer : optim.Optimizer + Optimizer for updating model parameters. + device : torch.device + Device to run training on (CPU, CUDA, or MPS). + loss_params : dict of str to float + Parameters for precision_focused_loss function. + + Returns + ------- + float + Average training loss over the epoch. + """ + model.train() + total_loss = 0.0 + + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + outputs = model(X_batch) + loss = precision_focused_loss(outputs, y_batch, **loss_params) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + loss_params: dict[str, float], +) -> float: + """ + Validate the model on the validation set. + + Parameters + ---------- + model : nn.Module + Error predictor model to validate. + dataloader : DataLoader + DataLoader providing validation batches. + device : torch.device + Device to run validation on (CPU, CUDA, or MPS). + loss_params : dict of str to float + Parameters for precision_focused_loss function. + + Returns + ------- + float + Average validation loss. + """ + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + outputs = model(X_batch) + loss = precision_focused_loss(outputs, y_batch, **loss_params) + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize validation loss. + + This function defines the hyperparameter search space and training + procedure for each trial. It optimizes: + - Learning rate (5e-4 to 1e-3, log scale) + - Batch size (512 or 1024) + - Chroma branch width multiplier (1.0 to 1.5) + - Dropout rate (0.1 to 0.2) + - Loss function weights (MSE, Huber) + - Huber delta parameter (0.01 to 0.05) + + Parameters + ---------- + trial : Trial + Optuna trial object for suggesting hyperparameters. + + Returns + ------- + float + Best validation loss achieved during training. + + Raises + ------ + FileNotFoundError + If base model or training data files are not found. + optuna.TrialPruned + If trial is pruned based on intermediate results. + """ + + # Hyperparameters to optimize - constrained based on Trial 0 insights + lr = trial.suggest_float("lr", 5e-4, 1e-3, log=True) # Higher LR worked well + batch_size = trial.suggest_categorical( + "batch_size", [512, 1024] + ) # Smaller batches better + chroma_width = trial.suggest_float( + "chroma_width", 1.0, 1.5, step=0.25 + ) # Smaller worked + dropout = trial.suggest_float("dropout", 0.1, 0.2, step=0.05) + + # Simplified loss - just MSE + optional small Huber (no log penalty!) + mse_weight = trial.suggest_float("mse_weight", 1.0, 2.0, step=0.25) + huber_weight = trial.suggest_float("huber_weight", 0.0, 0.5, step=0.25) + huber_delta = trial.suggest_float("huber_delta", 0.01, 0.05, step=0.01) + + loss_params = { + "mse_weight": mse_weight, + "mae_weight": 0.0, # Fixed at 0 + "log_weight": 0.0, # Fixed at 0 (was causing scale issues) + "huber_weight": huber_weight, + "huber_delta": huber_delta, + } + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" chroma_width: %.2f", chroma_width) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info(" mse_weight: %.2f", mse_weight) + LOGGER.info(" huber_weight: %.2f", huber_weight) + LOGGER.info(" huber_delta: %.3f", huber_delta) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load base model and data + model_dir = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_dir / "multi_mlp.onnx" + params_path = model_dir / "multi_mlp_normalization_params.npz" + cache_file = data_dir / "training_data.npz" + + if not base_model_path.exists(): + msg = f"Base model not found: {base_model_path}" + raise FileNotFoundError(msg) + + base_session, input_params, output_params = load_base_model( + base_model_path, params_path + ) + + # Load data + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + # Normalize and generate base predictions + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0] + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0] + + # Compute errors + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Combined input + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiMLPErrorPredictorToMunsell(chroma_width=chroma_width, dropout=dropout).to( + device + ) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info(" Total parameters: %s", f"{total_params:,}") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", f"hparam_error_predictor_trial_{trial.number}" + ) + + # Training loop + num_epochs = 100 + patience = 15 + best_val_loss = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "trial": trial.number, + "lr": lr, + "batch_size": batch_size, + "chroma_width": chroma_width, + "dropout": dropout, + "mse_weight": mse_weight, + "huber_weight": huber_weight, + "huber_delta": huber_delta, + "total_params": total_params, + } + ) + + for epoch in range(num_epochs): + train_loss = train_epoch( + model, train_loader, optimizer, device, loss_params + ) + val_loss = validate(model, val_loader, device, loss_params) + + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + trial.report(val_loss, epoch) + + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_train_loss": train_loss, + "final_epoch": epoch + 1, + } + ) + + LOGGER.info(" Final validation loss: %.6f", best_val_loss) + + return best_val_loss + + +def main() -> None: + """ + Run hyperparameter search for Multi-MLP Error Predictor. + + Performs systematic hyperparameter optimization using Optuna with: + - MedianPruner for early stopping of unpromising trials + - 15 total trials + - MLflow logging for each trial + - Result visualization and saving + + The search aims to find optimal hyperparameters for predicting errors + in a base Munsell prediction model, which can then be used to improve + predictions by correcting systematic biases. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Error Predictor Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + study = optuna.create_study( + direction="minimize", + study_name="multi_mlp_error_predictor_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10), + ) + + n_trials = 15 + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (val_loss): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "from_xyY" + results_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"error_predictor_hparam_search_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-Error Predictor Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best validation loss: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for trial in study.trials: + f.write(f"\nTrial {trial.number}:\n") + f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n") + f.write(" Params:\n") + for key, value in trial.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py b/learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a9bde74ba8881a0376e7a5017f6bd45b274eb9ee --- /dev/null +++ b/learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py @@ -0,0 +1,541 @@ +""" +Hyperparameter search for Multi-Head model (xyY to Munsell) using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Encoder width multiplier (shared encoder capacity) +- Head width multiplier (component-specific head capacity) +- Chroma head width (specialized for chroma prediction) +- Dropout +- Weight decay + +Objective: Minimize validation loss +""" + +from __future__ import annotations + +import logging +from datetime import datetime + +import matplotlib.pyplot as plt +import mlflow +import numpy as np +import optuna +import torch +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class MultiHeadParametric(nn.Module): + """ + Parametric Multi-Head model for hyperparameter search (xyY to Munsell). + + This model uses a shared encoder to extract general color space features + from xyY inputs, followed by component-specific heads for predicting + each Munsell component independently. + + Architecture: + - Shared encoder: 3 → h1 → h2 → h3 (scaled by encoder_width) + - hue, value, code heads: h3 → h2' → h1' → 1 (scaled by head_width) + - chroma head: h3 → h2'' → h1'' → 1 (scaled by chroma_head_width) + + Parameters + ---------- + encoder_width : float, optional + Width multiplier for shared encoder layers. Default is 1.0. + Base dimensions: h1=128, h2=256, h3=512. + head_width : float, optional + Width multiplier for hue, value, and code heads. Default is 1.0. + Base dimensions: h1=128, h2=256. + chroma_head_width : float, optional + Width multiplier for chroma head (typically wider). Default is 1.0. + Base dimensions: h1=128, h2=256, h3=384. + dropout : float, optional + Dropout rate applied after hidden layers. Default is 0.0. + """ + + def __init__( + self, + encoder_width: float = 1.0, + head_width: float = 1.0, + chroma_head_width: float = 1.0, + dropout: float = 0.0, + ) -> None: + super().__init__() + + # Encoder dimensions (shared) + e_h1 = int(128 * encoder_width) + e_h2 = int(256 * encoder_width) + e_h3 = int(512 * encoder_width) + + # Head dimensions (component-specific) + h_h1 = int(128 * head_width) + h_h2 = int(256 * head_width) + + # Chroma head dimensions (specialized) + c_h1 = int(128 * chroma_head_width) + c_h2 = int(256 * chroma_head_width) + c_h3 = int(384 * chroma_head_width) + + # Shared encoder - learns general color space features + encoder_layers = [ + nn.Linear(3, e_h1), + nn.ReLU(), + nn.BatchNorm1d(e_h1), + ] + + if dropout > 0: + encoder_layers.append(nn.Dropout(dropout)) + + encoder_layers.extend( + [ + nn.Linear(e_h1, e_h2), + nn.ReLU(), + nn.BatchNorm1d(e_h2), + ] + ) + + if dropout > 0: + encoder_layers.append(nn.Dropout(dropout)) + + encoder_layers.extend( + [ + nn.Linear(e_h2, e_h3), + nn.ReLU(), + nn.BatchNorm1d(e_h3), + ] + ) + + if dropout > 0: + encoder_layers.append(nn.Dropout(dropout)) + + self.encoder = nn.Sequential(*encoder_layers) + + # Component-specific heads (hue, value, code) + def create_head() -> nn.Sequential: + head_layers = [ + nn.Linear(e_h3, h_h2), + nn.ReLU(), + nn.BatchNorm1d(h_h2), + ] + + if dropout > 0: + head_layers.append(nn.Dropout(dropout)) + + head_layers.extend( + [ + nn.Linear(h_h2, h_h1), + nn.ReLU(), + nn.BatchNorm1d(h_h1), + ] + ) + + if dropout > 0: + head_layers.append(nn.Dropout(dropout)) + + head_layers.append(nn.Linear(h_h1, 1)) + + return nn.Sequential(*head_layers) + + self.hue_head = create_head() + self.value_head = create_head() + self.code_head = create_head() + + # Chroma head - wider for harder task + chroma_layers = [ + nn.Linear(e_h3, c_h3), + nn.ReLU(), + nn.BatchNorm1d(c_h3), + ] + + if dropout > 0: + chroma_layers.append(nn.Dropout(dropout)) + + chroma_layers.extend( + [ + nn.Linear(c_h3, c_h2), + nn.ReLU(), + nn.BatchNorm1d(c_h2), + ] + ) + + if dropout > 0: + chroma_layers.append(nn.Dropout(dropout)) + + chroma_layers.extend( + [ + nn.Linear(c_h2, c_h1), + nn.ReLU(), + nn.BatchNorm1d(c_h1), + ] + ) + + if dropout > 0: + chroma_layers.append(nn.Dropout(dropout)) + + chroma_layers.append(nn.Linear(c_h1, 1)) + + self.chroma_head = nn.Sequential(*chroma_layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through shared encoder and component-specific heads. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, 3) containing normalized + xyY values. + + Returns + ------- + torch.Tensor + Predicted Munsell components, shape (batch_size, 4). + Output order: [hue, value, chroma, code]. + """ + # Shared feature extraction + features = self.encoder(x) + + # Component-specific predictions + hue = self.hue_head(features) + value = self.value_head(features) + chroma = self.chroma_head(features) + code = self.code_head(features) + + # Concatenate: [hue, value, chroma, code] + return torch.cat([hue, value, chroma, code], dim=1) + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize validation loss. + + This function defines the hyperparameter search space and training + procedure for each trial. It optimizes: + - Learning rate (1e-4 to 1e-3, log scale) + - Batch size (256, 512, or 1024) + - Encoder width multiplier (0.75 to 1.5) + - Head width multiplier (0.75 to 1.5) + - Chroma head width multiplier (1.0 to 1.75) + - Dropout rate (0.0 to 0.2) + - Weight decay (1e-5 to 1e-3, log scale) + + Parameters + ---------- + trial : Trial + Optuna trial object for suggesting hyperparameters. + + Returns + ------- + float + Best validation loss achieved during training. + + Raises + ------ + optuna.TrialPruned + If trial is pruned based on intermediate results. + """ + + # Suggest hyperparameters + lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True) + batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024]) + encoder_width = trial.suggest_float("encoder_width", 0.75, 1.5, step=0.25) + head_width = trial.suggest_float("head_width", 0.75, 1.5, step=0.25) + chroma_head_width = trial.suggest_float("chroma_head_width", 1.0, 1.75, step=0.25) + dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05) + weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" encoder_width: %.2f", encoder_width) + LOGGER.info(" head_width: %.2f", head_width) + LOGGER.info(" chroma_head_width: %.2f", chroma_head_width) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info(" weight_decay: %.6f", weight_decay) + + # Set device + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info(" device: %s", device) + + # Load data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + # Normalize outputs (xyY inputs are already in [0, 1] range) + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to tensors + X_train_t = torch.from_numpy(X_train).float() + y_train_t = torch.from_numpy(y_train_norm).float() + X_val_t = torch.from_numpy(X_val).float() + y_val_t = torch.from_numpy(y_val_norm).float() + + train_loader = DataLoader( + TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False + ) + + LOGGER.info( + " Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t) + ) + + # Initialize model + model = MultiHeadParametric( + encoder_width=encoder_width, + head_width=head_width, + chroma_head_width=chroma_head_width, + dropout=dropout, + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info(" Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", f"hparam_multi_head_trial_{trial.number}" + ) + + # Training loop with early stopping + num_epochs = 100 # Reduced for hyperparameter search + patience = 15 + best_val_loss = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "trial": trial.number, + "lr": lr, + "batch_size": batch_size, + "encoder_width": encoder_width, + "head_width": head_width, + "chroma_head_width": chroma_head_width, + "dropout": dropout, + "weight_decay": weight_decay, + "total_params": total_params, + } + ) + + for epoch in range(num_epochs): + train_loss = train_epoch( + model, train_loader, optimizer, weighted_mse_loss, device + ) + val_loss = validate(model, val_loader, weighted_mse_loss, device) + scheduler.step() + + # Per-component MAE + with torch.no_grad(): + pred_val = model(X_val_t.to(device)) + mae = torch.mean(torch.abs(pred_val - y_val_t.to(device)), dim=0).cpu() + + # Log to MLflow + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "mae_hue": mae[0].item(), + "mae_value": mae[1].item(), + "mae_chroma": mae[2].item(), + "mae_code": mae[3].item(), + "learning_rate": optimizer.param_groups[0]["lr"], + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train: %.6f, Val: %.6f - " + "MAE: hue=%.6f, value=%.6f, chroma=%.6f, code=%.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + mae[3], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + # Report intermediate value for pruning + trial.report(val_loss, epoch) + + # Handle pruning + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_train_loss": train_loss, + "final_mae_hue": mae[0].item(), + "final_mae_value": mae[1].item(), + "final_mae_chroma": mae[2].item(), + "final_mae_code": mae[3].item(), + "final_epoch": epoch + 1, + } + ) + + LOGGER.info(" Final validation loss: %.6f", best_val_loss) + + return best_val_loss + + +def main() -> None: + """ + Run hyperparameter search for Multi-Head model (xyY to Munsell). + + Performs systematic hyperparameter optimization using Optuna with: + - MedianPruner for early stopping of unpromising trials + - 20 total trials + - MLflow logging for each trial + - Result visualization using matplotlib (optimization history, + parameter importances, parallel coordinate plot) + + The search aims to find optimal hyperparameters for converting xyY + color coordinates to Munsell color specifications using a multi-head + architecture with shared encoder and component-specific heads. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head (from_xyY) Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + # Create study + study = optuna.create_study( + direction="minimize", + study_name="multi_head_from_xyY_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10), + ) + + # Run optimization + n_trials = 20 # Number of trials to run + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (val_loss): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "from_xyY" + results_dir.mkdir(exist_ok=True, parents=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"hparam_search_multi_head_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-Head (from_xyY) Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best validation loss: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for t in study.trials: + f.write(f"\nTrial {t.number}:\n") + if t.value is not None: + f.write(f" Value: {t.value:.6f}\n") + else: + f.write(" Value: Pruned\n") + f.write(" Params:\n") + for key, value in t.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + # Generate visualizations using matplotlib + from optuna.visualization.matplotlib import ( + plot_optimization_history, + plot_param_importances, + plot_parallel_coordinate, + ) + + # Optimization history + ax = plot_optimization_history(study) + ax.figure.savefig( + results_dir / f"optimization_history_multi_head_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parameter importances + ax = plot_param_importances(study) + ax.figure.savefig( + results_dir / f"param_importances_multi_head_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parallel coordinate plot + ax = plot_parallel_coordinate(study) + ax.figure.savefig( + results_dir / f"parallel_coordinate_multi_head_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + LOGGER.info("Visualizations saved to: %s", results_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py b/learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4555069ea323fe79f52da847786449704b2c2804 --- /dev/null +++ b/learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py @@ -0,0 +1,552 @@ +""" +Hyperparameter search for Multi-Head Error Predictor using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Width multipliers for each component branch (hue, value, chroma, code) +- Loss function component weights + +Objective: Minimize validation loss (combined base + error predictor) +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from pathlib import Path + +import matplotlib.pyplot as plt +import mlflow +import numpy as np +import onnxruntime as ort +import optuna +import torch +from numpy.typing import NDArray +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ComponentErrorPredictor +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class MultiHeadErrorPredictorParametric(nn.Module): + """ + Parametric Multi-Head error predictor with 4 independent branches. + + This model consists of four independent ComponentErrorPredictor + networks, one for each Munsell component (hue, value, chroma, code). + Each branch can have different widths for hyperparameter optimization. + + Parameters + ---------- + hue_width : float, optional + Width multiplier for the hue branch. Default is 1.0. + value_width : float, optional + Width multiplier for the value branch. Default is 1.0. + chroma_width : float, optional + Width multiplier for the chroma branch. Default is 1.5. + code_width : float, optional + Width multiplier for the code branch. Default is 1.0. + """ + + def __init__( + self, + hue_width: float = 1.0, + value_width: float = 1.0, + chroma_width: float = 1.5, + code_width: float = 1.0, + ) -> None: + super().__init__() + + # Independent error predictor for each component + self.hue_branch = ComponentErrorPredictor(width_multiplier=hue_width) + self.value_branch = ComponentErrorPredictor( + width_multiplier=value_width + ) + self.chroma_branch = ComponentErrorPredictor( + width_multiplier=chroma_width + ) + self.code_branch = ComponentErrorPredictor( + width_multiplier=code_width + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through all four error predictor branches. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, 7) containing normalized + xyY values and base model predictions. + + Returns + ------- + torch.Tensor + Predicted errors for all components, shape (batch_size, 4). + Output order: [hue_error, value_error, chroma_error, code_error]. + """ + # Each branch processes the same combined input independently + hue_error = self.hue_branch(x) + value_error = self.value_branch(x) + chroma_error = self.chroma_branch(x) + code_error = self.code_branch(x) + + # Concatenate: [Hue_error, Value_error, Chroma_error, Code_error] + return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the base Multi-Head ONNX model and its normalization parameters. + + Parameters + ---------- + model_path : Path + Path to the base Multi-Head model ONNX file. + params_path : Path + Path to the normalization parameters NPZ file. + + Returns + ------- + ort.InferenceSession + ONNX Runtime inference session for the base model. + dict + Input normalization parameters (x_range, y_range, Y_range). + dict + Output normalization parameters (hue_range, value_range, chroma_range, code_range). + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +def create_weighted_loss( + mse_weight: float, + mae_weight: float, + log_weight: float, + huber_weight: float, + huber_delta: float, +): + """ + Create a weighted loss function combining multiple loss components. + + Parameters + ---------- + mse_weight : float + Weight for MSE component. + mae_weight : float + Weight for MAE component. + log_weight : float + Weight for logarithmic penalty component. + huber_weight : float + Weight for Huber loss component. + huber_delta : float + Delta parameter for Huber loss transition point. + + Returns + ------- + callable + Loss function that accepts (pred, target) and returns a scalar loss. + """ + + def weighted_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute weighted combination of loss components. + + Parameters + ---------- + pred : torch.Tensor + Predicted values, shape (batch_size, n_components). + target : torch.Tensor + Target values, shape (batch_size, n_components). + + Returns + ------- + torch.Tensor + Weighted combination of loss components, scalar tensor. + """ + # Standard MSE + mse = torch.mean((pred - target) ** 2) + + # Mean absolute error + mae = torch.mean(torch.abs(pred - target)) + + # Logarithmic penalty + log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0)) + + # Huber loss + abs_error = torch.abs(pred - target) + huber = torch.where( + abs_error <= huber_delta, + 0.5 * abs_error**2, + huber_delta * (abs_error - 0.5 * huber_delta), + ) + huber_loss = torch.mean(huber) + + # Combine with weights + return ( + mse_weight * mse + + mae_weight * mae + + log_weight * log_penalty + + huber_weight * huber_loss + ) + + return weighted_loss + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize validation loss. + + This function defines the hyperparameter search space and training + procedure for each trial. It optimizes: + - Learning rate (1e-4 to 1e-3, log scale) + - Batch size (512, 1024, or 2048) + - Width multipliers for each component branch + - Loss function weights (MSE, MAE, log penalty, Huber) + - Huber delta parameter (0.005 to 0.02) + + Parameters + ---------- + trial : Trial + Optuna trial object for suggesting hyperparameters. + + Returns + ------- + float + Best validation loss achieved during training. + + Raises + ------ + optuna.TrialPruned + If trial is pruned based on intermediate results. + """ + + # Suggest hyperparameters + lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True) + batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048]) + hue_width = trial.suggest_float("hue_width", 0.75, 1.5, step=0.25) + value_width = trial.suggest_float("value_width", 0.75, 1.5, step=0.25) + chroma_width = trial.suggest_float("chroma_width", 1.0, 2.0, step=0.25) + code_width = trial.suggest_float("code_width", 0.75, 1.5, step=0.25) + + # Loss function weights + mse_weight = trial.suggest_float("mse_weight", 0.5, 2.0, step=0.5) + mae_weight = trial.suggest_float("mae_weight", 0.0, 1.0, step=0.25) + log_weight = trial.suggest_float("log_weight", 0.0, 0.5, step=0.1) + huber_weight = trial.suggest_float("huber_weight", 0.0, 1.0, step=0.25) + huber_delta = trial.suggest_float("huber_delta", 0.005, 0.02, step=0.005) + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" hue_width: %.2f", hue_width) + LOGGER.info(" value_width: %.2f", value_width) + LOGGER.info(" chroma_width: %.2f", chroma_width) + LOGGER.info(" code_width: %.2f", code_width) + LOGGER.info(" mse_weight: %.2f", mse_weight) + LOGGER.info(" mae_weight: %.2f", mae_weight) + LOGGER.info(" log_weight: %.2f", log_weight) + LOGGER.info(" huber_weight: %.2f", huber_weight) + LOGGER.info(" huber_delta: %.3f", huber_delta) + + # Set device + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info(" device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_directory / "multi_head.onnx" + params_path = model_directory / "multi_head_normalization_params.npz" + cache_file = data_dir / "training_data.npz" + + # Load base model + base_session, input_params, output_params = load_base_model( + base_model_path, params_path + ) + + # Load training data + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + # Normalize + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Generate base model predictions + base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0] + base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0] + + # Compute errors + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_loader = DataLoader( + TensorDataset(X_train_t, error_train_t), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val_t, error_val_t), batch_size=batch_size, shuffle=False + ) + + LOGGER.info( + " Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t) + ) + + # Initialize error predictor model + model = MultiHeadErrorPredictorParametric( + hue_width=hue_width, + value_width=value_width, + chroma_width=chroma_width, + code_width=code_width, + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info(" Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + + # Create loss function + criterion = create_weighted_loss( + mse_weight, mae_weight, log_weight, huber_weight, huber_delta + ) + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", f"hparam_multi_head_error_trial_{trial.number}" + ) + + # Training loop with early stopping + num_epochs = 50 # Reduced for hyperparameter search + patience = 10 + best_val_loss = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "lr": lr, + "batch_size": batch_size, + "hue_width": hue_width, + "value_width": value_width, + "chroma_width": chroma_width, + "code_width": code_width, + "mse_weight": mse_weight, + "mae_weight": mae_weight, + "log_weight": log_weight, + "huber_weight": huber_weight, + "huber_delta": huber_delta, + "total_params": total_params, + "trial_number": trial.number, + } + ) + + for epoch in range(num_epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + scheduler.step(val_loss) + + # Log to MLflow + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train: %.6f, Val: %.6f, LR: %.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + # Report intermediate value for pruning + trial.report(val_loss, epoch) + + # Handle pruning + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_train_loss": train_loss, + } + ) + + LOGGER.info(" Final validation loss: %.6f", best_val_loss) + + return best_val_loss + + +def main() -> None: + """ + Run hyperparameter search for Multi-Head Error Predictor. + + Performs systematic hyperparameter optimization using Optuna with: + - MedianPruner for early stopping of unpromising trials + - 30 total trials + - MLflow logging for each trial + - Result visualization using matplotlib (optimization history, + parameter importances, parallel coordinate plot) + + The search aims to find optimal hyperparameters for predicting errors + in a base Multi-Head model, allowing for error correction and improved + Munsell predictions. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Error Predictor Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + # Create study + study = optuna.create_study( + direction="minimize", + study_name="multi_head_error_predictor_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=5), + ) + + # Run optimization + n_trials = 30 # Number of trials to run + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (val_loss): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "from_xyY" + results_dir.mkdir(exist_ok=True, parents=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"hparam_search_multi_head_error_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-Head Error Predictor Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best validation loss: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for t in study.trials: + f.write(f"\nTrial {t.number}:\n") + if t.value is not None: + f.write(f" Value: {t.value:.6f}\n") + else: + f.write(" Value: Pruned\n") + f.write(" Params:\n") + for key, value in t.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + # Generate visualizations using matplotlib + from optuna.visualization.matplotlib import ( + plot_optimization_history, + plot_param_importances, + plot_parallel_coordinate, + ) + + # Optimization history + ax = plot_optimization_history(study) + ax.figure.savefig( + results_dir / f"optimization_history_multi_head_error_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parameter importances + ax = plot_param_importances(study) + ax.figure.savefig( + results_dir / f"param_importances_multi_head_error_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parallel coordinate plot + ax = plot_parallel_coordinate(study) + ax.figure.savefig( + results_dir / f"parallel_coordinate_multi_head_error_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + LOGGER.info("Visualizations saved to: %s", results_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py b/learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..335af995bad8dba15e5db3130eecb3079eb3af72 --- /dev/null +++ b/learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py @@ -0,0 +1,471 @@ +""" +Hyperparameter search for Multi-MLP model using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Chroma width multiplier +- Chroma loss weight +- Code loss weight +- Dropout (optional) + +Objective: Minimize validation loss +""" + +import logging +from datetime import datetime + +import matplotlib.pyplot as plt +import mlflow +import numpy as np +import optuna +import torch +from numpy.typing import NDArray +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPToMunsell +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell + +LOGGER = logging.getLogger(__name__) + + +def weighted_mse_loss( + pred: torch.Tensor, + target: torch.Tensor, + hue_weight: float = 1.0, + value_weight: float = 1.0, + chroma_weight: float = 4.0, + code_weight: float = 0.5, +) -> torch.Tensor: + """ + Component-wise weighted MSE loss with configurable weights. + + Applies different weights to each Munsell component to account for + varying prediction difficulty and importance. + + Parameters + ---------- + pred : torch.Tensor + Predicted values, shape (batch_size, 4). + target : torch.Tensor + Target values, shape (batch_size, 4). + hue_weight : float, optional + Weight for hue component. Default is 1.0. + value_weight : float, optional + Weight for value component. Default is 1.0. + chroma_weight : float, optional + Weight for chroma component (typically higher). Default is 4.0. + code_weight : float, optional + Weight for code component (typically lower). Default is 0.5. + + Returns + ------- + torch.Tensor + Weighted MSE loss, scalar tensor. + """ + weights = torch.tensor( + [hue_weight, value_weight, chroma_weight, code_weight], device=pred.device + ) + + mse = (pred - target) ** 2 + weighted_mse = mse * weights + return weighted_mse.mean() + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, + chroma_weight: float, + code_weight: float, +) -> float: + """ + Train the model for one epoch. + + Parameters + ---------- + model : nn.Module + Multi-MLP model to train. + dataloader : DataLoader + DataLoader providing training batches. + optimizer : optim.Optimizer + Optimizer for updating model parameters. + device : torch.device + Device to run training on (CPU, CUDA, or MPS). + chroma_weight : float + Weight for chroma component in loss function. + code_weight : float + Weight for code component in loss function. + + Returns + ------- + float + Average training loss over the epoch. + """ + model.train() + total_loss = 0.0 + + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + # Forward pass + outputs = model(X_batch) + loss = weighted_mse_loss( + outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight + ) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + chroma_weight: float, + code_weight: float, +) -> float: + """ + Validate the model on the validation set. + + Parameters + ---------- + model : nn.Module + Multi-MLP model to validate. + dataloader : DataLoader + DataLoader providing validation batches. + device : torch.device + Device to run validation on (CPU, CUDA, or MPS). + chroma_weight : float + Weight for chroma component in loss function. + code_weight : float + Weight for code component in loss function. + + Returns + ------- + float + Average validation loss. + """ + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + outputs = model(X_batch) + loss = weighted_mse_loss( + outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight + ) + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize validation loss. + + This function defines the hyperparameter search space and training + procedure for each trial. It optimizes: + - Learning rate (1e-4 to 1e-3, log scale) + - Batch size (512, 1024, or 2048) + - Chroma branch width multiplier (1.5 to 2.5) + - Chroma loss weight (3.0 to 6.0) + - Code loss weight (0.3 to 1.0) + - Dropout rate (0.0 to 0.2) + + Parameters + ---------- + trial : Trial + Optuna trial object for suggesting hyperparameters. + + Returns + ------- + float + Best validation loss achieved during training. + + Raises + ------ + FileNotFoundError + If training data file is not found. + optuna.TrialPruned + If trial is pruned based on intermediate results. + """ + + # Suggest hyperparameters + lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True) + batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048]) + chroma_width = trial.suggest_float("chroma_width", 1.5, 2.5, step=0.25) + chroma_weight = trial.suggest_float("chroma_weight", 3.0, 6.0, step=0.5) + code_weight = trial.suggest_float("code_weight", 0.3, 1.0, step=0.1) + dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05) + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" chroma_width: %.2f", chroma_width) + LOGGER.info(" chroma_weight: %.1f", chroma_weight) + LOGGER.info(" code_weight: %.1f", code_weight) + LOGGER.info(" dropout: %.2f", dropout) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load training data + data_file = PROJECT_ROOT / "data" / "training_data.npz" + + if not data_file.exists(): + LOGGER.error("Training data not found at %s", data_file) + LOGGER.error("Run generate_training_data.py first") + msg = f"Training data not found: {data_file}" + raise FileNotFoundError(msg) + + data = np.load(data_file) + + # Use pre-split data + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info( + "Loaded %d training samples, %d validation samples", len(X_train), len(X_val) + ) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train = normalize_munsell(y_train, output_params) + y_val = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiMLPToMunsell( + chroma_width_multiplier=chroma_width, dropout=dropout + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", f"hparam_multi_mlp_trial_{trial.number}" + ) + + # Training loop with early stopping + num_epochs = 100 # Reduced for hyperparameter search + patience = 15 + best_val_loss = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "trial": trial.number, + "lr": lr, + "batch_size": batch_size, + "chroma_width": chroma_width, + "chroma_weight": chroma_weight, + "code_weight": code_weight, + "dropout": dropout, + "total_params": total_params, + } + ) + + for epoch in range(num_epochs): + train_loss = train_epoch( + model, train_loader, optimizer, device, chroma_weight, code_weight + ) + val_loss = validate(model, val_loader, device, chroma_weight, code_weight) + + # Log to MLflow + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": lr, + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + # Report intermediate value for pruning + trial.report(val_loss, epoch) + + # Handle pruning + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_train_loss": train_loss, + "final_epoch": epoch + 1, + } + ) + + LOGGER.info(" Final validation loss: %.6f", best_val_loss) + + return best_val_loss + + +def main() -> None: + """ + Run hyperparameter search for Multi-MLP model. + + Performs systematic hyperparameter optimization using Optuna with: + - MedianPruner for early stopping of unpromising trials + - 15 total trials + - MLflow logging for each trial + - Result visualization using matplotlib (optimization history, + parameter importances, parallel coordinate plot) + + The search aims to find optimal hyperparameters for converting xyY + color coordinates to Munsell color specifications using a multi-MLP + architecture with independent branches for each component. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + # Create study + study = optuna.create_study( + direction="minimize", + study_name="multi_mlp_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10), + ) + + # Run optimization + n_trials = 15 # Number of trials to run + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (val_loss): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "from_xyY" + results_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"hparam_search_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-MLP Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best validation loss: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for trial in study.trials: + f.write(f"\nTrial {trial.number}:\n") + f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n") + f.write(" Params:\n") + for key, value in trial.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + # Generate visualizations using matplotlib + from optuna.visualization.matplotlib import ( + plot_optimization_history, + plot_param_importances, + plot_parallel_coordinate, + ) + + # Optimization history + ax = plot_optimization_history(study) + ax.figure.savefig(results_dir / f"optimization_history_{timestamp}.png", dpi=150) + plt.close(ax.figure) + + # Parameter importances + ax = plot_param_importances(study) + ax.figure.savefig(results_dir / f"param_importances_{timestamp}.png", dpi=150) + plt.close(ax.figure) + + # Parallel coordinate plot + ax = plot_parallel_coordinate(study) + ax.figure.savefig(results_dir / f"parallel_coordinate_{timestamp}.png", dpi=150) + plt.close(ax.figure) + + LOGGER.info("Visualizations saved to: %s", results_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/refine_multi_head_real.py b/learning_munsell/training/from_xyY/refine_multi_head_real.py new file mode 100644 index 0000000000000000000000000000000000000000..4942138ef5485b07effd46bdee4f446726c97a06 --- /dev/null +++ b/learning_munsell/training/from_xyY/refine_multi_head_real.py @@ -0,0 +1,358 @@ +""" +Refine Multi-Head model on REAL Munsell colors only. + +This script fine-tunes the best Multi-Head model using only the 2734 real +(measured) Munsell colors, which should improve accuracy on the evaluation set. +""" + +import logging +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL +from colour.notation.munsell import ( + munsell_colour_to_munsell_specification, + munsell_specification_to_xyY, +) +from numpy.typing import NDArray +from sklearn.model_selection import train_test_split +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.common import ( + log_training_epoch, + setup_mlflow_experiment, +) +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def generate_real_samples( + n_samples_per_color: int = 100, + perturbation_pct: float = 0.05, +) -> tuple[NDArray, NDArray]: + """ + Generate training samples from REAL (measured) Munsell colors only. + + Creates augmented samples by applying small perturbations to the 2734 real + Munsell color specifications to increase training data while staying close + to measured values. + + Parameters + ---------- + n_samples_per_color : int, optional + Number of perturbed samples to generate per real color (default is 100). + perturbation_pct : float, optional + Percentage of range to use for perturbations (default is 0.05 = 5%). + + Returns + ------- + xyY_samples : NDArray + Array of shape (n_samples, 3) containing xyY coordinates. + munsell_samples : NDArray + Array of shape (n_samples, 4) containing Munsell specifications + [hue, value, chroma, code]. + + Notes + ----- + Perturbations are applied uniformly within ±perturbation_pct of the + component ranges: + - Hue range: 9.5 (0.5 to 10.0) + - Value range: 9.0 (1.0 to 10.0) + - Chroma range: 50.0 (0.0 to 50.0) + + Invalid samples (that cannot be converted to xyY) are skipped. + """ + LOGGER.info( + "Generating samples from %d REAL Munsell colors...", len(MUNSELL_COLOURS_REAL) + ) + + np.random.seed(42) + + hue_range = 9.5 + value_range = 9.0 + chroma_range = 50.0 + + xyY_samples = [] + munsell_samples = [] + + for munsell_spec_tuple, _ in MUNSELL_COLOURS_REAL: + hue_code_str, value, chroma = munsell_spec_tuple + munsell_str = f"{hue_code_str} {value}/{chroma}" + base_spec = munsell_colour_to_munsell_specification(munsell_str) + + for _ in range(n_samples_per_color): + hue_delta = np.random.uniform( + -perturbation_pct * hue_range, perturbation_pct * hue_range + ) + value_delta = np.random.uniform( + -perturbation_pct * value_range, perturbation_pct * value_range + ) + chroma_delta = np.random.uniform( + -perturbation_pct * chroma_range, perturbation_pct * chroma_range + ) + + perturbed_spec = base_spec.copy() + perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0) + perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0) + perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0) + + try: + xyY = munsell_specification_to_xyY(perturbed_spec) + xyY_samples.append(xyY) + munsell_samples.append(perturbed_spec) + except Exception: + continue + + LOGGER.info("Generated %d samples", len(xyY_samples)) + return np.array(xyY_samples), np.array(munsell_samples) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=512, help="Batch size for training") +@click.option("--lr", default=1e-5, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Refine Multi-Head model on REAL Munsell colors only. + + Fine-tunes a pretrained Multi-Head MLP model using only the 2734 real + (measured) Munsell colors with small perturbations. This refinement step + aims to improve accuracy on actual measured colors by focusing the model + on the real color gamut. + + Notes + ----- + Training configuration: + - Dataset: 2734 real Munsell colors with 200 samples per color + - Perturbation: 3% of component ranges (smaller than initial training) + - Learning rate: 1e-5 (lower for fine-tuning) + - Batch size: 512 + - Early stopping: patience of 30 epochs + - Optimizer: AdamW with weight decay 0.01 + - Scheduler: ReduceLROnPlateau with factor 0.5, patience 15 + + Workflow: + 1. Generate augmented samples from real Munsell colors + 2. Load pretrained model (multi_head_large_best.pth) + 3. Fine-tune with lower learning rate + 4. Save best model based on validation loss + 5. Export to ONNX format + 6. Log metrics to MLflow + + Files generated: + - multi_head_refined_real_best.pth: Best checkpoint + - multi_head_refined_real.onnx: ONNX model + - multi_head_refined_real_normalization_params.npz: Normalization params + """ + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Refinement on REAL Munsell Colors") + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Generate REAL-only samples + LOGGER.info("") + xyY_all, munsell_all = generate_real_samples( + n_samples_per_color=200, # 200 samples per real color + perturbation_pct=0.03, # Smaller perturbations for refinement + ) + + # Split data + X_train, X_val, y_train, y_val = train_test_split( + xyY_all, munsell_all, test_size=0.15, random_state=42 + ) + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use hardcoded ranges covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Load pretrained model + model_directory = PROJECT_ROOT / "models" / "from_xyY" + pretrained_path = model_directory / "multi_head_large_best.pth" + + model = MultiHeadMLPToMunsell().to(device) + + if pretrained_path.exists(): + LOGGER.info("") + LOGGER.info("Loading pretrained model from %s...", pretrained_path) + checkpoint = torch.load( + pretrained_path, weights_only=False, map_location=device + ) + model.load_state_dict(checkpoint["model_state_dict"]) + LOGGER.info("Pretrained model loaded successfully") + else: + LOGGER.info("") + LOGGER.info("No pretrained model found, training from scratch") + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Fine-tuning with lower learning rate + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=15 + ) + criterion = nn.MSELoss() + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head_refined_real") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + LOGGER.info("Learning rate: %e (fine-tuning)", lr) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting refinement training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_refined_real", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "REAL_only", + "perturbation_pct": 0.03, + "samples_per_color": 200, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.2e", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + checkpoint_file = model_directory / "multi_head_refined_real_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" -> Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting refined model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + model_cpu = model.cpu() + dummy_input = torch.randn(1, 3) + + onnx_file = model_directory / "multi_head_refined_real.onnx" + torch.onnx.export( + model_cpu, + dummy_input, + onnx_file, + export_params=True, + opset_version=14, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + params_file = ( + model_directory / "multi_head_refined_real_normalization_params.npz" + ) + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization params saved to: %s", params_file) + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_deep_wide.py b/learning_munsell/training/from_xyY/train_deep_wide.py new file mode 100644 index 0000000000000000000000000000000000000000..13e845605c0dc84d0b0057118b076f0f71df1974 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_deep_wide.py @@ -0,0 +1,371 @@ +""" +Train Deep + Wide model for xyY to Munsell conversion. + +Option 5: Hybrid Deep + Wide architecture +- Input: 3 features (xyY) +- Deep path: 3 → 512 → 1024 (ResBlocks) → 512 +- Wide path: 3 → 128 (direct linear) +- Combine: [512, 128] → 256 → 4 +- Output: 4 features (hue, value, chroma, code) +""" + +import logging +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ResidualBlock +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class DeepWideNet(nn.Module): + """ + Deep + Wide Network for xyY to Munsell conversion. + + Architecture: + - Deep path: Complex non-linear transformation + - Wide path: Direct linear connections + - Combines both for final prediction + + Parameters + ---------- + num_residual_blocks : int, optional + Number of residual blocks in deep path. Default is 4. + + Attributes + ---------- + deep_encoder : nn.Sequential + Deep path encoder: 3 → 512 → 1024. + deep_residual_blocks : nn.ModuleList + Stack of residual blocks in deep path. + deep_decoder : nn.Sequential + Deep path decoder: 1024 → 512. + wide_path : nn.Sequential + Wide path: 3 → 128. + output_head : nn.Sequential + Combined output: [512, 128] → 256 → 4. + + Notes + ----- + Hybrid architecture inspired by Google's Wide & Deep Learning: + - Deep path: 3 → 512 → 1024 → (ResBlocks) → 512 + - Wide path: 3 → 128 (direct linear transformation) + - Combined: Concatenate [512, 128] → 256 → 4 + + The deep path learns complex non-linear transformations while the + wide path provides direct linear connections to preserve simple + relationships. Both paths are concatenated before the final output. + """ + + def __init__(self, num_residual_blocks: int = 4) -> None: + """Initialize the deep and wide network.""" + super().__init__() + + # Deep path: Complex transformation + self.deep_encoder = nn.Sequential( + nn.Linear(3, 512), + nn.GELU(), + nn.BatchNorm1d(512), + nn.Linear(512, 1024), + nn.GELU(), + nn.BatchNorm1d(1024), + ) + + self.deep_residual_blocks = nn.ModuleList( + [ResidualBlock(1024) for _ in range(num_residual_blocks)] + ) + + self.deep_decoder = nn.Sequential( + nn.Linear(1024, 512), + nn.GELU(), + nn.BatchNorm1d(512), + ) + + # Wide path: Direct linear transformation + self.wide_path = nn.Sequential( + nn.Linear(3, 128), + nn.GELU(), + nn.BatchNorm1d(128), + ) + + # Combined output: Concatenate deep (512) + wide (128) = 640 + self.output_head = nn.Sequential( + nn.Linear(640, 256), + nn.GELU(), + nn.BatchNorm1d(256), + nn.Linear(256, 4), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through deep and wide paths. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, 3) containing normalized xyY values. + + Returns + ------- + Tensor + Output tensor of shape (batch_size, 4) containing normalized Munsell + specifications [hue, value, chroma, code]. + + Notes + ----- + The forward pass processes input through two parallel paths: + 1. Deep path: Complex transformation through encoder, residual blocks, + and decoder (3 → 512 → 1024 → 512) + 2. Wide path: Direct linear transformation (3 → 128) + 3. Concatenation: Combine deep (512) + wide (128) = 640 features + 4. Output head: Final transformation to 4 components (640 → 256 → 4) + """ + # Deep path + deep = self.deep_encoder(x) + for block in self.deep_residual_blocks: + deep = block(deep) + deep = self.deep_decoder(deep) + + # Wide path + wide = self.wide_path(x) + + # Concatenate and output + combined = torch.cat([deep, wide], dim=1) + return self.output_head(combined) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=3e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the DeepWideNet model for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads normalization parameters from existing config + 2. Loads training data from cache + 3. Normalizes inputs and outputs to [0, 1] range + 4. Creates PyTorch DataLoaders + 5. Initializes DeepWideNet with deep and wide paths + 6. Trains with AdamW optimizer and precision-focused loss + 7. Uses learning rate scheduler (ReduceLROnPlateau) + 8. Implements early stopping based on validation loss + 9. Exports best model to ONNX format + 10. Logs all metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("Deep + Wide Network: xyY → Munsell") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + # Load training data + LOGGER.info("") + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use hardcoded ranges covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = DeepWideNet(num_residual_blocks=4).to(device) + LOGGER.info("") + LOGGER.info("Deep + Wide architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + learning_rate = lr + optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + criterion = precision_focused_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "deep_wide") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log parameters + mlflow.log_params( + { + "model": "deep_wide", + "learning_rate": learning_rate, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + # Log to MLflow + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "deep_wide_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "deep_wide.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={ + "xyY": {0: "batch_size"}, + "munsell_spec": {0: "batch_size"}, + }, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "deep_wide_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_ft_transformer.py b/learning_munsell/training/from_xyY/train_ft_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f83c760934f083e0190de6cfcdb67476fa98ba --- /dev/null +++ b/learning_munsell/training/from_xyY/train_ft_transformer.py @@ -0,0 +1,356 @@ +""" +Train FT-Transformer model for xyY to Munsell conversion. + +Option 4: Feature Tokenizer + Transformer architecture +- Input: 3 features (xyY) → each becomes a 256-dim token +- Add [CLS] token for regression +- 4-6 transformer blocks with multi-head attention +- Output: Take [CLS] token → MLP → 4 features +""" + +import logging +import click +from typing import Any + +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import FeatureTokenizer, TransformerBlock +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class FTTransformer(nn.Module): + """ + Feature Tokenizer + Transformer for xyY to Munsell conversion. + + This model adapts transformer architecture for tabular data by tokenizing + each input feature separately and using self-attention to capture complex + feature interactions. + + Architecture + ------------ + - Tokenize each feature (3 features → 3 tokens) + - Add CLS token (4 tokens total) + - 4 transformer blocks with multi-head attention + - Extract CLS token → MLP head → 4 outputs + + Parameters + ---------- + num_features : int, optional + Number of input features (xyY), default is 3. + embedding_dim : int, optional + Dimension of token embeddings, default is 256. + num_blocks : int, optional + Number of transformer blocks, default is 4. + num_heads : int, optional + Number of attention heads, default is 4. + ff_dim : int, optional + Feedforward network hidden dimension, default is 512. + dropout : float, optional + Dropout probability, default is 0.1. + + Attributes + ---------- + tokenizer : FeatureTokenizer + Converts input features to token embeddings. + transformer_blocks : nn.ModuleList + Stack of transformer blocks. + output_head : nn.Sequential + MLP that maps CLS token to output predictions. + """ + + def __init__( + self, + num_features: int = 3, + embedding_dim: int = 256, + num_blocks: int = 4, + num_heads: int = 4, + ff_dim: int = 512, + dropout: float = 0.1, + ) -> None: + """Initialize the FT-Transformer model.""" + super().__init__() + + # Feature tokenizer + self.tokenizer = FeatureTokenizer(num_features, embedding_dim) + + # Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + TransformerBlock(embedding_dim, num_heads, ff_dim, dropout) + for _ in range(num_blocks) + ] + ) + + # Output head (from CLS token) + self.output_head = nn.Sequential( + nn.Linear(embedding_dim, 128), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(128, 4), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through FT-Transformer. + + Parameters + ---------- + x : Tensor + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Tensor + Predicted Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + """ + # Tokenize features + tokens = self.tokenizer(x) # (batch_size, 1+num_features, embedding_dim) + + # Transformer blocks + for block in self.transformer_blocks: + tokens = block(tokens) + + # Extract CLS token (first token) + cls_token = tokens[:, 0, :] # (batch_size, embedding_dim) + + # Output head + return self.output_head(cls_token) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=3e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train FT-Transformer model for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads normalization parameters from existing config + 2. Loads training data from cache + 3. Normalizes inputs and outputs to [0, 1] range + 4. Creates PyTorch DataLoaders + 5. Initializes FT-Transformer with feature tokenization + 6. Trains with AdamW optimizer and precision-focused loss + 7. Uses learning rate scheduler (ReduceLROnPlateau) + 8. Implements early stopping based on validation loss + 9. Exports best model to ONNX format + 10. Logs all metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("FT-Transformer: xyY → Munsell") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + # Load training data + LOGGER.info("") + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = FTTransformer( + num_features=3, + embedding_dim=256, + num_blocks=4, + num_heads=4, + ff_dim=512, + dropout=0.1, + ).to(device) + + LOGGER.info("") + LOGGER.info("FT-Transformer architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + criterion = precision_focused_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "ft_transformer") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "ft_transformer", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "ft_transformer_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "ft_transformer.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={ + "xyY": {0: "batch_size"}, + "munsell_spec": {0: "batch_size"}, + }, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "ft_transformer_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_mixture_of_experts.py b/learning_munsell/training/from_xyY/train_mixture_of_experts.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7b3a0e7a2fc327cf8bf27620ebcc24b2ce614d --- /dev/null +++ b/learning_munsell/training/from_xyY/train_mixture_of_experts.py @@ -0,0 +1,620 @@ +""" +Train Mixture of Experts model for xyY to Munsell conversion. + +Option 6: Mixture of Experts architecture +- Input: 3 features (xyY) +- Gating network: 3 → 128 → 64 → 4 (softmax weights) +- 4 Expert networks: Each 3 → 256 → 256 → 4 (MLP) +- Output: Weighted combination of expert outputs +""" + +import logging +import click + +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ResidualBlock +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) + +LOGGER = logging.getLogger(__name__) + + +class ExpertNetwork(nn.Module): + """ + Single expert network with MLP architecture. + + Each expert is a specialized neural network that learns to handle + specific regions of the input space. Uses residual connections for + improved gradient flow. + + Architecture + ------------ + - Encoder: 3 → 256 with GELU and BatchNorm + - Residual blocks: Configurable number of ResidualBlock(256) + - Decoder: 256 → 4 + + Parameters + ---------- + num_residual_blocks : int, optional + Number of residual blocks, default is 2. + + Attributes + ---------- + encoder : nn.Sequential + Input encoding layer. + residual_blocks : nn.ModuleList + Stack of residual blocks. + decoder : nn.Sequential + Output decoding layer. + """ + + def __init__(self, num_residual_blocks: int = 2) -> None: + """Initialize the expert network.""" + super().__init__() + + self.encoder = nn.Sequential( + nn.Linear(3, 256), + nn.GELU(), + nn.BatchNorm1d(256), + ) + + self.residual_blocks = nn.ModuleList( + [ResidualBlock(256) for _ in range(num_residual_blocks)] + ) + + self.decoder = nn.Sequential( + nn.Linear(256, 4), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through expert network. + + Parameters + ---------- + x : Tensor + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Tensor + Expert's prediction of shape (batch_size, 4). + """ + x = self.encoder(x) + for block in self.residual_blocks: + x = block(x) + return self.decoder(x) + + +class GatingNetwork(nn.Module): + """ + Gating network to compute expert weights. + + Learns to route inputs to appropriate experts by outputting a probability + distribution over all experts. Different inputs activate different experts + based on learned input characteristics. + + Architecture + ------------ + 3 → 128 → 64 → num_experts → softmax + + Parameters + ---------- + num_experts : int + Number of expert networks to gate. + + Attributes + ---------- + gate : nn.Sequential + MLP that maps inputs to expert logits. + """ + + def __init__(self, num_experts: int) -> None: + """Initialize the gating network.""" + super().__init__() + + self.gate = nn.Sequential( + nn.Linear(3, 128), + nn.GELU(), + nn.BatchNorm1d(128), + nn.Linear(128, 64), + nn.GELU(), + nn.BatchNorm1d(64), + nn.Linear(64, num_experts), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute expert weights for input. + + Parameters + ---------- + x : Tensor + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Tensor + Softmax weights over experts of shape (batch_size, num_experts). + Weights sum to 1 along expert dimension. + """ + # Output softmax weights for each expert + return torch.softmax(self.gate(x), dim=-1) + + +class MixtureOfExperts(nn.Module): + """ + Mixture of Experts for xyY to Munsell conversion. + + Implements a mixture of experts architecture where multiple specialized + neural networks (experts) are combined via learned gating weights. This + allows different experts to specialize in different regions of the input + space (e.g., different color ranges or hue families). + + Architecture + ------------ + - Gating network: Learns which expert(s) to use for each input + - Multiple expert networks: Each specializes in different input regions + - Output: Weighted combination of expert predictions based on gate weights + - Load balancing: Auxiliary loss encourages balanced expert usage + + Parameters + ---------- + num_experts : int, optional + Number of expert networks, default is 4. + num_residual_blocks : int, optional + Number of residual blocks per expert, default is 2. + + Attributes + ---------- + num_experts : int + Number of expert networks. + gating_network : GatingNetwork + Network that computes expert weights. + experts : nn.ModuleList + List of expert networks. + load_balance_weight : float + Weight for load balancing auxiliary loss. + """ + + def __init__(self, num_experts: int = 4, num_residual_blocks: int = 2) -> None: + """Initialize the mixture of experts model.""" + super().__init__() + + self.num_experts = num_experts + + # Gating network + self.gating_network = GatingNetwork(num_experts) + + # Expert networks + self.experts = nn.ModuleList( + [ExpertNetwork(num_residual_blocks) for _ in range(num_experts)] + ) + + # Load balancing loss weight + self.load_balance_weight = 0.01 + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through mixture of experts. + + Parameters + ---------- + x : Tensor + Input xyY values of shape (batch_size, 3). + + Returns + ------- + tuple + (output, gate_weights) where: + - output: Weighted expert predictions of shape (batch_size, 4) + - gate_weights: Expert weights of shape (batch_size, num_experts) + """ + # Get gating weights + gate_weights = self.gating_network(x) # (batch_size, num_experts) + + # Get expert outputs + expert_outputs = torch.stack( + [expert(x) for expert in self.experts], dim=1 + ) # (batch_size, num_experts, 4) + + # Weighted combination + gate_weights_expanded = gate_weights.unsqueeze( + -1 + ) # (batch_size, num_experts, 1) + output = torch.sum( + expert_outputs * gate_weights_expanded, dim=1 + ) # (batch_size, 4) + + return output, gate_weights + + +def precision_focused_loss( + pred: torch.Tensor, + target: torch.Tensor, + gate_weights: torch.Tensor, + load_balance_weight: float = 0.01, +) -> torch.Tensor: + """ + Precision-focused loss function with load balancing for mixture of experts. + + Combines standard regression losses (MSE, MAE, log penalty, Huber) with + a load balancing auxiliary loss that encourages uniform expert usage across + the dataset to prevent expert collapse. + + Parameters + ---------- + pred : torch.Tensor + Predicted values. + target : torch.Tensor + Target ground truth values. + gate_weights : torch.Tensor + Expert gating weights of shape (batch_size, num_experts). + load_balance_weight : float, optional + Weight for load balancing auxiliary loss, default is 0.01. + + Returns + ------- + torch.Tensor + Combined loss value including load balancing term. + + Notes + ----- + The load balancing loss encourages each expert to handle roughly + 1/num_experts of the data, preventing scenarios where only a few + experts are used while others remain idle. + """ + # Standard precision loss + mse = torch.mean((pred - target) ** 2) + mae = torch.mean(torch.abs(pred - target)) + log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0)) + + delta = 0.01 + abs_error = torch.abs(pred - target) + huber = torch.where( + abs_error <= delta, 0.5 * abs_error**2, delta * (abs_error - 0.5 * delta) + ) + huber_loss = torch.mean(huber) + + # Load balancing loss: Encourage balanced expert usage + # Compute importance (sum of gate weights per expert) + importance = gate_weights.sum(dim=0) # (num_experts,) + # Normalize to probabilities + importance = importance / importance.sum() + # Encourage uniform distribution (1/num_experts for each) + num_experts = gate_weights.size(1) + target_importance = torch.ones_like(importance) / num_experts + load_balance_loss = torch.mean((importance - target_importance) ** 2) + + return ( + 1.0 * mse + + 0.5 * mae + + 0.3 * log_penalty + + 0.5 * huber_loss + + load_balance_weight * load_balance_loss + ) + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, +) -> float: + """ + Train the mixture of experts model for one epoch. + + Parameters + ---------- + model : nn.Module + The neural network model to train. + dataloader : DataLoader + DataLoader providing training batches (X, y). + optimizer : optim.Optimizer + Optimizer for updating model parameters. + device : torch.device + Device to run training on. + + Returns + ------- + float + Average loss for the epoch. + + Notes + ----- + Loss includes both prediction error and load balancing term. + The loss function is computed by precision_focused_loss which is + passed gate_weights for load balancing. + """ + model.train() + total_loss = 0.0 + + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + outputs, gate_weights = model(X_batch) + loss = precision_focused_loss( + outputs, y_batch, gate_weights, model.load_balance_weight + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate(model: nn.Module, dataloader: DataLoader, device: torch.device) -> float: + """ + Validate the mixture of experts model on validation set. + + Parameters + ---------- + model : nn.Module + The neural network model to validate. + dataloader : DataLoader + DataLoader providing validation batches (X, y). + device : torch.device + Device to run validation on. + + Returns + ------- + float + Average loss for the validation set. + """ + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + outputs, gate_weights = model(X_batch) + loss = precision_focused_loss( + outputs, y_batch, gate_weights, model.load_balance_weight + ) + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=3e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train mixture of experts model for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads normalization parameters from existing config + 2. Loads training data from cache + 3. Normalizes inputs and outputs to [0, 1] range + 4. Creates PyTorch DataLoaders + 5. Initializes MixtureOfExperts with 4 expert networks + 6. Trains with AdamW optimizer and precision-focused loss + 7. Uses learning rate scheduler (ReduceLROnPlateau) + 8. Implements early stopping based on validation loss + 9. Exports best model to ONNX format + 10. Logs all metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("Mixture of Experts: xyY → Munsell") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + # Load training data + LOGGER.info("") + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use hardcoded ranges covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MixtureOfExperts(num_experts=4, num_residual_blocks=2).to(device) + LOGGER.info("") + LOGGER.info("Mixture of Experts architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "mixture_of_experts") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "mixture_of_experts", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, device) + val_loss = validate(model, val_loader, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "mixture_of_experts_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX (simplified - outputs only prediction, not gate weights) + LOGGER.info("") + LOGGER.info("Exporting to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create wrapper for ONNX export (only return prediction) + class MoEWrapper(nn.Module): + def __init__(self, moe_model: nn.Module) -> None: + super().__init__() + self.moe_model = moe_model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output, _ = self.moe_model(x) + return output + + wrapped_model = MoEWrapper(model).to(device) + wrapped_model.eval() + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "mixture_of_experts.onnx" + torch.onnx.export( + wrapped_model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={ + "xyY": {0: "batch_size"}, + "munsell_spec": {0: "batch_size"}, + }, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "mixture_of_experts_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_mlp.py b/learning_munsell/training/from_xyY/train_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1df014888fdcdf8a3fc6bb70019b06b1c939fe --- /dev/null +++ b/learning_munsell/training/from_xyY/train_mlp.py @@ -0,0 +1,269 @@ +""" +Train ML model for xyY to Munsell conversion. + +This script trains a compact MLP/DNN model with architecture: +3 inputs → [64, 128, 128, 64] hidden layers → 4 outputs + +Target: < 1e-7 accuracy compared to iterative algorithm +""" + +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=200, help="Maximum training epochs.") +@click.option("--batch-size", default=1024, help="Training batch size.") +@click.option("--lr", default=5e-4, help="Learning rate.") +@click.option("--patience", default=20, help="Early stopping patience.") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the MLPToMunsell model for xyY to Munsell conversion. + + Parameters + ---------- + epochs : int + Maximum number of training epochs. + batch_size : int + Training batch size. + lr : float + Learning rate for AdamW optimizer. + patience : int + Early stopping patience (epochs without improvement). + + Notes + ----- + The training pipeline: + 1. Loads training data from cache + 2. Normalizes Munsell outputs to [0, 1] range + 3. Trains compact MLP model (3 → [64, 128, 128, 64] → 4) + 4. Uses weighted MSE loss function + 5. Learning rate scheduling with ReduceLROnPlateau + 6. Early stopping based on validation loss + 7. Exports model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + LOGGER.info("=" * 80) + LOGGER.info("ML-Based xyY to Munsell Conversion: Model Training") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + LOGGER.error("Please run 01_generate_training_data.py first") + return + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + # Note: Invalid samples (outside Munsell gamut) are also stored in the cache + # Available as: data['xyY_all'], data['munsell_all'], data['valid_mask'] + # These can be used for future enhancements like: + # - Adversarial training to avoid extrapolation + # - Gamut-aware loss functions + # - Uncertainty estimation at boundaries + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use hardcoded ranges covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + # Larger batch size for larger dataset (500K samples) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup - lower learning rate for larger model + optimizer = optim.Adam(model.parameters(), lr=lr) + # Use weighted MSE with default weights + weights = torch.tensor([1.0, 1.0, 2.0, 0.5]) + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "mlp") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log hyperparameters + mlflow.log_params( + { + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "optimizer": "Adam", + "criterion": "weighted_mse_loss", + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Log to MLflow + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "mlp_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input + dummy_input = torch.randn(1, 3).to(device) + + # Export + onnx_file = model_directory / "mlp.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "mlp_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + + # Log model + mlflow.pytorch.log_model(model, "model") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_mlp_attention.py b/learning_munsell/training/from_xyY/train_mlp_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c378829d04167b6b45c8d09666e020394ea0f592 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_mlp_attention.py @@ -0,0 +1,460 @@ +""" +Train MLP + Self-Attention model for xyY to Munsell conversion. + +Option 1: MLP backbone with multi-head self-attention layers +- Input: 3 features (xyY) +- Architecture: 3 -> 512 -> 1024 + [Attention + ResBlock] x 4 -> 512 -> 4 +- Output: 4 features (hue, value, chroma, code) +""" + +import logging +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ResidualBlock +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class MultiHeadSelfAttention(nn.Module): + """ + Multi-head self-attention layer for feature interaction. + + Implements scaled dot-product attention with multiple heads to capture + different aspects of feature relationships. + + Parameters + ---------- + dim + Input and output feature dimension. + num_heads + Number of attention heads. Must divide ``dim`` evenly. + + Attributes + ---------- + query + Linear projection for query vectors. + key + Linear projection for key vectors. + value + Linear projection for value vectors. + out + Output projection after attention. + scale + Scaling factor (1/sqrt(head_dim)) for dot-product attention. + """ + + def __init__(self, dim: int, num_heads: int = 4) -> None: + super().__init__() + self.num_heads = num_heads + self.dim = dim + self.head_dim = dim // num_heads + + assert dim % num_heads == 0, "dim must be divisible by num_heads" # noqa: S101 + + self.query = nn.Linear(dim, dim) + self.key = nn.Linear(dim, dim) + self.value = nn.Linear(dim, dim) + self.out = nn.Linear(dim, dim) + + self.scale = self.head_dim**-0.5 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply multi-head self-attention. + + Parameters + ---------- + x + Input tensor of shape ``(batch_size, dim)``. + + Returns + ------- + torch.Tensor + Output tensor of shape ``(batch_size, dim)`` with attention applied. + """ + batch_size = x.size(0) + + # Linear projections + Q = self.query(x).view(batch_size, self.num_heads, self.head_dim) + K = self.key(x).view(batch_size, self.num_heads, self.head_dim) + V = self.value(x).view(batch_size, self.num_heads, self.head_dim) + + # Scaled dot-product attention + attn_weights = torch.softmax( + torch.matmul(Q, K.transpose(-2, -1)) * self.scale, dim=-1 + ) + + # Apply attention to values + attn_output = torch.matmul(attn_weights, V) + + # Concatenate heads and project + attn_output = attn_output.view(batch_size, self.dim) + return self.out(attn_output) + + +class AttentionResBlock(nn.Module): + """ + Combined attention and residual block. + + Applies self-attention followed by a residual MLP block, each with + batch normalization and skip connections. + + Parameters + ---------- + dim + Input and output feature dimension. + num_heads + Number of attention heads for the self-attention layer. + + Attributes + ---------- + attention + Multi-head self-attention layer. + norm1 + Batch normalization after attention. + residual + Residual MLP block. + norm2 + Batch normalization after residual block. + """ + + def __init__(self, dim: int, num_heads: int = 4) -> None: + super().__init__() + self.attention = MultiHeadSelfAttention(dim, num_heads) + self.norm1 = nn.BatchNorm1d(dim) + self.residual = ResidualBlock(dim) + self.norm2 = nn.BatchNorm1d(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply attention and residual transformations. + + Parameters + ---------- + x + Input tensor of shape ``(batch_size, dim)``. + + Returns + ------- + torch.Tensor + Output tensor of shape ``(batch_size, dim)``. + """ + # Attention with residual + attn_out = self.norm1(x + self.attention(x)) + # ResBlock with residual + return self.norm2(self.residual(attn_out)) + + +class MLPAttention(nn.Module): + """ + MLP with self-attention for xyY to Munsell conversion. + + Architecture: + - Input: 3 features (xyY normalized to [0, 1]) + - Encoder: 3 -> 512 -> 1024 + - Attention-ResBlocks at 1024-dim (configurable count) + - Decoder: 1024 -> 512 -> 4 + - Output: 4 features (hue, value, chroma, code normalized) + + Parameters + ---------- + num_blocks + Number of attention-residual blocks in the middle. + num_heads + Number of attention heads in each attention layer. + + Attributes + ---------- + encoder + MLP that projects 3D xyY input to 1024D feature space. + blocks + List of AttentionResBlock modules. + decoder + MLP that projects 1024D features to 4D Munsell output. + """ + + def __init__(self, num_blocks: int = 4, num_heads: int = 4) -> None: + super().__init__() + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(3, 512), + nn.GELU(), + nn.BatchNorm1d(512), + nn.Linear(512, 1024), + nn.GELU(), + nn.BatchNorm1d(1024), + ) + + # Attention-ResBlocks + self.blocks = nn.ModuleList( + [AttentionResBlock(1024, num_heads) for _ in range(num_blocks)] + ) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(1024, 512), + nn.GELU(), + nn.BatchNorm1d(512), + nn.Linear(512, 4), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Predict Munsell specification from xyY input. + + Parameters + ---------- + x + Input tensor of shape ``(batch_size, 3)`` containing normalized + xyY values. + + Returns + ------- + torch.Tensor + Output tensor of shape ``(batch_size, 4)`` containing normalized + Munsell specification [hue, value, chroma, code]. + """ + # Encode + x = self.encoder(x) + + # Attention-ResBlocks + for block in self.blocks: + x = block(x) + + # Decode + return self.decoder(x) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=3e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train MLP + Self-Attention model for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads normalization parameters and training data from disk + 2. Normalizes inputs (xyY) and outputs (Munsell specification) to [0, 1] + 3. Creates MLPAttention model (4 blocks, 4 attention heads) + 4. Trains with precision-focused loss (MSE + MAE + log + Huber) + 5. Uses AdamW optimizer with ReduceLROnPlateau scheduler + 6. Applies early stopping based on validation loss (patience=20) + 7. Exports best model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + LOGGER.info("=" * 80) + LOGGER.info("MLP + Self-Attention: xyY → Munsell") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + # Load training data + LOGGER.info("") + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MLPAttention(num_blocks=4, num_heads=4).to(device) + LOGGER.info("") + LOGGER.info("MLP + Attention architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + criterion = precision_focused_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "mlp_attention") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log hyperparameters + mlflow.log_params( + { + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "weight_decay": 1e-5, + "optimizer": "AdamW", + "scheduler": "ReduceLROnPlateau", + "criterion": "precision_focused_loss", + "patience": patience, + "total_params": total_params, + "num_blocks": 4, + "num_heads": 4, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "mlp_attention_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "mlp_attention.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={ + "xyY": {0: "batch_size"}, + "munsell_spec": {0: "batch_size"}, + }, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "mlp_attention_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + + # Log model + mlflow.pytorch.log_model(model, "model") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_mlp_error_predictor.py b/learning_munsell/training/from_xyY/train_mlp_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..707d2c8b98e8631f2bcab204cdd53c1e937115c2 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_mlp_error_predictor.py @@ -0,0 +1,457 @@ +""" +Train error predictor with advanced MLP architecture. + +Architecture features: +- Larger capacity: 7 → 256 → 512 → 512 → 256 → 4 +- Residual connections (MLP-style) for better gradient flow +- Modern activation functions (GELU instead of ReLU) +- Precision-focused loss function + +Generic error predictor that can work with any base model. +""" + +import logging +from pathlib import Path +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ResidualBlock +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_munsell, normalize_xyY +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + +# Note: This script has a custom ErrorPredictorMLP architecture +# so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor from shared modules. + + +class ErrorPredictorMLP(nn.Module): + """ + Advanced error predictor with residual connections. + + This model implements a two-stage architecture for Munsell color prediction: + 1. Base model makes initial predictions from xyY coordinates + 2. Error predictor learns residual corrections to improve base predictions + + The error predictor uses MLP-style residual blocks for better gradient + flow and deeper representations. It takes both the input xyY coordinates + and the base model's predictions to predict the error that should be added + to the base predictions. + + Architecture: + - Input: 7 features (xyY_norm + base_pred_norm) + - Encoder: 7 → 256 → 512 + - Residual blocks at 512-dim + - Decoder: 512 → 256 → 128 → 4 + - Uses GELU activations and residual connections + + Parameters + ---------- + num_residual_blocks : int, optional + Number of residual blocks to use in the middle of the network. + Default is 3. + + Attributes + ---------- + encoder : nn.Sequential + Encoder network that maps 7D input to 512D representation. + residual_blocks : nn.ModuleList + List of residual blocks for deep feature extraction. + decoder : nn.Sequential + Decoder network that maps 512D representation to 4D error prediction. + """ + + def __init__(self, num_residual_blocks: int = 3) -> None: + super().__init__() + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(7, 256), + nn.GELU(), + nn.BatchNorm1d(256), + nn.Linear(256, 512), + nn.GELU(), + nn.BatchNorm1d(512), + ) + + # Residual blocks + self.residual_blocks = nn.ModuleList( + [ResidualBlock(512) for _ in range(num_residual_blocks)] + ) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(512, 256), + nn.GELU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.GELU(), + nn.BatchNorm1d(128), + nn.Linear(128, 4), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the error predictor. + + Parameters + ---------- + x : Tensor + Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7). + + Returns + ------- + Tensor + Predicted error correction of shape (batch_size, 4). + """ + # Encode + x = self.encoder(x) + + # Residual blocks + for block in self.residual_blocks: + x = block(x) + + # Decode + return self.decoder(x) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the base ONNX model and its normalization parameters. + + The base model is the first stage of the two-stage architecture that makes + initial predictions from xyY coordinates to Munsell specifications. + + Parameters + ---------- + model_path : Path + Path to the ONNX model file. + params_path : Path + Path to the .npz file containing input and output normalization parameters. + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session for the base model. + input_params : dict + Dictionary containing input normalization ranges (x_range, y_range, Y_range). + output_params : dict + Dictionary containing output normalization ranges (hue_range, value_range, + chroma_range, code_range). + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +@click.command() +@click.option( + "--base-model", + type=click.Path(exists=True, path_type=Path), + help="Path to base model ONNX file", +) +@click.option( + "--params", + type=click.Path(exists=True, path_type=Path), + help="Path to normalization params file", +) +@click.option( + "--epochs", + type=int, + default=200, + help="Number of training epochs", +) +@click.option( + "--batch-size", + type=int, + default=1024, + help="Batch size for training", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate", +) +@click.option( + "--patience", + type=int, + default=20, + help="Patience for early stopping", +) +def main( + base_model: Path | None, + params: Path | None, + epochs: int, + batch_size: int, + lr: float, + patience: int, +) -> None: + """ + Train error predictor with advanced MLP architecture. + + Parameters + ---------- + base_model : Path or None + Path to the base model ONNX file. If None, uses default path. + params : Path or None + Path to normalization parameters .npz file. If None, uses default path. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses precision-focused loss function + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("Error Predictor: MLP + GELU + Precision Loss") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = base_model + params_path = params + cache_file = data_dir / "training_data.npz" + + # Extract base model name for error predictor naming + base_model_name = ( + base_model_path.stem if base_model_path else "xyY_to_munsell_specification" + ) + + # Load base model + LOGGER.info("") + LOGGER.info("Loading base model from %s...", base_model_path) + base_session, input_params, output_params = load_base_model( + base_model_path, params_path + ) + + # Load training data + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + # Base predictions (normalized) + base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0] + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0] + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize error predictor model with MLP architecture + model = ErrorPredictorMLP(num_residual_blocks=3).to(device) + LOGGER.info("") + LOGGER.info("Error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup with precision-focused loss + LOGGER.info("") + LOGGER.info("Using precision-focused loss function:") + LOGGER.info(" - MSE (weight: 1.0)") + LOGGER.info(" - MAE (weight: 0.5)") + LOGGER.info(" - Log penalty for small errors (weight: 0.3)") + LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + criterion = precision_focused_loss + + # MLflow setup + model_name = f"{base_model_name}_error_predictor" + run_name = setup_mlflow_experiment("from_xyY", model_name) + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": model_name, + "base_model": base_model_name, + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Update learning rate + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / f"{base_model_name}_error_predictor_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting error predictor to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input (xyY_norm + base_pred_norm = 7 inputs) + dummy_input = torch.randn(1, 7).to(device) + + # Export + onnx_file = model_directory / f"{base_model_name}_error_predictor.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("Error predictor ONNX model saved to: %s", onnx_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_mlp_gamma.py b/learning_munsell/training/from_xyY/train_mlp_gamma.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c3fc3b1296f1efa31b31a25e0de533a14db355 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_mlp_gamma.py @@ -0,0 +1,297 @@ +""" +Train ML model for xyY to Munsell conversion with gamma-corrected Y. + +Experiment: Apply gamma 2.33 to Y before normalization to better align +with perceptual lightness (Munsell Value scale is perceptually uniform). +""" + +import logging +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + +# Gamma value for Y transformation +GAMMA = 2.33 + + +def normalize_inputs( + X: NDArray, gamma: float = GAMMA +) -> tuple[NDArray, dict[str, Any]]: + """ + Normalize xyY inputs to [0, 1] range with gamma correction on Y. + + Parameters + ---------- + X : ndarray + xyY values of shape (n, 3) where columns are [x, y, Y]. + gamma : float + Gamma value to apply to Y component. + + Returns + ------- + ndarray + Normalized values with gamma-corrected Y, dtype float32. + dict + Normalization parameters including gamma value. + """ + # Typical ranges for xyY + x_range = (0.0, 1.0) + y_range = (0.0, 1.0) + Y_range = (0.0, 1.0) + + X_norm = X.copy() + X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0]) + X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0]) + + # Normalize Y first, then apply gamma + Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0]) + # Clip to avoid numerical issues with negative values + Y_normalized = np.clip(Y_normalized, 0, 1) + # Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light + X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma) + + params = { + "x_range": x_range, + "y_range": y_range, + "Y_range": Y_range, + "gamma": gamma, + } + + return X_norm, params + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=5e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train MLP model with gamma-corrected Y input. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from cache + 2. Normalizes inputs with gamma correction (gamma=2.33) on Y + 3. Normalizes Munsell outputs to [0, 1] range + 4. Trains MLP with weighted MSE loss + 5. Uses early stopping based on validation loss + 6. Exports best model to ONNX format + 7. Logs metrics and artifacts to MLflow + + The gamma correction on Y aligns with perceptual lightness. The gamma + transformation spreads dark values and compresses light values, matching + human lightness perception and the perceptually uniform Munsell Value scale. + """ + + LOGGER.info("=" * 80) + LOGGER.info("ML-Based xyY to Munsell Conversion: Gamma Experiment") + LOGGER.info("Gamma = %.2f applied to Y component", GAMMA) + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + LOGGER.error("Please run 01_generate_training_data.py first") + return + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize data with gamma correction + X_train_norm, input_params = normalize_inputs(X_train, gamma=GAMMA) + X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA) + + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + LOGGER.info("") + LOGGER.info("Input normalization with gamma=%.2f:", GAMMA) + LOGGER.info(" Y range after gamma: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max()) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_norm) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val_norm) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + # Component weights: emphasize chroma (2.0), de-emphasize code (0.5) + weights = torch.tensor([1.0, 1.0, 2.0, 0.5]) + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", f"mlp_gamma_{GAMMA}") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "optimizer": "Adam", + "criterion": "weighted_mse_loss", + "patience": patience, + "total_params": total_params, + "gamma": GAMMA, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "mlp_gamma_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "input_params": input_params, + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "mlp_gamma.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY_gamma"], + output_names=["munsell_spec"], + dynamic_axes={"xyY_gamma": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + # Save normalization parameters (including gamma) + params_file = model_directory / "mlp_gamma_normalization_params.npz" + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py b/learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..8de8baf3b5983547aa2613ef2cff2fd05008b2e3 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py @@ -0,0 +1,411 @@ +""" +Train second-stage error predictor for 3-stage model. + +Architecture: Multi-Head + Multi-Error Predictor + Multi-Error Predictor +- Stage 1: Multi-Head base model (existing) +- Stage 2: First error predictor (existing) +- Stage 3: Second error predictor (this script) - learns residuals from stage 2 + +The second error predictor has the same architecture as the first but learns +the remaining errors after the first error correction is applied. +""" + +import logging +from pathlib import Path +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ( + ComponentErrorPredictor, + MultiHeadErrorPredictorToMunsell, +) +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_munsell, normalize_xyY +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option( + "--base-model", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to Multi-Head base model ONNX file", +) +@click.option( + "--first-error-predictor", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to first error predictor ONNX file", +) +@click.option( + "--params", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to normalization params file", +) +@click.option( + "--epochs", + type=int, + default=300, + help="Number of training epochs (default: 300)", +) +@click.option( + "--batch-size", + type=int, + default=2048, + help="Batch size for training (default: 2048)", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate (default: 3e-4)", +) +@click.option( + "--patience", + type=int, + default=30, + help="Early stopping patience (default: 30)", +) +def main( + base_model: Path | None, + first_error_predictor: Path | None, + params: Path | None, + epochs: int, + batch_size: int, + lr: float, + patience: int, +) -> None: + """ + Train the second-stage error predictor for the 3-stage model. + + This script trains the third stage of a 3-stage model: + - Stage 1: Multi-Head base model (pre-trained) + - Stage 2: First error predictor (pre-trained) + - Stage 3: Second error predictor (trained by this script) + + The second error predictor learns the residual errors remaining after + the first error correction is applied, further refining the predictions. + + Parameters + ---------- + base_model : Path, optional + Path to the Multi-Head base model ONNX file. + Default: models/from_xyY/multi_head_large.onnx + first_error_predictor : Path, optional + Path to the first error predictor ONNX file. + Default: models/from_xyY/multi_head_multi_error_predictor_large.onnx + params : Path, optional + Path to the normalization parameters file. + Default: models/from_xyY/multi_head_large_normalization_params.npz + + Notes + ----- + The training pipeline: + 1. Loads pre-trained Stage 1 and Stage 2 models + 2. Generates Stage 2 predictions (base + first error correction) + 3. Computes remaining residual errors + 4. Trains Stage 3 error predictor on these residuals + 5. Exports the model to ONNX format + 6. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Second Error Predictor: 3-Stage Model Training") + LOGGER.info("Multi-Head + Multi-Error Predictor + Multi-Error Predictor") + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + if base_model is None: + base_model = model_directory / "multi_head_large.onnx" + if first_error_predictor is None: + first_error_predictor = model_directory / "multi_head_multi_error_predictor_large.onnx" + if params is None: + params = model_directory / "multi_head_large_normalization_params.npz" + + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + return + + if not base_model.exists(): + LOGGER.error("Error: Base model not found at %s", base_model) + return + + if not first_error_predictor.exists(): + LOGGER.error("Error: First error predictor not found at %s", first_error_predictor) + return + + # Load models + LOGGER.info("") + LOGGER.info("Loading Stage 1: Multi-Head base model from %s...", base_model) + base_session = ort.InferenceSession(str(base_model)) + + LOGGER.info("Loading Stage 2: First error predictor from %s...", first_error_predictor) + error_predictor_session = ort.InferenceSession(str(first_error_predictor)) + + # Load normalization params + params_data = np.load(params, allow_pickle=True) + input_params = params_data["input_params"].item() + output_params = params_data["output_params"].item() + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate stage 2 predictions (base + first error correction) + LOGGER.info("") + LOGGER.info("Computing Stage 2 predictions (base + first error correction)...") + + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + inference_batch_size = 50000 + + # Stage 1: Base model predictions + LOGGER.info(" Stage 1: Base model predictions (training set)...") + base_pred_train = [] + for i in range(0, len(X_train_norm), inference_batch_size): + batch = X_train_norm[i : i + inference_batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train.append(pred) + base_pred_train = np.concatenate(base_pred_train, axis=0) + + LOGGER.info(" Stage 1: Base model predictions (validation set)...") + base_pred_val = [] + for i in range(0, len(X_val_norm), inference_batch_size): + batch = X_val_norm[i : i + inference_batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val.append(pred) + base_pred_val = np.concatenate(base_pred_val, axis=0) + + # Stage 2: First error predictor corrections + LOGGER.info(" Stage 2: First error predictor corrections (training set)...") + combined_train = np.concatenate([X_train_norm, base_pred_train], axis=1).astype(np.float32) + error_correction_train = [] + for i in range(0, len(combined_train), inference_batch_size): + batch = combined_train[i : i + inference_batch_size] + correction = error_predictor_session.run(None, {"combined_input": batch})[0] + error_correction_train.append(correction) + error_correction_train = np.concatenate(error_correction_train, axis=0) + + LOGGER.info(" Stage 2: First error predictor corrections (validation set)...") + combined_val = np.concatenate([X_val_norm, base_pred_val], axis=1).astype(np.float32) + error_correction_val = [] + for i in range(0, len(combined_val), inference_batch_size): + batch = combined_val[i : i + inference_batch_size] + correction = error_predictor_session.run(None, {"combined_input": batch})[0] + error_correction_val.append(correction) + error_correction_val = np.concatenate(error_correction_val, axis=0) + + # Stage 2 predictions (base + first error correction) + stage2_pred_train = base_pred_train + error_correction_train + stage2_pred_val = base_pred_val + error_correction_val + + # Compute remaining errors for stage 3 + error_train = y_train_norm - stage2_pred_train + error_val = y_val_norm - stage2_pred_val + + # Statistics + LOGGER.info("") + LOGGER.info("Stage 2 prediction error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Compare with stage 1 errors + stage1_error_train = y_train_norm - base_pred_train + LOGGER.info("") + LOGGER.info("Stage 1 (base only) error statistics for comparison:") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(stage1_error_train))) + LOGGER.info(" Std of error: %.6f", np.std(stage1_error_train)) + + error_reduction = ( + (np.mean(np.abs(stage1_error_train)) - np.mean(np.abs(error_train))) + / np.mean(np.abs(stage1_error_train)) + * 100 + ) + LOGGER.info("") + LOGGER.info("Stage 2 error reduction vs Stage 1: %.1f%%", error_reduction) + + # Create combined input for stage 3: [xyY_norm, stage2_pred_norm] + X_train_combined = np.concatenate([X_train_norm, stage2_pred_train], axis=1) + X_val_combined = np.concatenate([X_val_norm, stage2_pred_val], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize second error predictor (same architecture as first) + model = MultiHeadErrorPredictorToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Stage 3: Second error predictor architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + criterion = precision_focused_loss + + run_name = setup_mlflow_experiment("from_xyY", "multi_head_3stage_error_predictor") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting Stage 3 training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_3stage_error_predictor", + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "weight_decay": 1e-5, + "optimizer": "AdamW", + "scheduler": "ReduceLROnPlateau", + "criterion": "precision_focused_loss", + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "stage2_error_reduction_pct": error_reduction, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_head_3stage_error_predictor_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Stage 3 error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 7).to(device) + + onnx_file = model_directory / "multi_head_3stage_error_predictor.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + LOGGER.info("Stage 3 error predictor ONNX model saved to: %s", onnx_file) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_circular.py b/learning_munsell/training/from_xyY/train_multi_head_circular.py new file mode 100644 index 0000000000000000000000000000000000000000..6987e534fd8a1f8741ed37c6ffc1d03cbf748034 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_circular.py @@ -0,0 +1,479 @@ +""" +Train Multi-Head model with circular hue loss for xyY to Munsell conversion. + +This version uses circular loss for the hue component (which wraps from 0-10) +to avoid penalizing predictions near the boundary. + +Key Difference from Standard Training: +- Uses munsell_component_loss() which applies circular MSE for hue +- and regular MSE for value/chroma/code components +""" + +from __future__ import annotations + +import copy +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.training.from_xyY.hyperparameter_search_multi_head import ( + MultiHeadParametric, +) + +LOGGER = logging.getLogger(__name__) + + +def circular_mse_loss( + pred_hue: torch.Tensor, target_hue: torch.Tensor, hue_range: float = 1.0 +) -> torch.Tensor: + """ + Circular MSE loss for hue component (normalized 0-1). + + Parameters + ---------- + pred_hue : Tensor + Predicted hue values (normalized 0-1) + target_hue : Tensor + Target hue values (normalized 0-1) + hue_range : float + Range of hue values (1.0 for normalized) + + Returns + ------- + Tensor + Circular MSE loss + """ + diff = torch.abs(pred_hue - target_hue) + circular_diff = torch.min(diff, hue_range - diff) + return torch.mean(circular_diff**2) + + +def munsell_component_loss( + pred: torch.Tensor, target: torch.Tensor, hue_range: float = 1.0 +) -> torch.Tensor: + """ + Component-wise loss for Munsell predictions. + + Uses circular MSE for hue (component 0) and regular MSE + for value, chroma, code (components 1-3). + + Parameters + ---------- + pred : Tensor + Predictions [hue, value, chroma, code] (shape: [batch, 4]) + target : Tensor + Ground truth [hue, value, chroma, code] (shape: [batch, 4]) + hue_range : float + Range of normalized hue values (default 1.0) + + Returns + ------- + Tensor + Combined loss + """ + hue_loss = circular_mse_loss(pred[:, 0], target[:, 0], hue_range) + other_loss = nn.functional.mse_loss(pred[:, 1:], target[:, 1:]) + return hue_loss + other_loss + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=512, help="Batch size for training") +@click.option("--lr", default=0.000837, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +def main( + epochs: int, + batch_size: int, + lr: float, + patience: int, + encoder_width: float = 0.75, + head_width: float = 1.5, + chroma_head_width: float = 1.5, + dropout: float = 0.0, + weight_decay: float = 0.000013, +) -> tuple[MultiHeadParametric, float]: + """ + Train Multi-Head model with circular hue loss. + + This script uses circular loss for the hue component (which wraps from + 0-10) to avoid penalizing predictions near the boundary. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. + batch_size : int, optional + Training batch size. + lr : float, optional + Learning rate for AdamW optimizer. + encoder_width : float, optional + Width multiplier for the shared encoder. + head_width : float, optional + Width multiplier for hue, value, and code heads. + chroma_head_width : float, optional + Width multiplier for chroma head (typically larger). + dropout : float, optional + Dropout rate for regularization. + weight_decay : float, optional + Weight decay for AdamW optimizer. + + Returns + ------- + model : MultiHeadParametric + Trained model with best validation loss weights. + best_val_loss : float + Best validation loss achieved during training. + + Notes + ----- + The training pipeline: + 1. Loads training data from cache + 2. Normalizes outputs to [0, 1] range + 3. Trains with circular MSE for hue and regular MSE for other components + 4. Uses CosineAnnealingLR scheduler + 5. Early stopping based on validation loss + 6. Exports model to ONNX format + 7. Logs metrics and artifacts to MLflow + + The circular loss experiment showed that while mathematically correct, + the circular distance creates gradient discontinuities that harm + optimization. This model is included for comparison purposes. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Training Multi-Head (Circular Hue Loss) for xyY to Munsell conversion") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Using Circular Loss for Hue Component") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Hyperparameters:") + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" encoder_width: %.2f", encoder_width) + LOGGER.info(" head_width: %.2f", head_width) + LOGGER.info(" chroma_head_width: %.2f", chroma_head_width) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info(" weight_decay: %.6f", weight_decay) + LOGGER.info("") + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load data from cache + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Training samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs already in [0, 1] range) + # Use shared normalization parameters covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to tensors + X_train_t = torch.from_numpy(X_train).float() + y_train_t = torch.from_numpy(y_train_norm).float() + X_val_t = torch.from_numpy(X_val).float() + y_val_t = torch.from_numpy(y_val_norm).float() + + train_loader = DataLoader( + TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False + ) + + # Create model + model = MultiHeadParametric( + encoder_width=encoder_width, + head_width=head_width, + chroma_head_width=chroma_head_width, + dropout=dropout, + ).to(device) + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("") + LOGGER.info("Model parameters: %s", f"{total_params:,}") + + encoder_params = sum(p.numel() for p in model.encoder.parameters()) + hue_params = sum(p.numel() for p in model.hue_head.parameters()) + value_params = sum(p.numel() for p in model.value_head.parameters()) + chroma_params = sum(p.numel() for p in model.chroma_head.parameters()) + code_params = sum(p.numel() for p in model.code_head.parameters()) + + LOGGER.info(" - Shared encoder (%.2fx): %s", encoder_width, f"{encoder_params:,}") + LOGGER.info(" - Hue head (%.2fx): %s", head_width, f"{hue_params:,}") + LOGGER.info(" - Value head (%.2fx): %s", head_width, f"{value_params:,}") + LOGGER.info(" - Chroma head (%.2fx): %s", chroma_head_width, f"{chroma_params:,}") + LOGGER.info(" - Code head (%.2fx): %s", head_width, f"{code_params:,}") + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head_circular") + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training with circular hue loss...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_circular", + "encoder_width": encoder_width, + "head_width": head_width, + "chroma_head_width": chroma_head_width, + "dropout": dropout, + "learning_rate": lr, + "batch_size": batch_size, + "weight_decay": weight_decay, + "epochs": epochs, + "patience": patience, + "total_params": total_params, + "loss_type": "circular_hue", + } + ) + + for epoch in range(epochs): + # Training + model.train() + train_loss = 0.0 + for X_batch, y_batch in train_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + + optimizer.zero_grad() + pred = model(X_batch) + + # Use circular loss for hue component + loss = munsell_component_loss(pred, y_batch, hue_range=1.0) + + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + + train_loss /= len(X_train_t) + scheduler.step() + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + pred = model(X_batch) + val_loss += munsell_component_loss( + pred, y_batch, hue_range=1.0 + ).item() * len(X_batch) + val_loss /= len(X_val_t) + + # Per-component MAE (denormalized for interpretability) + with torch.no_grad(): + pred_val = model(X_val_t.to(device)).cpu() + # Denormalize predictions and ground truth + pred_denorm = pred_val.numpy() + hue_min, hue_max = output_params["hue_range"] + value_min, value_max = output_params["value_range"] + chroma_min, chroma_max = output_params["chroma_range"] + code_min, code_max = output_params["code_range"] + + pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min # hue + pred_denorm[:, 1] = pred_val[:, 1].numpy() * (value_max - value_min) + value_min # value + pred_denorm[:, 2] = pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min # chroma + pred_denorm[:, 3] = pred_val[:, 3].numpy() * (code_max - code_min) + code_min # code + + y_denorm = y_val_norm.copy() + y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min + y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min + y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min + y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min + + mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0) + + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "mae_hue": mae[0], + "mae_value": mae[1], + "mae_chroma": mae[2], + "mae_code": mae[3], + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "MAE: hue=%.4f, value=%.4f, chroma=%.4f, code=%.4f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + mae[3], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + model.load_state_dict(best_state) + + # Final evaluation + model.eval() + with torch.no_grad(): + pred_val = model(X_val_t.to(device)).cpu() + pred_denorm = pred_val.numpy() + hue_min, hue_max = output_params["hue_range"] + value_min, value_max = output_params["value_range"] + chroma_min, chroma_max = output_params["chroma_range"] + code_min, code_max = output_params["code_range"] + + pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min + pred_denorm[:, 1] = pred_val[:, 1].numpy() * (value_max - value_min) + value_min + pred_denorm[:, 2] = pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min + pred_denorm[:, 3] = pred_val[:, 3].numpy() * (code_max - code_min) + code_min + + y_denorm = y_val_norm.copy() + y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min + y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min + y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min + y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min + + mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_mae_hue": mae[0], + "final_mae_value": mae[1], + "final_mae_chroma": mae[2], + "final_mae_code": mae[3], + "final_epoch": epoch + 1, + } + ) + + LOGGER.info("") + LOGGER.info("Final Results:") + LOGGER.info(" Best Val Loss: %.6f", best_val_loss) + LOGGER.info(" MAE hue: %.6f", mae[0]) + LOGGER.info(" MAE value: %.6f", mae[1]) + LOGGER.info(" MAE chroma: %.6f", mae[2]) + LOGGER.info(" MAE code: %.6f", mae[3]) + + # Save model + models_dir = PROJECT_ROOT / "models" / "from_xyY" + models_dir.mkdir(exist_ok=True) + + checkpoint_path = models_dir / "multi_head_circular.pth" + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "val_loss": best_val_loss, + "mae": { + "hue": float(mae[0]), + "value": float(mae[1]), + "chroma": float(mae[2]), + "code": float(mae[3]), + }, + "hyperparameters": { + "encoder_width": encoder_width, + "head_width": head_width, + "chroma_head_width": chroma_head_width, + "dropout": dropout, + "lr": lr, + "batch_size": batch_size, + "weight_decay": weight_decay, + }, + "loss_type": "circular_hue", + }, + checkpoint_path, + ) + LOGGER.info("") + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + model.cpu().eval() + dummy_input = torch.randn(1, 3) + onnx_path = models_dir / "multi_head_circular.onnx" + + torch.onnx.export( + model, + dummy_input, + onnx_path, + input_names=["xyY"], # Match other models for comparison compatibility + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch"}, "munsell_spec": {0: "batch"}}, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Save normalization parameters + params_path = models_dir / "multi_head_circular_normalization_params.npz" + np.savez( + params_path, + output_params=output_params, + ) + LOGGER.info("Saved normalization parameters: %s", params_path) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(params_path)) + mlflow.pytorch.log_model(model, "model") + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + return model, best_val_loss + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py b/learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0294ff95c16723037f58200bf89341c7519058 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py @@ -0,0 +1,640 @@ +""" +Train Multi-Head + Cross-Attention Error Predictor for xyY to Munsell conversion. + +This version uses cross-attention between component branches to learn +correlations between errors in different Munsell components. + +Key Features: +- Shared context encoder +- Multi-head cross-attention between components +- Component-specific prediction heads +- Residual connections +""" + +from __future__ import annotations + +import copy +import logging + +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell + +LOGGER = logging.getLogger(__name__) + +# Note: This script has a custom CrossAttentionErrorPredictor architecture +# so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor from shared modules. + + +class CustomMultiheadAttention(nn.Module): + """ + Custom multi-head attention that exports cleanly to ONNX. + + Uses basic operations instead of nn.MultiheadAttention to avoid + reshape issues with dynamic batch sizes during ONNX export. + + Parameters + ---------- + embed_dim : int + Total dimension of the model (must be divisible by num_heads). + num_heads : int + Number of parallel attention heads. + dropout : float, optional + Dropout probability on attention weights. + + Attributes + ---------- + embed_dim : int + Total embedding dimension. + num_heads : int + Number of attention heads. + head_dim : int + Dimension of each attention head (embed_dim // num_heads). + scale : float + Scaling factor for attention scores (head_dim ** -0.5). + q_proj : nn.Linear + Query projection layer. + k_proj : nn.Linear + Key projection layer. + v_proj : nn.Linear + Value projection layer. + out_proj : nn.Linear + Output projection layer. + dropout : nn.Dropout + Dropout layer for attention weights. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + """Initialize the custom multi-head attention module.""" + super().__init__() + + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim**-0.5 + + # Linear projections for Q, K, V + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for self-attention. + + Parameters + ---------- + x : Tensor + Input tensor [batch, seq_len, embed_dim] + + Returns + ------- + Tensor + Output tensor [batch, seq_len, embed_dim] + """ + batch_size, seq_len, embed_dim = x.shape + + # Project to Q, K, V + q = self.q_proj(x) # [batch, seq_len, embed_dim] + k = self.k_proj(x) # [batch, seq_len, embed_dim] + v = self.v_proj(x) # [batch, seq_len, embed_dim] + + # Reshape for multi-head attention: [batch, seq_len, num_heads, head_dim] + # Then transpose to: [batch, num_heads, seq_len, head_dim] + # Use -1 for batch dimension to enable dynamic batch size in ONNX + q = q.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Scaled dot-product attention + # Q @ K^T: [batch, heads, seq, dim] @ [batch, heads, dim, seq] + # -> [batch, heads, seq, seq] + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn_weights = torch.softmax(attn_scores, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Apply attention to values + # [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim] + # -> [batch, num_heads, seq_len, head_dim] + attn_output = torch.matmul(attn_weights, v) + + # Transpose back and reshape: [batch, num_heads, seq_len, head_dim] + # -> [batch, seq_len, num_heads, head_dim] + # -> [batch, seq_len, embed_dim] + # Use -1 for batch dimension to enable dynamic batch size in ONNX + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(-1, seq_len, self.embed_dim) + + # Final projection + output = self.out_proj(attn_output) + + return output + + +class CrossAttentionErrorPredictor(nn.Module): + """ + Error predictor with cross-attention between Munsell components. + + Uses cross-attention to learn correlations between errors in different + Munsell components (hue, value, chroma, code). + + Parameters + ---------- + input_dim : int, optional + Input dimension (7 = xyY_norm + base_pred_norm). + context_dim : int, optional + Dimension of shared context features. + component_dim : int, optional + Dimension of component-specific features. + n_components : int, optional + Number of Munsell components (4). + n_attention_heads : int, optional + Number of attention heads for cross-attention. + dropout : float, optional + Dropout probability. + + Attributes + ---------- + context_encoder : nn.Sequential + Shared encoder: input_dim → 256 → context_dim. + component_encoders : nn.ModuleList + Component-specific encoders: context_dim → component_dim (x4). + cross_attention : CustomMultiheadAttention + Cross-attention module between component features. + attention_norm : nn.LayerNorm + Layer normalization after attention. + component_decoders : nn.ModuleList + Component-specific decoders: component_dim → 128 → 1 (x4). + + Notes + ----- + Architecture: + 1. Shared context encoder: 7 → 256 → 512 + 2. Component-specific encoders: 512 → 256 (x4) + 3. Multi-head cross-attention between components + 4. Residual connection + layer norm + 5. Component-specific decoders: 256 → 128 → 1 + """ + + def __init__( + self, + input_dim: int = 7, + context_dim: int = 512, + component_dim: int = 256, + n_components: int = 4, + n_attention_heads: int = 4, + dropout: float = 0.1, + ) -> None: + """Initialize the cross-attention error predictor.""" + super().__init__() + + self.n_components = n_components + self.component_dim = component_dim + + # Shared context encoder + self.context_encoder = nn.Sequential( + nn.Linear(input_dim, 256), + nn.GELU(), + nn.LayerNorm(256), + nn.Dropout(dropout), + nn.Linear(256, context_dim), + nn.GELU(), + nn.LayerNorm(context_dim), + ) + + # Component-specific encoders + self.component_encoders = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(context_dim, component_dim), + nn.GELU(), + nn.LayerNorm(component_dim), + ) + for _ in range(n_components) + ] + ) + + # Multi-head cross-attention (using custom implementation) + self.cross_attention = CustomMultiheadAttention( + embed_dim=component_dim, + num_heads=n_attention_heads, + dropout=dropout, + ) + + # Layer norm after attention + self.attention_norm = nn.LayerNorm(component_dim) + + # Component-specific decoders + self.component_decoders = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(component_dim, 128), + nn.GELU(), + nn.LayerNorm(128), + nn.Dropout(dropout), + nn.Linear(128, 1), + ) + for _ in range(n_components) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with cross-attention. + + Parameters + ---------- + x : Tensor + Input [xyY_norm (3) + base_pred_norm (4)] = 7 features + + Returns + ------- + Tensor + Predicted errors [hue_err, value_err, chroma_err, code_err] + """ + # Shared context encoding + context = self.context_encoder(x) # [batch, 512] + + # Component-specific encoding + component_features = [] + for encoder in self.component_encoders: + feat = encoder(context) # [batch, 256] + component_features.append(feat) + + # Stack for cross-attention: [batch, 4, 256] + component_stack = torch.stack(component_features, dim=1) + + # Cross-attention between components + attended = self.cross_attention(component_stack) # [batch, 4, 256] + + # Residual connection + layer norm + component_stack = self.attention_norm(component_stack + attended) + + # Component-specific decoding (unrolled for ONNX compatibility) + # Use unbind to split the tensor instead of indexing to preserve batch dimension + components = torch.unbind( + component_stack, dim=1 + ) # Split into 4 tensors of shape [batch, 256] + + # Decode each component explicitly + pred_0 = self.component_decoders[0](components[0]) # [batch, 1] + pred_1 = self.component_decoders[1](components[1]) # [batch, 1] + pred_2 = self.component_decoders[2](components[2]) # [batch, 1] + pred_3 = self.component_decoders[3](components[3]) # [batch, 1] + + # Concatenate along dimension 1 and squeeze + predictions = torch.cat([pred_0, pred_1, pred_2, pred_3], dim=1) # [batch, 4] + + return predictions + + +def train_cross_attention_error_predictor( + epochs: int = 300, + batch_size: int = 1024, + lr: float = 0.0005, + dropout: float = 0.1, + context_dim: int = 512, + component_dim: int = 256, + n_attention_heads: int = 4, +) -> tuple[CrossAttentionErrorPredictor, float]: + """ + Train cross-attention error predictor. + + This model uses cross-attention between component branches to learn + correlations between errors in different Munsell components. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. + batch_size : int, optional + Training batch size. + lr : float, optional + Learning rate for AdamW optimizer. + dropout : float, optional + Dropout rate for regularization. + context_dim : int, optional + Dimension of shared context features. + component_dim : int, optional + Dimension of component-specific features. + n_attention_heads : int, optional + Number of attention heads for cross-attention. + + Returns + ------- + model : CrossAttentionErrorPredictor + Trained model with best validation loss weights. + best_val_loss : float + Best validation loss achieved during training. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained Multi-Head base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains cross-attention error predictor on these residuals + 5. Uses CosineAnnealingLR scheduler + 6. Early stopping based on validation loss + 7. Exports model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Training Multi-Head + Cross-Attention Error Predictor") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Architecture:") + LOGGER.info(" - Shared context encoder: 7 → 256 → %d", context_dim) + LOGGER.info(" - Component encoders: %d → %d (x4)", context_dim, component_dim) + LOGGER.info(" - Cross-attention: %d heads", n_attention_heads) + LOGGER.info(" - Component decoders: %d → 128 → 1 (x4)", component_dim) + LOGGER.info("") + LOGGER.info("Hyperparameters:") + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info("") + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + base_model_path = model_directory / "multi_head.onnx" + params_path = model_directory / "multi_head_normalization_params.npz" + cache_file = data_dir / "training_data.npz" + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-Head base model from %s...", base_model_path) + base_session = ort.InferenceSession(str(base_model_path)) + params = np.load(params_path, allow_pickle=True) + input_params = params["input_params"].item() + output_params = params["output_params"].item() + + # Load training data + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Multi-Head base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0] + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0] + + # Compute errors + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + LOGGER.info("") + LOGGER.info("Base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = CrossAttentionErrorPredictor( + input_dim=7, + context_dim=context_dim, + component_dim=component_dim, + n_attention_heads=n_attention_heads, + dropout=dropout, + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("") + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + context_params = sum(p.numel() for p in model.context_encoder.parameters()) + attention_params = sum(p.numel() for p in model.cross_attention.parameters()) + LOGGER.info(" - Context encoder: %s", f"{context_params:,}") + LOGGER.info(" - Cross-attention: %s", f"{attention_params:,}") + + # Training setup + criterion = nn.MSELoss() + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "cross_attention_error_predictor") + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + best_state = None + patience = 30 + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "cross_attention_error_predictor", + "context_dim": context_dim, + "component_dim": component_dim, + "n_attention_heads": n_attention_heads, + "dropout": dropout, + "learning_rate": lr, + "batch_size": batch_size, + "epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + model.train() + train_loss = 0.0 + for X_batch, y_batch in train_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + + optimizer.zero_grad() + pred = model(X_batch) + loss = criterion(pred, y_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + + train_loss /= len(X_train_t) + scheduler.step() + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + pred = model(X_batch) + val_loss += criterion(pred, y_batch).item() * len(X_batch) + val_loss /= len(X_val_t) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + model.load_state_dict(best_state) + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + LOGGER.info("") + LOGGER.info("Final Results:") + LOGGER.info(" Best Val Loss: %.6f", best_val_loss) + + # Save model + model_directory.mkdir(exist_ok=True) + checkpoint_path = ( + model_directory / "multi_head_cross_attention_error_predictor.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "val_loss": best_val_loss, + "hyperparameters": { + "context_dim": context_dim, + "component_dim": component_dim, + "n_attention_heads": n_attention_heads, + "dropout": dropout, + "lr": lr, + "batch_size": batch_size, + }, + }, + checkpoint_path, + ) + LOGGER.info("") + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting error predictor to ONNX...") + model.eval() + model.cpu() + + dummy_input = torch.randn(1, 7) + onnx_path = model_directory / "multi_head_cross_attention_error_predictor.onnx" + + torch.onnx.export( + model, + dummy_input, + onnx_path, + export_params=True, + opset_version=17, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_path) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + + return model, best_val_loss + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + train_cross_attention_error_predictor( + epochs=300, + batch_size=1024, + lr=0.0005, + dropout=0.1, + context_dim=512, + component_dim=256, + n_attention_heads=4, + ) diff --git a/learning_munsell/training/from_xyY/train_multi_head_gamma.py b/learning_munsell/training/from_xyY/train_multi_head_gamma.py new file mode 100644 index 0000000000000000000000000000000000000000..aae0d5c2c924e9e1ca1c05b12d605d0299e841d3 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_gamma.py @@ -0,0 +1,300 @@ +""" +Train multi-head ML model for xyY to Munsell conversion with gamma-corrected Y. + +Experiment: Apply gamma 2.33 to Y before normalization to better align +with perceptual lightness (Munsell Value scale is perceptually uniform). + +The multi-head architecture has separate heads for each Munsell component, +so gamma correction on Y should primarily benefit Value prediction without +negatively impacting Chroma prediction (unlike the single MLP). +""" + +import logging +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + +# Gamma value for Y transformation +GAMMA = 2.33 + + +def normalize_inputs( + X: NDArray, gamma: float = GAMMA +) -> tuple[NDArray, dict[str, Any]]: + """ + Normalize xyY inputs to [0, 1] range with gamma correction on Y. + + Parameters + ---------- + X : ndarray + xyY values of shape (n, 3) where columns are [x, y, Y]. + gamma : float + Gamma value to apply to Y component. + + Returns + ------- + ndarray + Normalized values with gamma-corrected Y, dtype float32. + dict + Normalization parameters including gamma value. + """ + # xyY chromaticity and luminance ranges (all [0, 1]) + x_range = (0.0, 1.0) + y_range = (0.0, 1.0) + Y_range = (0.0, 1.0) + + X_norm = X.copy() + X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0]) + X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0]) + + # Normalize Y first, then apply gamma + Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0]) + # Clip to avoid numerical issues with negative values + Y_normalized = np.clip(Y_normalized, 0, 1) + # Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light + X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma) + + params = { + "x_range": x_range, + "y_range": y_range, + "Y_range": Y_range, + "gamma": gamma, + } + + return X_norm, params + + + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=5e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the multi-head model with gamma-corrected Y input. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from cache + 2. Normalizes inputs with gamma correction (gamma=2.33) on Y + 3. Normalizes Munsell outputs to [0, 1] range + 4. Trains multi-head MLP with weighted MSE loss + 5. Uses early stopping based on validation loss + 6. Exports best model to ONNX format + 7. Logs metrics and artifacts to MLflow + + The gamma correction on Y aligns with perceptual lightness. The Munsell + Value scale is perceptually uniform, so gamma correction should primarily + benefit Value prediction without negatively impacting Chroma prediction. + """ + + LOGGER.info("=" * 80) + LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Gamma Experiment") + LOGGER.info("Gamma = %.2f applied to Y component", GAMMA) + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + LOGGER.error("Please run 01_generate_training_data.py first") + return + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize data with gamma correction + X_train_norm, input_params = normalize_inputs(X_train, gamma=GAMMA) + X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA) + + # Use shared normalization parameters for Munsell outputs + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + LOGGER.info("") + LOGGER.info("Input normalization with gamma=%.2f:", GAMMA) + LOGGER.info(" Y range after gamma: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max()) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_norm) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val_norm) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiHeadMLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + criterion = weighted_mse_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", f"multi_head_gamma_{GAMMA}") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_gamma", + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "optimizer": "Adam", + "criterion": "weighted_mse_loss", + "patience": patience, + "total_params": total_params, + "gamma": GAMMA, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_head_gamma_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "input_params": input_params, + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "multi_head_gamma.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY_gamma"], + output_names=["munsell_spec"], + dynamic_axes={"xyY_gamma": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + # Save normalization parameters (including gamma) + params_file = model_directory / "multi_head_gamma_normalization_params.npz" + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py b/learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py new file mode 100644 index 0000000000000000000000000000000000000000..795e0bb4207d4d605cd4c844eabab908c4898bba --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py @@ -0,0 +1,605 @@ +""" +Train multi-head ML models with various gamma values to find optimal gamma. + +Sweeps gamma from 1.0 to 3.0 in increments of 0.1 and evaluates each model +on real Munsell colours using Delta-E CIE2000. + +Supports parallel execution with multiple runs per gamma for averaging. +""" + +import logging +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Any + +import numpy as np +import torch +from colour import XYZ_to_Lab, xyY_to_XYZ +from colour.difference import delta_E_CIE2000 +from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL +from colour.notation.munsell import ( + CCS_ILLUMINANT_MUNSELL, + munsell_specification_to_xyY, +) +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + normalize_munsell, +) + +LOGGER = logging.getLogger(__name__) + + +def normalize_inputs(X: NDArray, gamma: float) -> tuple[NDArray, dict[str, Any]]: + """ + Normalize xyY inputs to [0, 1] range with gamma correction on Y. + + Parameters + ---------- + X : ndarray + xyY values of shape (n, 3) where columns are [x, y, Y]. + gamma : float + Gamma value to apply to Y component. + + Returns + ------- + ndarray + Normalized values with gamma-corrected Y, dtype float32. + dict + Normalization parameters including gamma value. + """ + x_range = (0.0, 1.0) + y_range = (0.0, 1.0) + Y_range = (0.0, 1.0) + + X_norm = X.copy() + X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0]) + X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0]) + + Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0]) + Y_normalized = np.clip(Y_normalized, 0, 1) + X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma) + + params = { + "x_range": x_range, + "y_range": y_range, + "Y_range": Y_range, + "gamma": gamma, + } + + return X_norm, params + + +def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray: + """ + Denormalize Munsell output from [0, 1] to original ranges. + + Parameters + ---------- + y_norm : ndarray + Normalized Munsell values in [0, 1] range. + params : dict + Normalization parameters containing range information. + + Returns + ------- + ndarray + Denormalized Munsell values in original ranges. + """ + y = np.copy(y_norm) + y[..., 0] = ( + y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) + + params["hue_range"][0] + ) + y[..., 1] = ( + y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0]) + + params["value_range"][0] + ) + y[..., 2] = ( + y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + + params["chroma_range"][0] + ) + y[..., 3] = ( + y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0]) + + params["code_range"][0] + ) + return y + + +def weighted_mse_loss( + pred: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None +) -> torch.Tensor: + """ + Component-wise weighted MSE loss. + + Parameters + ---------- + pred : Tensor + Predicted Munsell values. + target : Tensor + Ground truth Munsell values. + weights : Tensor, optional + Component weights [w_hue, w_value, w_chroma, w_code]. + + Returns + ------- + Tensor + Weighted mean squared error loss. + """ + if weights is None: + weights = torch.tensor([1.0, 1.0, 3.0, 0.5], device=pred.device) + mse = (pred - target) ** 2 + weighted_mse = mse * weights + return weighted_mse.mean() + + +def clamp_munsell_specification(spec: NDArray) -> NDArray: + """ + Clamp Munsell specification to valid ranges. + + Parameters + ---------- + spec : ndarray + Munsell specification [hue, value, chroma, code]. + + Returns + ------- + ndarray + Clamped Munsell specification within valid ranges. + """ + clamped = np.copy(spec) + clamped[..., 0] = np.clip(spec[..., 0], 0.5, 10.0) + clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0) + clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0) + clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0) + return clamped + + +def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]: + """ + Compute Delta-E CIE2000 for predicted Munsell specifications. + + Parameters + ---------- + pred : ndarray + Predicted Munsell specifications. + reference_Lab : ndarray + Reference CIELAB values for comparison. + + Returns + ------- + list of float + Delta-E CIE2000 values for valid predictions. + + Notes + ----- + Predictions that cannot be converted to valid xyY are skipped. + """ + delta_E_values = [] + for idx in range(len(pred)): + try: + ml_spec = clamp_munsell_specification(pred[idx]) + ml_spec_for_conversion = ml_spec.copy() + ml_spec_for_conversion[3] = round(ml_spec[3]) + ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion) + ml_XYZ = xyY_to_XYZ(ml_xyy) + ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL) + delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab) + delta_E_values.append(delta_E) + except (RuntimeError, ValueError): + continue + return delta_E_values + + +def train_model( + gamma: float, + X_train: NDArray, + y_train: NDArray, + X_val: NDArray, + y_val: NDArray, + device: torch.device, + num_epochs: int = 100, + patience: int = 15, +) -> tuple[nn.Module, dict[str, Any], dict[str, Any], float]: + """ + Train a multi-head model with specified gamma value. + + Parameters + ---------- + gamma : float + Gamma value for Y correction. + X_train : ndarray + Training inputs (xyY values). + y_train : ndarray + Training targets (Munsell specifications). + X_val : ndarray + Validation inputs. + y_val : ndarray + Validation targets. + device : torch.device + Device to run training on. + num_epochs : int, optional + Maximum number of training epochs. Default is 100. + patience : int, optional + Early stopping patience. Default is 15. + + Returns + ------- + nn.Module + Trained model with best validation loss. + dict + Input normalization parameters. + dict + Output normalization parameters. + float + Best validation loss achieved. + """ + # Normalize data + X_train_norm, input_params = normalize_inputs(X_train, gamma=gamma) + X_val_norm, _ = normalize_inputs(X_val, gamma=gamma) + + # Use shared normalization parameters covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to tensors + X_train_t = torch.FloatTensor(X_train_norm) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val_norm) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False) + + # Initialize model + model = MultiHeadMLPToMunsell().to(device) + optimizer = optim.Adam(model.parameters(), lr=5e-4) + criterion = weighted_mse_loss + + best_val_loss = float("inf") + patience_counter = 0 + best_state = None + + for epoch in range(num_epochs): + # Train + model.train() + for X_batch, y_batch in train_loader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Validate + model.eval() + total_val_loss = 0.0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + total_val_loss += loss.item() + val_loss = total_val_loss / len(val_loader) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + best_state = model.state_dict().copy() + else: + patience_counter += 1 + if patience_counter >= patience: + break + + # Load best state + if best_state is not None: + model.load_state_dict(best_state) + + return model, input_params, output_params, best_val_loss + + +def evaluate_on_real_munsell( + model: nn.Module, + input_params: dict[str, Any], + output_params: dict[str, Any], + xyY_array: NDArray, + reference_Lab: NDArray, + device: torch.device, +) -> tuple[float, float]: + """ + Evaluate model on real Munsell colors using Delta-E CIE2000. + + Parameters + ---------- + model : nn.Module + Trained model to evaluate. + input_params : dict + Input normalization parameters. + output_params : dict + Output normalization parameters. + xyY_array : ndarray + Real Munsell xyY values. + reference_Lab : ndarray + Reference CIELAB values for Delta-E computation. + device : torch.device + Device to run evaluation on. + + Returns + ------- + float + Mean Delta-E CIE2000. + float + Median Delta-E CIE2000. + """ + model.eval() + gamma = input_params["gamma"] + + # Normalize inputs + X_norm, _ = normalize_inputs(xyY_array, gamma=gamma) + X_t = torch.FloatTensor(X_norm).to(device) + + # Predict + with torch.no_grad(): + pred_norm = model(X_t).cpu().numpy() + + pred = denormalize_output(pred_norm, output_params) + delta_E_values = compute_delta_e(pred, reference_Lab) + + return np.mean(delta_E_values), np.median(delta_E_values) + + +def run_single_trial( + gamma: float, + run_id: int, + X_train: NDArray, + y_train: NDArray, + X_val: NDArray, + y_val: NDArray, + xyY_array: NDArray, + reference_Lab: NDArray, +) -> dict[str, Any]: + """ + Run a single training trial for a given gamma value. + + Parameters + ---------- + gamma : float + Gamma value for Y correction. + run_id : int + Run identifier for this trial. + X_train : ndarray + Training inputs. + y_train : ndarray + Training targets. + X_val : ndarray + Validation inputs. + y_val : ndarray + Validation targets. + xyY_array : ndarray + Real Munsell xyY values for evaluation. + reference_Lab : ndarray + Reference CIELAB values for Delta-E computation. + + Returns + ------- + dict + Results dictionary containing gamma, run_id, val_loss, + mean_delta_e, and median_delta_e. + + Notes + ----- + Uses CPU to avoid MPS multiprocessing issues. + """ + # Each process uses CPU to avoid MPS multiprocessing issues + device = torch.device("cpu") + + model, input_params, output_params, val_loss = train_model( + gamma=gamma, + X_train=X_train, + y_train=y_train, + X_val=X_val, + y_val=y_val, + device=device, + num_epochs=100, + patience=15, + ) + + mean_delta_e, median_delta_e = evaluate_on_real_munsell( + model, input_params, output_params, xyY_array, reference_Lab, device + ) + + return { + "gamma": gamma, + "run_id": run_id, + "val_loss": val_loss, + "mean_delta_e": mean_delta_e, + "median_delta_e": median_delta_e, + } + + +def main() -> None: + """ + Run gamma sweep experiment to find optimal gamma value. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from cache + 2. Loads real Munsell colors for evaluation + 3. Sweeps gamma values from 1.0 to 3.0 in 0.1 increments + 4. Trains multiple models per gamma value for averaging + 5. Evaluates each model on real Munsell colors using Delta-E CIE2000 + 6. Aggregates results and identifies best gamma value + 7. Saves results to NPZ file for analysis + + Uses parallel execution with ProcessPoolExecutor for efficiency. + Each model is trained with early stopping and evaluated on validation set. + """ + import argparse + + parser = argparse.ArgumentParser(description="Gamma sweep with averaging") + parser.add_argument("--runs", type=int, default=3, help="Number of runs per gamma") + parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers") + args = parser.parse_args() + + num_runs = args.runs + num_workers = args.workers + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Gamma Sweep: Finding Optimal Gamma Value") + LOGGER.info("Testing gamma values from 1.0 to 3.0 in increments of 0.1") + LOGGER.info("Runs per gamma: %d, Parallel workers: %d", num_runs, num_workers) + LOGGER.info("=" * 80) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + return + + LOGGER.info("\nLoading training data...") + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + LOGGER.info("Train samples: %d, Validation samples: %d", len(X_train), len(X_val)) + + # Load real Munsell data for evaluation + LOGGER.info("Loading real Munsell colours for evaluation...") + xyY_values = [] + reference_Lab = [] + + for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL: + try: + xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0]) + XYZ = xyY_to_XYZ(xyY_scaled) + Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL) + xyY_values.append(xyY_scaled) + reference_Lab.append(Lab) + except (RuntimeError, ValueError): + continue + + xyY_array = np.array(xyY_values) + reference_Lab = np.array(reference_Lab) + LOGGER.info("Loaded %d real Munsell colours", len(xyY_array)) + + # Gamma values to test + gamma_values = [round(1.0 + i * 0.1, 1) for i in range(21)] # 1.0 to 3.0 + + # Create all tasks: (gamma, run_id) pairs + tasks = [(gamma, run_id) for gamma in gamma_values for run_id in range(num_runs)] + total_tasks = len(tasks) + + LOGGER.info("\n" + "-" * 80) + LOGGER.info("Starting gamma sweep: %d total tasks (%d gamma values x %d runs)", + total_tasks, len(gamma_values), num_runs) + LOGGER.info("-" * 80) + + all_results = [] + completed = 0 + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = { + executor.submit( + run_single_trial, gamma, run_id, + X_train, y_train, X_val, y_val, xyY_array, reference_Lab + ): (gamma, run_id) + for gamma, run_id in tasks + } + + for future in as_completed(futures): + gamma, run_id = futures[future] + try: + result = future.result() + all_results.append(result) + completed += 1 + LOGGER.info( + "[%3d/%3d] gamma=%.1f run=%d: mean_ΔE=%.4f, median_ΔE=%.4f", + completed, total_tasks, gamma, run_id, + result["mean_delta_e"], result["median_delta_e"] + ) + except Exception as e: + LOGGER.error("Task failed for gamma=%.1f run=%d: %s", gamma, run_id, e) + completed += 1 + + # Aggregate results by gamma (average across runs) + aggregated = {} + for r in all_results: + gamma = r["gamma"] + if gamma not in aggregated: + aggregated[gamma] = {"val_losses": [], "means": [], "medians": []} + aggregated[gamma]["val_losses"].append(r["val_loss"]) + aggregated[gamma]["means"].append(r["mean_delta_e"]) + aggregated[gamma]["medians"].append(r["median_delta_e"]) + + results = [] + for gamma in sorted(aggregated.keys()): + agg = aggregated[gamma] + results.append({ + "gamma": gamma, + "val_loss": np.mean(agg["val_losses"]), + "val_loss_std": np.std(agg["val_losses"]), + "mean_delta_e": np.mean(agg["means"]), + "mean_delta_e_std": np.std(agg["means"]), + "median_delta_e": np.mean(agg["medians"]), + "median_delta_e_std": np.std(agg["medians"]), + "num_runs": len(agg["means"]), + }) + + # Print results + LOGGER.info("\n" + "=" * 80) + LOGGER.info("GAMMA SWEEP RESULTS (averaged over %d runs)", num_runs) + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("%-8s %-14s %-14s %-14s", "Gamma", "Val Loss", "Mean ΔE", "Median ΔE") + LOGGER.info("-" * 50) + + for r in results: + LOGGER.info( + "%-8.1f %-14s %-14s %-14s", + r["gamma"], + f"{r['val_loss']:.6f}±{r['val_loss_std']:.4f}", + f"{r['mean_delta_e']:.4f}±{r['mean_delta_e_std']:.4f}", + f"{r['median_delta_e']:.4f}±{r['median_delta_e_std']:.4f}", + ) + + # Find best by mean Delta-E + best_by_mean = min(results, key=lambda x: x["mean_delta_e"]) + best_by_median = min(results, key=lambda x: x["median_delta_e"]) + + LOGGER.info("") + LOGGER.info("Best gamma by MEAN Delta-E: %.1f (ΔE = %.4f ± %.4f)", + best_by_mean["gamma"], best_by_mean["mean_delta_e"], + best_by_mean["mean_delta_e_std"]) + LOGGER.info("Best gamma by MEDIAN Delta-E: %.1f (ΔE = %.4f ± %.4f)", + best_by_median["gamma"], best_by_median["median_delta_e"], + best_by_median["median_delta_e_std"]) + + # Save results + results_file = PROJECT_ROOT / "models" / "from_xyY" / "gamma_sweep_results_averaged.npz" + np.savez(results_file, results=results, all_results=all_results) + LOGGER.info("\nResults saved to: %s", results_file) + + LOGGER.info("\n" + "=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_large.py b/learning_munsell/training/from_xyY/train_multi_head_large.py new file mode 100644 index 0000000000000000000000000000000000000000..64cf19949d3d7d073498feedeea1129ac431ae1c --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_large.py @@ -0,0 +1,246 @@ +""" +Train multi-head ML model on large dataset (2M samples) for xyY to Munsell conversion. + +This script trains on the larger dataset for potentially improved accuracy. +Uses the same architecture as train_multi_head_mlp.py but with the large dataset. +""" + +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=2048, help="Batch size for training") +@click.option("--lr", default=5e-4, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train multi-head MLP on large dataset (2M samples) for xyY to Munsell. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from large cached .npz file + 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1] + 3. Creates multi-head MLP with shared encoder and component-specific heads + 4. Trains with weighted MSE loss (emphasizing chroma) + 5. Uses Adam optimizer with ReduceLROnPlateau scheduler + 6. Applies early stopping based on validation loss (patience=30) + 7. Exports best model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Model Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Load large training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use shared normalization parameters covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders (larger batch size for larger dataset) + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiHeadMLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + learning_rate = lr + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + criterion = weighted_mse_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_large", + "learning_rate": learning_rate, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_head_large_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + onnx_file = model_directory / "multi_head_large.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + params_file = model_directory / "multi_head_large_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_mlp.py b/learning_munsell/training/from_xyY/train_multi_head_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5515f26f4569fe4cac62c49e5c7fff5eaf860a --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_mlp.py @@ -0,0 +1,269 @@ +""" +Train multi-head ML model for xyY to Munsell conversion. + +Architecture: +- Shared encoder: 3 inputs → 512-dim features +- 4 separate heads (one per component): + - Hue head (circular/angular) + - Value head (linear lightness) + - Chroma head (non-linear saturation - larger capacity) + - Code head (discrete categorical) + +This architecture allows each component to learn specialized features +while sharing the general color space understanding. +""" + +import logging +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=5e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train multi-head MLP for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from cached .npz file + 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1] + 3. Creates multi-head MLP with shared encoder and component-specific heads + 4. Trains with weighted MSE loss (emphasizing chroma) + 5. Uses Adam optimizer with no learning rate scheduling + 6. Applies early stopping based on validation loss (patience=20) + 7. Exports best model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Model Training") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + LOGGER.error("Please run 01_generate_training_data.py first") + return + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use shared normalization parameters covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiHeadMLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Count parameters per component + encoder_params = sum(p.numel() for p in model.encoder.parameters()) + hue_params = sum(p.numel() for p in model.hue_head.parameters()) + value_params = sum(p.numel() for p in model.value_head.parameters()) + chroma_params = sum(p.numel() for p in model.chroma_head.parameters()) + code_params = sum(p.numel() for p in model.code_head.parameters()) + + LOGGER.info(" - Shared encoder: %s", f"{encoder_params:,}") + LOGGER.info(" - Hue head: %s", f"{hue_params:,}") + LOGGER.info(" - Value head: %s", f"{value_params:,}") + LOGGER.info(" - Chroma head: %s (WIDER)", f"{chroma_params:,}") + LOGGER.info(" - Code head: %s", f"{code_params:,}") + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + # Use weighted MSE with default weights + weights = torch.tensor([1.0, 1.0, 3.0, 0.5]) + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log parameters + mlflow.log_params( + { + "model": "multi_head", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Log to MLflow + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_head_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input + dummy_input = torch.randn(1, 3).to(device) + + # Export + onnx_file = model_directory / "multi_head.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "multi_head_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py b/learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a4481b20cedbf19981e38a757d1edbf8e76ddd72 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py @@ -0,0 +1,378 @@ +""" +Train Multi-Head error predictor for Multi-Head base model. + +Architecture: +- 4 independent error correction branches (one per component) +- Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output +- Chroma branch: WIDER (1.5x capacity for hardest component) + +Complete independence matches the Multi-Head base model philosophy. +""" + +import logging +from pathlib import Path +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ( + ComponentErrorPredictor, + MultiHeadErrorPredictorToMunsell, +) +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_munsell, normalize_xyY +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load Multi-Head base ONNX model and normalization parameters. + + Parameters + ---------- + model_path : Path + Path to Multi-Head base model ONNX file. + params_path : Path + Path to normalization parameters .npz file. + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session. + input_params : dict + Input normalization ranges. + output_params : dict + Output normalization ranges. + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +@click.command() +@click.option( + "--base-model", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to Multi-Head base model ONNX file", +) +@click.option( + "--params", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to normalization params file", +) +@click.option( + "--epochs", + type=int, + default=200, + help="Number of training epochs", +) +@click.option( + "--batch-size", + type=int, + default=1024, + help="Batch size for training", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate", +) +@click.option( + "--patience", + type=int, + default=20, + help="Patience for early stopping", +) +def main( + base_model: Path | None, + params: Path | None, + epochs: int, + batch_size: int, + lr: float, + patience: int, +) -> None: + """ + Train Multi-Head error predictor with 4 independent branches. + + Parameters + ---------- + base_model : Path or None + Path to Multi-Head base model ONNX file. Uses default if None. + params : Path or None + Path to normalization parameters. Uses default if None. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses precision-focused loss function + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Error Predictor: 4 Independent Branches") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + # Use provided paths or defaults + if base_model is None: + base_model = model_directory / "multi_head.onnx" + if params is None: + params = model_directory / "multi_head_normalization_params.npz" + + cache_file = data_dir / "training_data.npz" + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-Head base model from %s...", base_model) + base_session, input_params, output_params = load_base_model(base_model, params) + + # Load training data + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Multi-Head base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + # Base predictions (normalized) + base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0] + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0] + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Multi-Head base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-Head error predictor + model = MultiHeadErrorPredictorToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Multi-Head error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Count parameters per branch + hue_params = sum(p.numel() for p in model.hue_branch.parameters()) + value_params = sum(p.numel() for p in model.value_branch.parameters()) + chroma_params = sum(p.numel() for p in model.chroma_branch.parameters()) + code_params = sum(p.numel() for p in model.code_branch.parameters()) + + LOGGER.info(" - Hue branch: %s", f"{hue_params:,}") + LOGGER.info(" - Value branch: %s", f"{value_params:,}") + LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}") + LOGGER.info(" - Code branch: %s", f"{code_params:,}") + + # Training setup with precision-focused loss + LOGGER.info("") + LOGGER.info("Using precision-focused loss function:") + LOGGER.info(" - MSE (weight: 1.0)") + LOGGER.info(" - MAE (weight: 0.5)") + LOGGER.info(" - Log penalty for small errors (weight: 0.3)") + LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + criterion = precision_focused_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head_multi_error_predictor") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log hyperparameters + mlflow.log_params( + { + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "weight_decay": 1e-5, + "optimizer": "AdamW", + "scheduler": "ReduceLROnPlateau", + "criterion": "precision_focused_loss", + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Update learning rate + scheduler.step(val_loss) + + # Log to MLflow + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / "multi_head_multi_error_predictor_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Multi-Head error predictor to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input (xyY_norm + base_pred_norm = 7 inputs) + dummy_input = torch.randn(1, 7).to(device) + + # Export + onnx_file = model_directory / "multi_head_multi_error_predictor.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + LOGGER.info("Multi-Head error predictor ONNX model saved to: %s", onnx_file) + + # Log artifacts + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + + # Log model + mlflow.pytorch.log_model(model, "model") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py b/learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py new file mode 100644 index 0000000000000000000000000000000000000000..dc912f6489b4f96e487f80e3f5c65579fc3d03ac --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py @@ -0,0 +1,409 @@ +""" +Train Multi-Head error predictor on large dataset (2M samples). + +Architecture: +- 4 independent error correction branches (one per component) +- Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output +- Chroma branch: WIDER (1.5x capacity for hardest component) + +Uses the large dataset for improved model training. +""" + +import logging +from pathlib import Path +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ( + ComponentErrorPredictor, + MultiHeadErrorPredictorToMunsell, +) +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the base ONNX model and normalization parameters. + + Parameters + ---------- + model_path : Path + Path to the ONNX model file. + params_path : Path + Path to the normalization parameters file (.npz). + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session. + input_params : dict + Input normalization parameters. + output_params : dict + Output normalization parameters. + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +@click.command() +@click.option( + "--base-model", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to Multi-Head large base model ONNX file", +) +@click.option( + "--params", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to normalization params file", +) +@click.option( + "--output-suffix", + type=str, + default="large", + help="Suffix for output filenames (default: 'large')", +) +@click.option( + "--epochs", + type=int, + default=300, + help="Number of training epochs (default: 300)", +) +@click.option( + "--batch-size", + type=int, + default=2048, + help="Batch size for training (default: 2048)", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate (default: 3e-4)", +) +@click.option( + "--patience", + type=int, + default=30, + help="Early stopping patience (default: 30)", +) +def main( + base_model: Path | None, + params: Path | None, + output_suffix: str, + epochs: int, + batch_size: int, + lr: float, + patience: int, +) -> None: + """ + Train Multi-Head error predictor on large dataset. + + This script trains an error predictor on top of the Multi-Head large + base model, using the 2M sample dataset for improved accuracy. + + Parameters + ---------- + base_model : Path, optional + Path to the Multi-Head large base model ONNX file. + Default: models/from_xyY/multi_head_large.onnx + params : Path, optional + Path to the normalization parameters file. + Default: models/from_xyY/multi_head_large_normalization_params.npz + output_suffix : str + Suffix for output filenames (default: 'large'). + + Notes + ----- + The training pipeline: + 1. Loads pre-trained Multi-Head large base model + 2. Generates base model predictions for training data (in batches) + 3. Computes residual errors between predictions and targets + 4. Trains multi-head error predictor on these residuals + 5. Uses precision-focused loss function + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Error Predictor: Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + # Use provided paths or defaults for large model + if base_model is None: + base_model = model_directory / "multi_head_large.onnx" + if params is None: + params = model_directory / "multi_head_large_normalization_params.npz" + + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + if not base_model.exists(): + LOGGER.error("Error: Multi-Head large base model not found at %s", base_model) + LOGGER.error("Please run train_multi_head_large.py first") + return + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-Head large base model from %s...", base_model) + base_session, input_params, output_params = load_base_model(base_model, params) + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Multi-Head large base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + # Base predictions (normalized) - process in batches for memory efficiency + LOGGER.info(" Processing training set predictions...") + inference_batch_size = 50000 + base_pred_train_norm = [] + for i in range(0, len(X_train_norm), inference_batch_size): + batch = X_train_norm[i : i + inference_batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train_norm.append(pred) + base_pred_train_norm = np.concatenate(base_pred_train_norm, axis=0) + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + LOGGER.info(" Processing validation set predictions...") + base_pred_val_norm = [] + for i in range(0, len(X_val_norm), inference_batch_size): + batch = X_val_norm[i : i + inference_batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val_norm.append(pred) + base_pred_val_norm = np.concatenate(base_pred_val_norm, axis=0) + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Multi-Head large base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders (larger batch size for large dataset) + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-Head error predictor + model = MultiHeadErrorPredictorToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Multi-Head error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Count parameters per branch + hue_params = sum(p.numel() for p in model.hue_branch.parameters()) + value_params = sum(p.numel() for p in model.value_branch.parameters()) + chroma_params = sum(p.numel() for p in model.chroma_branch.parameters()) + code_params = sum(p.numel() for p in model.code_branch.parameters()) + + LOGGER.info(" - Hue branch: %s", f"{hue_params:,}") + LOGGER.info(" - Value branch: %s", f"{value_params:,}") + LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}") + LOGGER.info(" - Code branch: %s", f"{code_params:,}") + + # Training setup + LOGGER.info("") + LOGGER.info("Using precision-focused loss function:") + LOGGER.info(" - MSE (weight: 1.0)") + LOGGER.info(" - MAE (weight: 0.5)") + LOGGER.info(" - Log penalty for small errors (weight: 0.3)") + LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + criterion = precision_focused_loss + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", f"multi_head_multi_error_predictor_{output_suffix}" + ) + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": f"multi_head_multi_error_predictor_{output_suffix}", + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "weight_decay": 1e-5, + "optimizer": "AdamW", + "scheduler": "ReduceLROnPlateau", + "criterion": "precision_focused_loss", + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / f"multi_head_multi_error_predictor_{output_suffix}_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + "output_params": output_params, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Multi-Head error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 7).to(device) + + onnx_file = model_directory / f"multi_head_multi_error_predictor_{output_suffix}.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + LOGGER.info("Multi-Head error predictor ONNX model saved to: %s", onnx_file) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_st2084.py b/learning_munsell/training/from_xyY/train_multi_head_st2084.py new file mode 100644 index 0000000000000000000000000000000000000000..dec1b17e81c804b94b7fa6cfb4e49055a2bca001 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_st2084.py @@ -0,0 +1,313 @@ +""" +Train multi-head ML model for xyY to Munsell conversion with ST.2084 (PQ) encoded Y. + +Experiment: Apply SMPTE ST.2084 (Perceptual Quantizer) encoding to Y before +normalization. ST.2084 is designed for perceptual uniformity across a wide +luminance range, potentially providing better alignment with Munsell Value +than simple gamma correction. + +The multi-head architecture has separate heads for each Munsell component, +so PQ encoding on Y should primarily benefit Value prediction without +negatively impacting Chroma prediction. +""" + +import logging +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from colour.models import eotf_inverse_ST2084 +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + +# Peak luminance for ST.2084 scaling +# Munsell Y is relative luminance [0, 1], we scale to cd/m² for ST.2084 +# Using 100 cd/m² as reference white (typical SDR display) +L_P_REFERENCE = 100.0 + + +def normalize_inputs( + X: NDArray, L_p: float = L_P_REFERENCE +) -> tuple[NDArray, dict[str, Any]]: + """ + Normalize xyY inputs to [0, 1] range with ST.2084 (PQ) encoding on Y. + + Parameters + ---------- + X : ndarray + xyY values of shape (n, 3) where columns are [x, y, Y]. + L_p : float + Peak luminance in cd/m² for ST.2084 scaling. + + Returns + ------- + ndarray + Normalized values with ST.2084-encoded Y, dtype float32. + dict + Normalization parameters including L_p and encoding type. + """ + # xyY chromaticity and luminance ranges (all [0, 1]) + x_range = (0.0, 1.0) + y_range = (0.0, 1.0) + Y_range = (0.0, 1.0) + + X_norm = X.copy() + X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0]) + X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0]) + + # Normalize Y first, then apply ST.2084 + Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0]) + # Clip to avoid numerical issues + Y_normalized = np.clip(Y_normalized, 0, 1) + # Scale to cd/m² and apply ST.2084 inverse EOTF (PQ encoding) + # ST.2084 expects absolute luminance in cd/m² + Y_cdm2 = Y_normalized * L_p + # eotf_inverse_ST2084 returns values in [0, 1] for the 10000 cd/m² range + # We use a custom L_p to scale appropriately + X_norm[:, 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p) + + params = { + "x_range": x_range, + "y_range": y_range, + "Y_range": Y_range, + "encoding": "ST2084", + "L_p": L_p, + } + + return X_norm, params + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=5e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the multi-head model with ST.2084 (PQ) encoded Y input. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from cache + 2. Normalizes inputs with ST.2084 (PQ) encoding on Y + 3. Normalizes Munsell outputs to [0, 1] range + 4. Trains multi-head MLP with weighted MSE loss + 5. Uses early stopping based on validation loss + 6. Exports best model to ONNX format + 7. Logs metrics and artifacts to MLflow + + ST.2084 (Perceptual Quantizer) encoding is designed for perceptual + uniformity across a wide luminance range, potentially providing better + alignment with Munsell Value than simple gamma correction. The multi-head + architecture isolates this effect to the Value head without negatively + impacting Chroma prediction. + """ + + LOGGER.info("=" * 80) + LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head ST.2084 Experiment") + LOGGER.info("ST.2084 (PQ) encoding applied to Y component (L_p=%.0f cd/m²)", L_P_REFERENCE) + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + LOGGER.error("Please run 01_generate_training_data.py first") + return + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize data with ST.2084 encoding + X_train_norm, input_params = normalize_inputs(X_train, L_p=L_P_REFERENCE) + X_val_norm, _ = normalize_inputs(X_val, L_p=L_P_REFERENCE) + + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + LOGGER.info("") + LOGGER.info("Input normalization with ST.2084 (L_p=%.0f):", L_P_REFERENCE) + LOGGER.info(" Y range after ST.2084: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max()) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_norm) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val_norm) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiHeadMLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + criterion = weighted_mse_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head_st2084") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_st2084", + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "optimizer": "Adam", + "criterion": "weighted_mse_loss", + "patience": patience, + "total_params": total_params, + "encoding": "ST2084", + "L_p": L_P_REFERENCE, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_head_st2084_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "input_params": input_params, + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + onnx_file = model_directory / "multi_head_st2084.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=17, + input_names=["xyY_st2084"], + output_names=["munsell_spec"], + dynamic_axes={"xyY_st2084": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + # Save normalization parameters (including ST.2084 info) + params_file = model_directory / "multi_head_st2084_normalization_params.npz" + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("IMPORTANT: Input Y must be ST.2084-encoded with L_p=%.0f", L_P_REFERENCE) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_weighted_boundary.py b/learning_munsell/training/from_xyY/train_multi_head_weighted_boundary.py new file mode 100644 index 0000000000000000000000000000000000000000..acc475a36262d6eb0001b5b7f5b433d51dae8d1f --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_weighted_boundary.py @@ -0,0 +1,625 @@ +""" +Train multi-head model with weighted sampling and boundary-aware loss. + +This script implements two approaches to improve performance on problematic +high-value, high-chroma regions (Y/GY/G hues): + +1. Weighted Training: Apply higher loss weights to samples in problem regions + (Y/GY/G hues, value >= 8, chroma >= 12) + +2. Boundary-Aware Loss: Add penalty term when chroma prediction exceeds + maximum valid chroma for the (hue_code, value) combination from real + Munsell gamut boundaries. +""" + +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiHeadMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) + +LOGGER = logging.getLogger(__name__) + +# Maximum valid chroma per (hue_code, value) from MUNSELL_COLOURS_REAL +# Keys are hue codes (1=R, 2=YR, 3=Y, 4=GY, 5=G, 6=BG, 7=B, 8=PB, 9=P, 10=RP) +# Values are dicts mapping Munsell value (1-9) to maximum chroma +MUNSELL_MAX_CHROMA = { + 1: {1.0: 8.0, 2.0: 10.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 16.0, 7.0: 16.0, 8.0: 12.0, 9.0: 4.0}, + 2: {1.0: 8.0, 2.0: 14.0, 3.0: 20.0, 4.0: 24.0, 5.0: 24.0, + 6.0: 22.0, 7.0: 22.0, 8.0: 18.0, 9.0: 10.0}, + 3: {1.0: 8.0, 2.0: 16.0, 3.0: 22.0, 4.0: 26.0, 5.0: 28.0, + 6.0: 28.0, 7.0: 26.0, 8.0: 24.0, 9.0: 16.0}, + 4: {1.0: 6.0, 2.0: 12.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 20.0, 7.0: 22.0, 8.0: 24.0, 9.0: 18.0}, + 5: {1.0: 2.0, 2.0: 4.0, 3.0: 6.0, 4.0: 10.0, 5.0: 12.0, + 6.0: 14.0, 7.0: 16.0, 8.0: 20.0, 9.0: 20.0}, + 6: {1.0: 8.0, 2.0: 8.0, 3.0: 10.0, 4.0: 12.0, 5.0: 16.0, + 6.0: 18.0, 7.0: 20.0, 8.0: 20.0, 9.0: 8.0}, + 7: {1.0: 10.0, 2.0: 14.0, 3.0: 16.0, 4.0: 20.0, 5.0: 20.0, + 6.0: 18.0, 7.0: 16.0, 8.0: 10.0, 9.0: 6.0}, + 8: {1.0: 16.0, 2.0: 20.0, 3.0: 22.0, 4.0: 26.0, 5.0: 26.0, + 6.0: 24.0, 7.0: 20.0, 8.0: 14.0, 9.0: 6.0}, + 9: {1.0: 26.0, 2.0: 30.0, 3.0: 34.0, 4.0: 32.0, 5.0: 30.0, + 6.0: 26.0, 7.0: 22.0, 8.0: 14.0, 9.0: 6.0}, + 10: {1.0: 38.0, 2.0: 38.0, 3.0: 34.0, 4.0: 30.0, 5.0: 22.0, + 6.0: 16.0, 7.0: 12.0, 8.0: 8.0, 9.0: 4.0}, +} + + +def compute_sample_weights( + y: NDArray, + problem_weight: float = 3.0, +) -> NDArray: + """ + Compute per-sample weights, upweighting problem regions. + + Analysis of model errors revealed systematic underperformance on + yellow-green hues with high value and chroma. This function assigns + higher weights to these samples to focus learning on difficult cases. + + Parameters + ---------- + y : ndarray + Munsell specifications [hue, value, chroma, code] of shape (n, 4). + Code values: 1=R, 2=YR, 3=Y, 4=GY, 5=G, 6=BG, 7=B, 8=PB, 9=P, 10=RP. + problem_weight : float, optional + Weight multiplier for problem region samples. + + Returns + ------- + ndarray + Per-sample weights of shape (n,). Normal samples have weight 1.0, + problem region samples have weight equal to ``problem_weight``. + + Notes + ----- + Problem region criteria: + - Hue codes 3, 4, or 5 (Yellow, Yellow-Green, Green) + - Value >= 8 (high brightness) + - Chroma >= 12 (high saturation) + """ + weights = np.ones(len(y), dtype=np.float32) + + # Identify problem region samples + codes = np.round(y[:, 3]).astype(int) + values = y[:, 1] + chromas = y[:, 2] + + is_problem_hue = np.isin(codes, [3, 4, 5]) + is_high_value = values >= 8.0 + is_high_chroma = chromas >= 12.0 + + problem_mask = is_problem_hue & is_high_value & is_high_chroma + weights[problem_mask] = problem_weight + + return weights + + +def build_max_chroma_tensor(device: torch.device) -> torch.Tensor: + """ + Build a lookup tensor for maximum valid chroma values. + + Creates a 2D tensor for O(1) lookup of the maximum chroma allowed + for any (hue_code, value) combination in the Munsell system. + + Parameters + ---------- + device : torch.device + Device to create the tensor on (CPU, CUDA, or MPS). + + Returns + ------- + torch.Tensor + Lookup tensor of shape (11, 10) where ``tensor[code, value]`` + gives the maximum valid chroma. Index 0 is unused for both + dimensions (codes are 1-10, values are 1-9). + + Notes + ----- + Maximum chroma values are derived from ``MUNSELL_COLOURS_REAL`` dataset, + which contains the 2,734 physically realizable Munsell colors. + + Examples + -------- + >>> tensor = build_max_chroma_tensor(torch.device('cpu')) + >>> tensor[3, 8] # Max chroma for Yellow (code=3), value=8 + tensor(24.) + """ + max_chroma_tensor = torch.zeros(11, 10, device=device) + + for code, values in MUNSELL_MAX_CHROMA.items(): + for value, max_chroma in values.items(): + max_chroma_tensor[code, int(value)] = max_chroma + + return max_chroma_tensor + + +def boundary_aware_loss( + pred: torch.Tensor, + target: torch.Tensor, + max_chroma_tensor: torch.Tensor, + output_params: dict, + component_weights: torch.Tensor, + sample_weights: torch.Tensor, + boundary_penalty_weight: float = 0.5, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute loss with boundary-aware chroma penalty. + + Parameters + ---------- + pred : Tensor + Predicted Munsell values (normalized). + target : Tensor + Ground truth Munsell values (normalized). + max_chroma_tensor : Tensor + Lookup tensor for maximum valid chroma per (code, value). + output_params : dict + Output normalization parameters for denormalization. + component_weights : Tensor + Component-wise loss weights [w_hue, w_value, w_chroma, w_code]. + sample_weights : Tensor + Per-sample weights for problem region upweighting. + boundary_penalty_weight : float, optional + Weight for boundary violation penalty. Default is 0.5. + + Returns + ------- + Tensor + Total loss (MSE + boundary penalty). + Tensor + Weighted MSE loss component. + Tensor + Boundary violation loss component. + + Notes + ----- + The boundary loss penalizes chroma predictions that exceed the + maximum valid chroma for the predicted (hue_code, value) combination + from the real Munsell gamut. + """ + # Denormalize predictions to get actual Munsell values + hue_range = output_params["hue_range"] + value_range = output_params["value_range"] + chroma_range = output_params["chroma_range"] + code_range = output_params["code_range"] + + pred_chroma = pred[:, 2] * (chroma_range[1] - chroma_range[0]) + chroma_range[0] + pred_value = pred[:, 1] * (value_range[1] - value_range[0]) + value_range[0] + pred_code = pred[:, 3] * (code_range[1] - code_range[0]) + code_range[0] + + # Round code and value for lookup (during training these are continuous) + code_idx = torch.clamp(torch.round(pred_code), 1, 10).long() + value_idx = torch.clamp(torch.round(pred_value), 1, 9).long() + + # Look up max chroma for each sample + max_chroma = max_chroma_tensor[code_idx, value_idx] + + # Compute boundary violation penalty (only when chroma exceeds max) + chroma_excess = torch.relu(pred_chroma - max_chroma) + boundary_loss = (chroma_excess ** 2).mean() + + # Standard weighted MSE + mse = (pred - target) ** 2 + weighted_mse = mse * component_weights + + # Apply sample weights + sample_weighted_mse = (weighted_mse.mean(dim=1) * sample_weights).mean() + + # Total loss + total_loss = sample_weighted_mse + boundary_penalty_weight * boundary_loss + + return total_loss, sample_weighted_mse, boundary_loss + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + component_weights: torch.Tensor, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """ + Train the model for one epoch with boundary-aware loss. + + Parameters + ---------- + model : nn.Module + The neural network model to train. + dataloader : DataLoader + DataLoader providing training batches (X, y, weights). + optimizer : optim.Optimizer + Optimizer for updating model parameters. + device : torch.device + Device to run training on. + max_chroma_tensor : Tensor + Lookup tensor for maximum valid chroma. + output_params : dict + Output normalization parameters. + component_weights : Tensor + Component-wise loss weights. + boundary_penalty_weight : float + Weight for boundary violation penalty. + + Returns + ------- + float + Average total loss for the epoch. + float + Average MSE loss component. + float + Average boundary loss component. + """ + model.train() + total_loss_sum = 0.0 + mse_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + for X_batch, y_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, mse_loss, boundary_loss = boundary_aware_loss( + outputs, y_batch, max_chroma_tensor, output_params, + component_weights, w_batch, boundary_penalty_weight + ) + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + total_loss_sum += total_loss.item() + mse_loss_sum += mse_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, mse_loss_sum / n, boundary_loss_sum / n + + +def validate( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + component_weights: torch.Tensor, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """ + Validate the model with boundary-aware loss. + + Parameters + ---------- + model : nn.Module + The neural network model to validate. + dataloader : DataLoader + DataLoader providing validation batches (X, y, weights). + device : torch.device + Device to run validation on. + max_chroma_tensor : Tensor + Lookup tensor for maximum valid chroma. + output_params : dict + Output normalization parameters. + component_weights : Tensor + Component-wise loss weights. + boundary_penalty_weight : float + Weight for boundary violation penalty. + + Returns + ------- + float + Average total loss. + float + Average MSE loss component. + float + Average boundary loss component. + """ + model.eval() + total_loss_sum = 0.0 + mse_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + with torch.no_grad(): + for X_batch, y_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, mse_loss, boundary_loss = boundary_aware_loss( + outputs, y_batch, max_chroma_tensor, output_params, + component_weights, w_batch, boundary_penalty_weight + ) + + total_loss_sum += total_loss.item() + mse_loss_sum += mse_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, mse_loss_sum / n, boundary_loss_sum / n + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=2048, help="Batch size for training") +@click.option("--lr", default=5e-4, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the multi-head model with weighted sampling and boundary-aware loss. + + Notes + ----- + This script implements two complementary approaches to improve performance + on problematic high-value, high-chroma regions (Y/GY/G hues): + + 1. Weighted Training: Applies higher loss weights (3x) to samples in + problem regions (Y/GY/G hues, value >= 8, chroma >= 12) to focus + learning on these challenging cases. + + 2. Boundary-Aware Loss: Adds a penalty term when chroma predictions + exceed the maximum valid chroma for the (hue_code, value) combination + from the real Munsell gamut boundaries. This prevents the model from + predicting invalid colors. + + The training pipeline includes: + - Loading large training dataset (2M samples) + - Computing sample weights for problem region upweighting + - Building max chroma lookup tensor from real Munsell data + - Training with combined MSE + boundary penalty loss + - Learning rate scheduling with ReduceLROnPlateau + - Early stopping based on validation loss + - Exporting the best model to ONNX format + - Logging metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Model with Weighted Training + Boundary-Aware Loss") + LOGGER.info("=" * 80) + + # Hyperparameters for the two approaches + problem_region_weight = 3.0 # Weight multiplier for problem region samples + boundary_penalty_weight = 0.5 # Weight for boundary violation penalty + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Load large training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Compute sample weights + LOGGER.info("") + LOGGER.info("Computing sample weights (problem region weight: %.1f)...", + problem_region_weight) + train_weights = compute_sample_weights(y_train, problem_region_weight) + val_weights = compute_sample_weights(y_val, problem_region_weight) + + n_problem_train = np.sum(train_weights > 1.0) + n_problem_val = np.sum(val_weights > 1.0) + LOGGER.info(" Train problem region samples: %d (%.2f%%)", + n_problem_train, 100 * n_problem_train / len(y_train)) + LOGGER.info(" Val problem region samples: %d (%.2f%%)", + n_problem_val, 100 * n_problem_val / len(y_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use shared normalization parameters covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + w_train_t = torch.FloatTensor(train_weights) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + w_val_t = torch.FloatTensor(val_weights) + + # Create data loaders with weights + train_dataset = TensorDataset(X_train_t, y_train_t, w_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t, w_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiHeadMLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Build max chroma lookup tensor + max_chroma_tensor = build_max_chroma_tensor(device) + LOGGER.info("") + LOGGER.info("Built max chroma lookup table for boundary-aware loss") + LOGGER.info("Boundary penalty weight: %.2f", boundary_penalty_weight) + + # Component weights: [hue, value, chroma, code] + component_weights = torch.tensor([1.0, 1.0, 3.0, 0.5], device=device) + LOGGER.info("Component weights: %s", component_weights.tolist()) + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_head_weighted_boundary") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params({ + "model": "multi_head_weighted_boundary", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + "problem_region_weight": problem_region_weight, + "boundary_penalty_weight": boundary_penalty_weight, + "component_weights": component_weights.tolist(), + }) + + for epoch in range(epochs): + train_total, train_mse, train_boundary = train_epoch( + model, train_loader, optimizer, device, + max_chroma_tensor, output_params, component_weights, + boundary_penalty_weight + ) + val_total, val_mse, val_boundary = validate( + model, val_loader, device, + max_chroma_tensor, output_params, component_weights, + boundary_penalty_weight + ) + + scheduler.step(val_total) + + log_training_epoch( + epoch, train_total, val_total, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics({ + "train_mse": train_mse, + "train_boundary": train_boundary, + "val_mse": val_mse, + "val_boundary": val_boundary, + }, step=epoch) + + LOGGER.info( + "Epoch %03d/%d - Train: %.6f (mse=%.6f, bnd=%.6f) | " + "Val: %.6f (mse=%.6f, bnd=%.6f) | LR: %.6f", + epoch + 1, epochs, + train_total, train_mse, train_boundary, + val_total, val_mse, val_boundary, + optimizer.param_groups[0]["lr"], + ) + + if val_total < best_val_loss: + best_val_loss = val_total + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_head_weighted_boundary_best.pth" + + torch.save({ + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_total, + }, checkpoint_file) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_total) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics({ + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + }) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + onnx_file = model_directory / "multi_head_weighted_boundary.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + params_file = model_directory / "multi_head_weighted_boundary_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_head_weighted_boundary_multi_error_predictor.py b/learning_munsell/training/from_xyY/train_multi_head_weighted_boundary_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..6212c129a238d0c546c42838e88a6b0f766ed027 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_head_weighted_boundary_multi_error_predictor.py @@ -0,0 +1,717 @@ +""" +Train Multi-Head error predictor with weighted + boundary-aware loss. + +This extends the error predictor training to also apply: +1. Weighted Training: Higher loss weights for problem regions (Y/GY/G hues, + value >= 8, chroma >= 12) +2. Boundary-Aware Loss: Penalty when corrected chroma prediction exceeds + maximum valid chroma for the (hue_code, value) combination + +Uses the weighted boundary base model for two-stage inference. +""" + +import logging +from pathlib import Path + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ( + ComponentErrorPredictor, + MultiHeadErrorPredictorToMunsell, +) +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell + +LOGGER = logging.getLogger(__name__) + +# Maximum valid chroma per (hue_code, value) from MUNSELL_COLOURS_REAL +MUNSELL_MAX_CHROMA = { + 1: {1.0: 8.0, 2.0: 10.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 16.0, 7.0: 16.0, 8.0: 12.0, 9.0: 4.0}, + 2: {1.0: 8.0, 2.0: 14.0, 3.0: 20.0, 4.0: 24.0, 5.0: 24.0, + 6.0: 22.0, 7.0: 22.0, 8.0: 18.0, 9.0: 10.0}, + 3: {1.0: 8.0, 2.0: 16.0, 3.0: 22.0, 4.0: 26.0, 5.0: 28.0, + 6.0: 28.0, 7.0: 26.0, 8.0: 24.0, 9.0: 16.0}, + 4: {1.0: 6.0, 2.0: 12.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 20.0, 7.0: 22.0, 8.0: 24.0, 9.0: 18.0}, + 5: {1.0: 2.0, 2.0: 4.0, 3.0: 6.0, 4.0: 10.0, 5.0: 12.0, + 6.0: 14.0, 7.0: 16.0, 8.0: 20.0, 9.0: 20.0}, + 6: {1.0: 8.0, 2.0: 8.0, 3.0: 10.0, 4.0: 12.0, 5.0: 16.0, + 6.0: 18.0, 7.0: 20.0, 8.0: 20.0, 9.0: 8.0}, + 7: {1.0: 10.0, 2.0: 14.0, 3.0: 16.0, 4.0: 20.0, 5.0: 20.0, + 6.0: 18.0, 7.0: 16.0, 8.0: 10.0, 9.0: 6.0}, + 8: {1.0: 16.0, 2.0: 20.0, 3.0: 22.0, 4.0: 26.0, 5.0: 26.0, + 6.0: 24.0, 7.0: 20.0, 8.0: 14.0, 9.0: 6.0}, + 9: {1.0: 26.0, 2.0: 30.0, 3.0: 34.0, 4.0: 32.0, 5.0: 30.0, + 6.0: 26.0, 7.0: 22.0, 8.0: 14.0, 9.0: 6.0}, + 10: {1.0: 38.0, 2.0: 38.0, 3.0: 34.0, 4.0: 30.0, 5.0: 22.0, + 6.0: 16.0, 7.0: 12.0, 8.0: 8.0, 9.0: 4.0}, +} + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the base ONNX model and normalization parameters. + + Parameters + ---------- + model_path : Path + Path to the ONNX model file. + params_path : Path + Path to the normalization parameters file (.npz). + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session. + input_params : dict + Input normalization parameters. + output_params : dict + Output normalization parameters. + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +def compute_sample_weights( + y: NDArray, + problem_weight: float = 3.0, +) -> NDArray: + """ + Compute per-sample weights, upweighting problem regions. + + Analysis of model errors revealed systematic underperformance on + yellow-green hues with high value and chroma. This function assigns + higher weights to these samples to focus learning on difficult cases. + + Parameters + ---------- + y : ndarray + Munsell specifications [hue, value, chroma, code] of shape (n, 4). + Code values: 1=R, 2=YR, 3=Y, 4=GY, 5=G, 6=BG, 7=B, 8=PB, 9=P, 10=RP. + problem_weight : float, optional + Weight multiplier for problem region samples. + + Returns + ------- + ndarray + Per-sample weights of shape (n,). Normal samples have weight 1.0, + problem region samples have weight equal to ``problem_weight``. + + Notes + ----- + Problem region criteria: + - Hue codes 3, 4, or 5 (Yellow, Yellow-Green, Green) + - Value >= 8 (high brightness) + - Chroma >= 12 (high saturation) + """ + weights = np.ones(len(y), dtype=np.float32) + + codes = np.round(y[:, 3]).astype(int) + values = y[:, 1] + chromas = y[:, 2] + + is_problem_hue = np.isin(codes, [3, 4, 5]) + is_high_value = values >= 8.0 + is_high_chroma = chromas >= 12.0 + + problem_mask = is_problem_hue & is_high_value & is_high_chroma + weights[problem_mask] = problem_weight + + return weights + + +def build_max_chroma_tensor(device: torch.device) -> torch.Tensor: + """ + Build a lookup tensor for maximum valid chroma values. + + Creates a 2D tensor for O(1) lookup of the maximum chroma allowed + for any (hue_code, value) combination in the Munsell system. + + Parameters + ---------- + device : torch.device + Device to create the tensor on (CPU, CUDA, or MPS). + + Returns + ------- + torch.Tensor + Lookup tensor of shape (11, 10) where ``tensor[code, value]`` + gives the maximum valid chroma. Index 0 is unused for both + dimensions (codes are 1-10, values are 1-9). + + Notes + ----- + Maximum chroma values are derived from ``MUNSELL_COLOURS_REAL`` dataset, + which contains the 2,734 physically realizable Munsell colors. + """ + max_chroma_tensor = torch.zeros(11, 10, device=device) + + for code, values in MUNSELL_MAX_CHROMA.items(): + for value, max_chroma in values.items(): + max_chroma_tensor[code, int(value)] = max_chroma + + return max_chroma_tensor + + +def weighted_boundary_error_loss( + error_pred: torch.Tensor, + error_target: torch.Tensor, + base_pred_norm: torch.Tensor, + max_chroma_tensor: torch.Tensor, + output_params: dict, + sample_weights: torch.Tensor, + boundary_penalty_weight: float = 0.5, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute weighted + boundary-aware loss for error predictor. + + The corrected prediction is: base_pred + error_pred. + We penalize when the corrected chroma exceeds the gamut boundary. + + Parameters + ---------- + error_pred : Tensor + Predicted error corrections (normalized) of shape (batch_size, 4). + error_target : Tensor + Target errors (normalized) of shape (batch_size, 4). + base_pred_norm : Tensor + Base model predictions (normalized) of shape (batch_size, 4). + max_chroma_tensor : Tensor + Lookup tensor for maximum valid chroma per (code, value). + output_params : dict + Output normalization parameters for denormalization. + sample_weights : Tensor + Per-sample weights for problem region upweighting. + boundary_penalty_weight : float, optional + Weight for boundary violation penalty. Default is 0.5. + + Returns + ------- + Tensor + Total loss (precision-focused + boundary penalty). + Tensor + Weighted precision-focused loss component. + Tensor + Boundary violation loss component. + + Notes + ----- + The precision-focused loss combines: + - MSE: Standard mean squared error (weight 1.0) + - MAE: Mean absolute error (weight 0.5) + - Log penalty: Penalizes small errors heavily (weight 0.3) + - Huber: Small delta (0.01) for precision on small errors (weight 0.5) + + The boundary loss penalizes chroma predictions that exceed the + maximum valid chroma for the predicted (hue_code, value) combination. + """ + # Compute corrected prediction (normalized) + corrected_pred = base_pred_norm + error_pred + + # Denormalize to get actual Munsell values for boundary check + value_range = output_params["value_range"] + chroma_range = output_params["chroma_range"] + code_range = output_params["code_range"] + + pred_value = corrected_pred[:, 1] * (value_range[1] - value_range[0]) + value_range[0] + pred_chroma = corrected_pred[:, 2] * (chroma_range[1] - chroma_range[0]) + chroma_range[0] + pred_code = corrected_pred[:, 3] * (code_range[1] - code_range[0]) + code_range[0] + + # Round code and value for lookup + code_idx = torch.clamp(torch.round(pred_code), 1, 10).long() + value_idx = torch.clamp(torch.round(pred_value), 1, 9).long() + + # Look up max chroma for each sample + max_chroma = max_chroma_tensor[code_idx, value_idx] + + # Boundary violation penalty + chroma_excess = torch.relu(pred_chroma - max_chroma) + boundary_loss = (chroma_excess ** 2).mean() + + # Precision-focused base loss components + mse = torch.mean((error_pred - error_target) ** 2, dim=1) + mae = torch.mean(torch.abs(error_pred - error_target), dim=1) + log_penalty = torch.mean( + torch.log1p(torch.abs(error_pred - error_target) * 1000.0), dim=1 + ) + + # Huber loss with small delta + delta = 0.01 + abs_error = torch.abs(error_pred - error_target) + huber = torch.where( + abs_error <= delta, 0.5 * abs_error**2, delta * (abs_error - 0.5 * delta) + ) + huber_loss = torch.mean(huber, dim=1) + + # Combine base loss components + base_loss = 1.0 * mse + 0.5 * mae + 0.3 * log_penalty + 0.5 * huber_loss + + # Apply sample weights + weighted_base_loss = (base_loss * sample_weights).mean() + + # Total loss + total_loss = weighted_base_loss + boundary_penalty_weight * boundary_loss + + return total_loss, weighted_base_loss, boundary_loss + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """ + Train the model for one epoch with weighted boundary-aware loss. + + Parameters + ---------- + model : nn.Module + The neural network model to train. + dataloader : DataLoader + DataLoader providing training batches (X, y, base_pred, weights). + optimizer : optim.Optimizer + Optimizer for updating model parameters. + device : torch.device + Device to run training on. + max_chroma_tensor : Tensor + Lookup tensor for maximum valid chroma. + output_params : dict + Output normalization parameters. + boundary_penalty_weight : float + Weight for boundary violation penalty. + + Returns + ------- + float + Average total loss for the epoch. + float + Average precision-focused loss component. + float + Average boundary loss component. + """ + model.train() + total_loss_sum = 0.0 + base_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + for X_batch, y_batch, base_pred_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + base_pred_batch = base_pred_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, base_loss, boundary_loss = weighted_boundary_error_loss( + outputs, y_batch, base_pred_batch, max_chroma_tensor, output_params, + w_batch, boundary_penalty_weight + ) + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + total_loss_sum += total_loss.item() + base_loss_sum += base_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, base_loss_sum / n, boundary_loss_sum / n + + +def validate( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """ + Validate the model with weighted boundary-aware loss. + + Parameters + ---------- + model : nn.Module + The neural network model to validate. + dataloader : DataLoader + DataLoader providing validation batches (X, y, base_pred, weights). + device : torch.device + Device to run validation on. + max_chroma_tensor : Tensor + Lookup tensor for maximum valid chroma. + output_params : dict + Output normalization parameters. + boundary_penalty_weight : float + Weight for boundary violation penalty. + + Returns + ------- + float + Average total loss. + float + Average precision-focused loss component. + float + Average boundary loss component. + """ + model.eval() + total_loss_sum = 0.0 + base_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + with torch.no_grad(): + for X_batch, y_batch, base_pred_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + base_pred_batch = base_pred_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, base_loss, boundary_loss = weighted_boundary_error_loss( + outputs, y_batch, base_pred_batch, max_chroma_tensor, output_params, + w_batch, boundary_penalty_weight + ) + + total_loss_sum += total_loss.item() + base_loss_sum += base_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, base_loss_sum / n, boundary_loss_sum / n + + +@click.command() +@click.option( + "--epochs", + type=int, + default=300, + help="Number of training epochs", +) +@click.option( + "--batch-size", + type=int, + default=2048, + help="Batch size for training", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate", +) +@click.option( + "--patience", + type=int, + default=30, + help="Patience for early stopping", +) +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the weighted + boundary-aware error predictor. + + This script trains an error predictor that applies: + 1. Weighted Training: Higher loss weights for problem regions (Y/GY/G hues, + value >= 8, chroma >= 12) to focus learning on challenging cases. + 2. Boundary-Aware Loss: Penalty when corrected chroma prediction exceeds + the maximum valid chroma for the (hue_code, value) combination. + + The error predictor is trained on top of the weighted boundary base model + for two-stage inference: base_model → error_predictor. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained weighted boundary base model + 2. Computes sample weights for problem region upweighting + 3. Generates base model predictions for training data + 4. Computes residual errors between predictions and targets + 5. Builds max chroma lookup tensor from real Munsell data + 6. Trains with combined precision-focused + boundary penalty loss + 7. Learning rate scheduling with ReduceLROnPlateau + 8. Early stopping based on validation loss + 9. Exports model to ONNX format + 10. Logs metrics and artifacts to MLflow + + This is the best-performing model (Multi-Head W+B + Multi-Error Predictor W+B), + achieving Delta-E 0.50 on REAL Munsell colors. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head Error Predictor: Weighted + Boundary-Aware Loss") + LOGGER.info("=" * 80) + + # Hyperparameters + problem_region_weight = 3.0 + boundary_penalty_weight = 0.5 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_directory / "multi_head_weighted_boundary.onnx" + params_path = model_directory / "multi_head_weighted_boundary_normalization_params.npz" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + return + + if not base_model_path.exists(): + LOGGER.error("Error: Base model not found at %s", base_model_path) + return + + # Load base model + LOGGER.info("") + LOGGER.info("Loading weighted boundary base model from %s...", base_model_path) + base_session, input_params, output_params = load_base_model(base_model_path, params_path) + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Compute sample weights + LOGGER.info("") + LOGGER.info("Computing sample weights (problem region weight: %.1f)...", + problem_region_weight) + train_weights = compute_sample_weights(y_train, problem_region_weight) + val_weights = compute_sample_weights(y_val, problem_region_weight) + + n_problem_train = np.sum(train_weights > 1.0) + LOGGER.info(" Train problem region samples: %d (%.2f%%)", + n_problem_train, 100 * n_problem_train / len(y_train)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + batch_size = 50000 + base_pred_train_norm = [] + for i in range(0, len(X_train_norm), batch_size): + batch = X_train_norm[i : i + batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train_norm.append(pred) + base_pred_train_norm = np.concatenate(base_pred_train_norm, axis=0) + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + base_pred_val_norm = [] + for i in range(0, len(X_val_norm), batch_size): + batch = X_val_norm[i : i + batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val_norm.append(pred) + base_pred_val_norm = np.concatenate(base_pred_val_norm, axis=0) + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + LOGGER.info("") + LOGGER.info("Base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + base_pred_train_t = torch.FloatTensor(base_pred_train_norm) + w_train_t = torch.FloatTensor(train_weights) + + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + base_pred_val_t = torch.FloatTensor(base_pred_val_norm) + w_val_t = torch.FloatTensor(val_weights) + + # Create data loaders (include base predictions for boundary loss) + train_dataset = TensorDataset(X_train_t, error_train_t, base_pred_train_t, w_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t, base_pred_val_t, w_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiHeadErrorPredictorToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Multi-Head error predictor architecture:") + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Build max chroma lookup tensor + max_chroma_tensor = build_max_chroma_tensor(device) + LOGGER.info("") + LOGGER.info("Built max chroma lookup table for boundary-aware loss") + LOGGER.info("Boundary penalty weight: %.2f", boundary_penalty_weight) + LOGGER.info("Problem region weight: %.1f", problem_region_weight) + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary" + ) + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params({ + "model": "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary", + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "weight_decay": 1e-5, + "optimizer": "AdamW", + "scheduler": "ReduceLROnPlateau", + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + "problem_region_weight": problem_region_weight, + "boundary_penalty_weight": boundary_penalty_weight, + "base_model": "multi_head_weighted_boundary", + }) + + for epoch in range(epochs): + train_total, train_base, train_boundary = train_epoch( + model, train_loader, optimizer, device, + max_chroma_tensor, output_params, boundary_penalty_weight + ) + val_total, val_base, val_boundary = validate( + model, val_loader, device, + max_chroma_tensor, output_params, boundary_penalty_weight + ) + + scheduler.step(val_total) + + log_training_epoch( + epoch, train_total, val_total, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics({ + "train_base": train_base, + "train_boundary": train_boundary, + "val_base": val_base, + "val_boundary": val_boundary, + }, step=epoch) + + LOGGER.info( + "Epoch %03d/%d - Train: %.6f (base=%.6f, bnd=%.6f) | " + "Val: %.6f (base=%.6f, bnd=%.6f) | LR: %.6f", + epoch + 1, epochs, + train_total, train_base, train_boundary, + val_total, val_base, val_boundary, + optimizer.param_groups[0]["lr"], + ) + + if val_total < best_val_loss: + best_val_loss = val_total + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary_best.pth" + ) + + torch.save({ + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_total, + }, checkpoint_file) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_total) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics({ + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + }) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 7).to(device) + + onnx_file = model_directory / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + LOGGER.info("Error predictor ONNX model saved to: %s", onnx_file) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp.py b/learning_munsell/training/from_xyY/train_multi_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..30047d2439b5941f0bae1ec050001a19a66823dc --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp.py @@ -0,0 +1,275 @@ +""" +Train multi-MLP model for xyY to Munsell conversion. + +Architecture: +- 4 independent MLP branches (one per component) +- Each branch: 3 inputs → encoder → decoder → 1 output +- Hue branch: standard size +- Value branch: standard size +- Chroma branch: WIDER (2x capacity for hardest component) +- Code branch: standard size + +Complete independence allows maximum component specialization. +""" + +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=3.41e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the MultiMLPToMunsell model for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads training data from cache + 2. Normalizes outputs to [0, 1] range + 3. Creates PyTorch DataLoaders + 4. Initializes MultiMLPToMunsell with 4 independent branches + 5. Trains with Adam optimizer and weighted MSE loss + 6. Implements early stopping based on validation loss + 7. Exports best model to ONNX format + 8. Logs all metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-MLP Model Training") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Training data not found at %s", cache_file) + LOGGER.error("Please run 01_generate_training_data.py first") + return + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use hardcoded ranges covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiMLPToMunsell(chroma_width_multiplier=2.0, dropout=0.1).to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Count parameters per branch + hue_params = sum(p.numel() for p in model.hue_branch.parameters()) + value_params = sum(p.numel() for p in model.value_branch.parameters()) + chroma_params = sum(p.numel() for p in model.chroma_branch.parameters()) + code_params = sum(p.numel() for p in model.code_branch.parameters()) + + LOGGER.info(" - Hue branch: %s", f"{hue_params:,}") + LOGGER.info(" - Value branch: %s", f"{value_params:,}") + LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}") + LOGGER.info(" - Code branch: %s", f"{code_params:,}") + + # Training setup + # Optimized learning rate from hyperparameter search (Trial 2) + optimizer = optim.Adam(model.parameters(), lr=lr) + # Use weighted MSE with optimized weights from hyperparameter search (Trial 2) + weights = torch.tensor([1.0, 1.0, 5.0, 0.4]) # Chroma: 5.0, Code: 0.4 + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_mlp") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + LOGGER.info( + "View experiments with: mlflow ui --backend-store-uri %s", + PROJECT_ROOT / "mlruns", + ) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log parameters + mlflow.log_params( + { + "model": "multi_mlp", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "dropout": 0.1, + "chroma_width_multiplier": 2.0, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Log to MLflow + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_mlp_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input + dummy_input = torch.randn(1, 3).to(device) + + # Export + onnx_file = model_directory / "multi_mlp.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "multi_mlp_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + + # Log the model + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp_jax.py b/learning_munsell/training/from_xyY/train_multi_mlp_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..9c07c4c5a31f82e0aa731004473fabb1cbc2207b --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp_jax.py @@ -0,0 +1,796 @@ +""" +Train multi-MLP model for xyY to Munsell conversion using JAX/Flax. + +Uses differentiable Delta-E CIE2000 loss via round-trip: +1. Predict Munsell from xyY (main network) +2. Convert predicted Munsell → xyY (using trained approximator) +3. Compute Delta-E between reconstructed xyY and input xyY + +Architecture: +- 4 independent MLP branches (Hue, Value, Chroma, Code) +- Chroma branch is 2x wider (hardest component) +- Uses LayerNorm for simplicity (no batch statistics) +""" + +from __future__ import annotations + +import logging +from typing import Any + +import flax.linen as nn +import jax +import jax.numpy as jnp +import mlflow +import numpy as np +import optax +from flax.training import train_state +from jax import Array +from numpy.typing import NDArray + +from learning_munsell import PROJECT_ROOT +from learning_munsell.losses.jax_delta_e import delta_E_loss_jit +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) + +LOGGER = logging.getLogger(__name__) + + +# ============================================================================ +# Munsell → xyY Approximator (for round-trip loss) +# ============================================================================ + + +class MunsellToXYYApproximator(nn.Module): + """ + Small MLP to approximate Munsell to xyY conversion in JAX/Flax. + + This approximator is used for computing differentiable Delta-E CIE2000 loss + via round-trip conversion: xyY → Munsell → xyY → Delta-E. + + Attributes + ---------- + hidden_dims : tuple of int + Hidden layer dimensions, default is (128, 256, 128). + + Notes + ----- + Uses LayerNorm and SiLU activations for stable training. + Input is normalized Munsell [hue, value, chroma, code]. + Output is approximate xyY values. + """ + + hidden_dims: tuple[int, ...] = (128, 256, 128) + + @nn.compact + def __call__(self, munsell: Array) -> Array: + """ + Apply Munsell to xyY approximation. + + Parameters + ---------- + munsell : Array + Normalized Munsell specification of shape (batch_size, 4). + + Returns + ------- + Array + Approximate xyY values of shape (batch_size, 3). + """ + x = munsell + for dim in self.hidden_dims: + x = nn.Dense(dim)(x) + x = nn.LayerNorm()(x) + x = nn.silu(x) + return nn.Dense(3)(x) # Output: x, y, Y + + +def load_munsell_to_xyY_approximator() -> np.lib.npyio.NpzFile: + """ + Load pre-trained Munsell to xyY approximator weights. + + Returns + ------- + np.lib.npyio.NpzFile + NPZ file containing pre-trained network weights. + + Raises + ------ + FileNotFoundError + If the approximator weights file is not found. + + Notes + ----- + The approximator must be trained first using + train_munsell_to_xyY_approximator.py before this model can be trained. + """ + weights_path = ( + PROJECT_ROOT / "models" / "to_xyY" / "munsell_to_xyY_approximator_weights.npz" + ) + + if not weights_path.exists(): + msg = ( + f"Munsell → xyY approximator not found at {weights_path}. " + "Please train it first with train_munsell_to_xyY_approximator.py" + ) + raise FileNotFoundError(msg) + + LOGGER.info("Loading Munsell → xyY approximator from %s", weights_path) + return np.load(weights_path) + + +# ============================================================================ +# xyY → Munsell Network +# ============================================================================ + + +class ComponentMLP(nn.Module): + """ + Independent MLP for predicting a single Munsell component in JAX/Flax. + + A deep feedforward network with LayerNorm and ReLU activations. The width + can be scaled via width_multiplier to create wider networks for harder + components like chroma. + + Attributes + ---------- + width_multiplier : float + Multiplier for all hidden layer widths, default is 1.0. + Use 2.0 for chroma which requires more capacity. + + Notes + ----- + Architecture: 3 → 128 → 256 → 512 → 256 → 128 → 1 + (All dimensions scaled by width_multiplier) + Input is xyY values. Output is a single Munsell component value. + """ + + width_multiplier: float = 1.0 + + @nn.compact + def __call__(self, x: Array) -> Array: + """ + Predict single Munsell component. + + Parameters + ---------- + x : Array + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Array + Predicted component value of shape (batch_size, 1). + """ + h1 = int(128 * self.width_multiplier) + h2 = int(256 * self.width_multiplier) + h3 = int(512 * self.width_multiplier) + + # Encoder + x = nn.Dense(h1)(x) + x = nn.relu(x) + x = nn.LayerNorm()(x) + + x = nn.Dense(h2)(x) + x = nn.relu(x) + x = nn.LayerNorm()(x) + + x = nn.Dense(h3)(x) + x = nn.relu(x) + x = nn.LayerNorm()(x) + + # Decoder + x = nn.Dense(h2)(x) + x = nn.relu(x) + x = nn.LayerNorm()(x) + + x = nn.Dense(h1)(x) + x = nn.relu(x) + x = nn.LayerNorm()(x) + + # Output + return nn.Dense(1)(x) + + +class MultiMLPMunsell(nn.Module): + """ + Multi-MLP for xyY to Munsell conversion in JAX/Flax with Delta-E loss. + + Uses 4 independent MLP branches, one for each Munsell component. This + architecture allows component-specific learning and is trained with a + differentiable Delta-E CIE2000 loss via round-trip conversion. + + Attributes + ---------- + hue_branch : ComponentMLP + MLP for hue prediction (1.0x width). + value_branch : ComponentMLP + MLP for value prediction (1.0x width). + chroma_branch : ComponentMLP + MLP for chroma prediction (2.0x width). + code_branch : ComponentMLP + MLP for hue code prediction (1.0x width). + + Notes + ----- + Trained with combined MSE + Delta-E loss where Delta-E is computed via: + xyY → Munsell (this model) → xyY (approximator) → Delta-E + + This enables end-to-end differentiable training with perceptual color + difference metrics. The chroma branch is wider (2.0x) to handle the + hardest component prediction task. + """ + + @nn.compact + def __call__(self, x: Array) -> Array: + """ + Predict Munsell specification from xyY. + + Parameters + ---------- + x : Array + Input xyY values of shape (batch_size, 3). + + Returns + ------- + Array + Predicted Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + """ + hue = ComponentMLP(width_multiplier=1.0)(x) + value = ComponentMLP(width_multiplier=1.0)(x) + chroma = ComponentMLP(width_multiplier=2.0)(x) # WIDER + code = ComponentMLP(width_multiplier=1.0)(x) + return jnp.concatenate([hue, value, chroma, code], axis=-1) + + +# ============================================================================ +# Data Loading and Normalization +# ============================================================================ + + +def load_training_data() -> tuple[NDArray, NDArray, NDArray, NDArray]: + """ + Load training data from cache. + + Returns + ------- + tuple + (X_train, y_train, X_val, y_val) where: + - X_train: Training xyY values + - y_train: Training Munsell specifications + - X_val: Validation xyY values + - y_val: Validation Munsell specifications + """ + cache_file = PROJECT_ROOT / "data" / "training_data.npz" + + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + + return ( + data["X_train"].astype(np.float32), + data["y_train"].astype(np.float32), + data["X_val"].astype(np.float32), + data["y_val"].astype(np.float32), + ) + + +def denormalize_outputs_jax(y_norm: Array, params: dict) -> Array: + """ + Denormalize Munsell outputs back to original range (JAX version). + + Parameters + ---------- + y_norm : Array + Normalized Munsell specification [hue, value, chroma, code] in [0, 1]. + params : dict + Normalization parameters with 'hue_range', 'value_range', + 'chroma_range', 'code_range' keys. + + Returns + ------- + Array + Denormalized Munsell specification. + """ + hue = ( + y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0]) + + params["hue_range"][0] + ) + value = ( + y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0]) + + params["value_range"][0] + ) + chroma = ( + y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + + params["chroma_range"][0] + ) + code = ( + y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0]) + + params["code_range"][0] + ) + return jnp.stack([hue, value, chroma, code], axis=-1) + + +def normalize_munsell_for_approximator(munsell: Array) -> Array: + """ + Normalize Munsell specification for the to_xyY approximator. + + Parameters + ---------- + munsell : Array + Denormalized Munsell specification [hue, value, chroma, code]. + + Returns + ------- + Array + Normalized Munsell for approximator input: + [hue/10, value/10, chroma/50, code/10]. + + Notes + ----- + The approximator was trained with this specific normalization scheme, + which differs from the main model's normalization. + """ + # Approximator expects: [hue_in_decade/10, value/10, chroma/50, code/10] + hue_norm = munsell[..., 0] / 10.0 + value_norm = munsell[..., 1] / 10.0 + chroma_norm = munsell[..., 2] / 50.0 + code_norm = munsell[..., 3] / 10.0 + return jnp.stack([hue_norm, value_norm, chroma_norm, code_norm], axis=-1) + + +# ============================================================================ +# Training +# ============================================================================ + + +def create_train_state( + rng: Array, model: nn.Module, learning_rate: float +) -> train_state.TrainState: + """ + Create initial Flax training state. + + Parameters + ---------- + rng : Array + JAX random key for parameter initialization. + model : nn.Module + Flax model to initialize. + learning_rate : float + Learning rate for AdamW optimizer. + + Returns + ------- + train_state.TrainState + Initial training state with model parameters and optimizer. + """ + params = model.init(rng, jnp.ones((1, 3))) + # Use gradient clipping for stability + tx = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw(learning_rate, weight_decay=1e-4), + ) + return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) + + +def train_multi_mlp_jax( + epochs: int = 200, + batch_size: int = 1024, + lr: float = 1e-4, + delta_e_weight: float = 1.0, + mse_weight: float = 1.0, +) -> tuple[train_state.TrainState, float]: + """ + Train Multi-MLP model using JAX with combined MSE + Delta-E loss. + + This function implements differentiable Delta-E CIE2000 training via + round-trip conversion using a pre-trained Munsell → xyY approximator. + The model is trained end-to-end with perceptual color difference metrics. + + Parameters + ---------- + epochs : int, optional + Number of training epochs, default is 200. + batch_size : int, optional + Training batch size, default is 1024. + lr : float, optional + Learning rate, default is 1e-4. + delta_e_weight : float, optional + Weight for Delta-E loss component, default is 1.0. + mse_weight : float, optional + Weight for MSE loss component, default is 1.0. + + Returns + ------- + tuple + (final_state, best_val_delta_e) where: + - final_state: Final training state with best parameters + - best_val_delta_e: Best validation Delta-E achieved + + Notes + ----- + The combined loss is: + L = mse_weight * MSE + delta_e_weight * Delta-E + + Delta-E is computed via round-trip: + 1. Predict Munsell from xyY (this model) + 2. Convert Munsell → xyY (using approximator) + 3. Compute Delta-E between reconstructed and original xyY + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP JAX with Delta-E CIE2000 Loss") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("JAX devices: %s", jax.devices()) + + # Load Munsell → xyY approximator for round-trip loss + approx_weights = load_munsell_to_xyY_approximator() + + # Initialize approximator + approx_model = MunsellToXYYApproximator() + approx_rng = jax.random.PRNGKey(42) + approx_params = approx_model.init(approx_rng, jnp.ones((1, 4))) + + # Load weights into approximator (manual mapping from PyTorch format) + approx_params = { + "params": { + "Dense_0": { + "kernel": jnp.array(approx_weights["net_0_weight"].T), + "bias": jnp.array(approx_weights["net_0_bias"]), + }, + "LayerNorm_0": { + "scale": jnp.array(approx_weights["net_1_weight"]), + "bias": jnp.array(approx_weights["net_1_bias"]), + }, + "Dense_1": { + "kernel": jnp.array(approx_weights["net_3_weight"].T), + "bias": jnp.array(approx_weights["net_3_bias"]), + }, + "LayerNorm_1": { + "scale": jnp.array(approx_weights["net_4_weight"]), + "bias": jnp.array(approx_weights["net_4_bias"]), + }, + "Dense_2": { + "kernel": jnp.array(approx_weights["net_6_weight"].T), + "bias": jnp.array(approx_weights["net_6_bias"]), + }, + "LayerNorm_2": { + "scale": jnp.array(approx_weights["net_7_weight"]), + "bias": jnp.array(approx_weights["net_7_bias"]), + }, + "Dense_3": { + "kernel": jnp.array(approx_weights["net_9_weight"].T), + "bias": jnp.array(approx_weights["net_9_bias"]), + }, + } + } + + @jax.jit + def apply_approximator(munsell_norm: Array) -> Array: + return approx_model.apply(approx_params, munsell_norm) + + # Load data + X_train, y_train, X_val, y_val = load_training_data() + + # Normalize outputs to [0, 1] range + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Input params for saving (xyY values used as-is) + input_params = XYY_NORMALIZATION_PARAMS + + # Convert to JAX arrays (xyY inputs used directly without normalization) + X_train_jax = jnp.array(X_train) + y_train_jax = jnp.array(y_train_norm) + X_val_jax = jnp.array(X_val) + y_val_jax = jnp.array(y_val_norm) + + # Keep xyY for Delta-E computation (clipped to valid range for color math) + X_train_xyY = jnp.array(np.clip(X_train, 0.0, 1.0)) + X_val_xyY = jnp.array(np.clip(X_val, 0.0, 1.0)) + + LOGGER.info("") + LOGGER.info("Training samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + LOGGER.info("") + LOGGER.info("Loss weights:") + LOGGER.info(" MSE weight: %.2f", mse_weight) + LOGGER.info(" Delta-E weight: %.2f", delta_e_weight) + + # Create model and state + model = MultiMLPMunsell() + rng = jax.random.PRNGKey(0) + state = create_train_state(rng, model, lr) + + # Count parameters + param_count = sum(x.size for x in jax.tree.leaves(state.params)) + LOGGER.info("Model parameters: %s", f"{param_count:,}") + + # Component weights for MSE [hue, value, chroma, code] + mse_weights = jnp.array([1.0, 1.0, 5.0, 0.4]) + + def loss_fn( + params: dict[str, Any], + batch_x_norm: Array, + batch_x_xyY: Array, + batch_y_norm: Array, + ) -> Array: + pred_norm = model.apply(params, batch_x_norm) + + loss = jnp.array(0.0) + + # MSE loss on normalized outputs + if mse_weight > 0: + mse = jnp.mean((pred_norm - batch_y_norm) ** 2 * mse_weights) + loss = loss + mse_weight * mse + + # Delta-E loss via round-trip + if delta_e_weight > 0: + # Denormalize predictions + pred_munsell = denormalize_outputs_jax(pred_norm, output_params) + # Clamp values to valid ranges to avoid NaN + pred_munsell = jnp.stack( + [ + jnp.clip(pred_munsell[..., 0], 0.5, 10.0), # Hue + jnp.clip(pred_munsell[..., 1], 0.0, 10.0), # Value + jnp.clip(pred_munsell[..., 2], 0.0, 50.0), # Chroma + jnp.clip(pred_munsell[..., 3], 1.0, 10.0), # Code + ], + axis=-1, + ) + # Normalize for approximator + pred_munsell_approx_norm = normalize_munsell_for_approximator(pred_munsell) + # Convert to xyY + pred_xyY = apply_approximator(pred_munsell_approx_norm) + # Clamp xyY to valid ranges + pred_xyY = jnp.clip(pred_xyY, 0.001, 1.0) + # Delta-E loss + delta_e = delta_E_loss_jit(pred_xyY, batch_x_xyY) + loss = loss + delta_e_weight * delta_e + + return loss + + @jax.jit + def train_step( + state: train_state.TrainState, + batch_x_norm: Array, + batch_x_xyY: Array, + batch_y_norm: Array, + ) -> tuple[train_state.TrainState, Array]: + loss, grads = jax.value_and_grad(loss_fn)( + state.params, batch_x_norm, batch_x_xyY, batch_y_norm + ) + state = state.apply_gradients(grads=grads) + return state, loss + + @jax.jit + def eval_step( + params: dict[str, Any], + batch_x_norm: Array, + batch_x_xyY: Array, + batch_y_norm: Array, + ) -> tuple[Array, Array, Array]: + pred_norm = model.apply(params, batch_x_norm) + + # MSE + mse = jnp.mean((pred_norm - batch_y_norm) ** 2 * mse_weights) + + # Delta-E + pred_munsell = denormalize_outputs_jax(pred_norm, output_params) + pred_munsell = jnp.stack( + [ + jnp.clip(pred_munsell[..., 0], 0.5, 10.0), + jnp.clip(pred_munsell[..., 1], 0.0, 10.0), + jnp.clip(pred_munsell[..., 2], 0.0, 50.0), + jnp.clip(pred_munsell[..., 3], 1.0, 10.0), + ], + axis=-1, + ) + pred_munsell_approx_norm = normalize_munsell_for_approximator(pred_munsell) + pred_xyY = apply_approximator(pred_munsell_approx_norm) + pred_xyY = jnp.clip(pred_xyY, 0.001, 1.0) + delta_e = delta_E_loss_jit(pred_xyY, batch_x_xyY) + + # Per-component MAE (in normalized space) + mae = jnp.mean(jnp.abs(pred_norm - batch_y_norm), axis=0) + + return mse, delta_e, mae + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_mlp_jax_delta_e") + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_delta_e = float("inf") + best_params = None + patience = 20 + patience_counter = 0 + n_batches = len(X_train) // batch_size + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params({ + "model": "multi_mlp_jax_delta_e", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": param_count, + "train_samples": len(X_train), + "val_samples": len(X_val), + "mse_weight": mse_weight, + "delta_e_weight": delta_e_weight, + "framework": "JAX/Flax", + }) + + for epoch in range(epochs): + # Shuffle training data + rng, shuffle_rng = jax.random.split(rng) + perm = jax.random.permutation(shuffle_rng, len(X_train_jax)) + X_train_shuffled = X_train_jax[perm] + X_train_xyY_shuffled = X_train_xyY[perm] + y_train_shuffled = y_train_jax[perm] + + # Training + train_loss = 0.0 + for i in range(n_batches): + start = i * batch_size + end = start + batch_size + batch_x_norm = X_train_shuffled[start:end] + batch_x_xyY = X_train_xyY_shuffled[start:end] + batch_y_norm = y_train_shuffled[start:end] + + state, loss = train_step(state, batch_x_norm, batch_x_xyY, batch_y_norm) + train_loss += float(loss) + + train_loss /= n_batches + + # Validation + val_mse, val_delta_e, val_mae = eval_step( + state.params, X_val_jax, X_val_xyY, y_val_jax + ) + val_mse = float(val_mse) + val_delta_e = float(val_delta_e) + val_mae = np.array(val_mae) + + # Log to MLflow + mlflow.log_metrics({ + "train_loss": train_loss, + "val_mse": val_mse, + "val_delta_e": val_delta_e, + "mae_hue": val_mae[0], + "mae_value": val_mae[1], + "mae_chroma": val_mae[2], + "mae_code": val_mae[3], + }, step=epoch) + + # Log every epoch with consistent format + is_best = val_delta_e < best_val_delta_e + best_marker = " (best)" if is_best else "" + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f " + "(mse=%.6f, dE=%.4f)%s, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_mse + val_delta_e, + val_mse, + val_delta_e, + best_marker, + lr, + ) + + if is_best: + best_val_delta_e = val_delta_e + best_params = jax.tree.map(lambda x: x.copy(), state.params) + patience_counter = 0 + else: + patience_counter += 1 + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Final evaluation with best params + val_mse, val_delta_e, val_mae = eval_step( + best_params, X_val_jax, X_val_xyY, y_val_jax + ) + val_mae = np.array(val_mae) + + LOGGER.info("") + LOGGER.info("Final Results:") + LOGGER.info(" Best Val Delta-E: %.4f", best_val_delta_e) + LOGGER.info(" Val MSE: %.6f", float(val_mse)) + LOGGER.info(" MAE Hue: %.4f", val_mae[0]) + LOGGER.info(" MAE Value: %.4f", val_mae[1]) + LOGGER.info(" MAE Chroma: %.4f", val_mae[2]) + LOGGER.info(" MAE Code: %.4f", val_mae[3]) + + mlflow.log_metrics({ + "best_val_delta_e": best_val_delta_e, + "final_val_mse": float(val_mse), + "final_mae_hue": val_mae[0], + "final_mae_value": val_mae[1], + "final_mae_chroma": val_mae[2], + "final_mae_code": val_mae[3], + }) + + # Save model weights + models_dir = PROJECT_ROOT / "models" / "from_xyY" + models_dir.mkdir(exist_ok=True) + + weights_path = models_dir / "multi_mlp_jax_delta_e.npz" + flat_params = {} + for key, value in jax.tree_util.tree_leaves_with_path(best_params): + path_str = "_".join(str(k.key) for k in key) + flat_params[path_str] = np.array(value) + + np.savez( + weights_path, + **flat_params, + metadata=np.array( + [ + f"val_delta_e={best_val_delta_e:.4f}", + f"val_mse={float(val_mse):.6f}", + f"mae_hue={val_mae[0]:.4f}", + f"mae_value={val_mae[1]:.4f}", + f"mae_chroma={val_mae[2]:.4f}", + f"mae_code={val_mae[3]:.4f}", + ] + ), + ) + LOGGER.info("") + LOGGER.info("Saved weights: %s", weights_path) + + # Save normalization parameters + params_path = models_dir / "multi_mlp_jax_delta_e_normalization_params.npz" + np.savez( + params_path, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Saved normalization params: %s", params_path) + + mlflow.log_artifact(str(weights_path)) + mlflow.log_artifact(str(params_path)) + + LOGGER.info("=" * 80) + + return state, best_val_delta_e + + +def main() -> None: + """ + Train multi-MLP JAX model with Delta-E loss. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained Munsell → xyY approximator for round-trip loss + 2. Loads training and validation data from cache + 3. Normalizes inputs and outputs to [0, 1] range + 4. Trains multi-MLP with combined MSE + Delta-E CIE2000 loss + 5. Uses gradient clipping and AdamW optimizer for stability + 6. Early stopping based on validation Delta-E + 7. Saves best model weights and normalization parameters + + The combined loss enables end-to-end differentiable training with + perceptual color difference metrics via round-trip conversion. + """ + # Combined MSE + Delta-E loss with lower learning rate for stability + train_multi_mlp_jax(mse_weight=1.0, delta_e_weight=1.0, lr=1e-4) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp_large.py b/learning_munsell/training/from_xyY/train_multi_mlp_large.py new file mode 100644 index 0000000000000000000000000000000000000000..ef247368b5da4acba300ba90242342b6fdb7f55b --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp_large.py @@ -0,0 +1,246 @@ +""" +Train multi-MLP model on large dataset (2M samples) for xyY to Munsell conversion. + +This script trains on the larger dataset for potentially improved accuracy. +Uses the same architecture as train_multi_mlp.py but with the large dataset. +""" + +import logging +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=2048, help="Batch size for training") +@click.option("--lr", default=3.41e-4, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train multi-MLP model on large dataset (2M samples) for xyY to Munsell. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from large cached .npz file + 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1] + 3. Creates multi-MLP with 4 independent component branches + 4. Trains with weighted MSE loss (emphasizing chroma 5x) + 5. Uses Adam optimizer with ReduceLROnPlateau scheduler + 6. Applies early stopping based on validation loss (patience=30) + 7. Exports best model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP Model Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Load large training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders (larger batch size for larger dataset) + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = MultiMLPToMunsell().to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + learning_rate = lr + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + # Use weighted MSE with default weights + weights = torch.tensor([1.0, 1.0, 5.0, 0.4]) + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_mlp_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_mlp_large", + "learning_rate": learning_rate, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_mlp_large_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + onnx_file = model_directory / "multi_mlp_large.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + params_file = model_directory / "multi_mlp_large_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor.py b/learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d07983e9abe5a547d32b9f3c6771da705b08d78f --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor.py @@ -0,0 +1,474 @@ +""" +Train Multi-MLP error predictor for Multi-MLP base model. + +Architecture: +- 4 independent error correction branches (one per component) +- Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output + +Two configurations available: +1. Default (non-optimized): + - lr: 3e-4, batch_size: 1024, chroma_width: 1.5, precision_focused_loss + +2. Optimized (--optimized flag): + - lr: 0.0008, batch_size: 512, dropout: 0.15, pure MSE loss + - All branches use standard width (1.0x) +""" + +import logging +from pathlib import Path +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPErrorPredictorToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Pure mean squared error loss function. + + Hyperparameter search (Optuna) showed that pure MSE outperforms complex + loss combinations for this multi-branch error predictor architecture. + + Parameters + ---------- + pred : torch.Tensor + Predicted error corrections of shape (batch_size, 4). + target : torch.Tensor + Target error corrections of shape (batch_size, 4). + + Returns + ------- + torch.Tensor + Scalar MSE loss value. + + Notes + ----- + This simple loss was found to be optimal through hyperparameter search, + suggesting that the independent branch architecture naturally handles + component-specific error characteristics without needing complex loss terms. + """ + return torch.mean((pred - target) ** 2) + + +def precision_focused_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute precision-focused loss for small residual errors. + + Combines multiple loss terms to heavily penalize small errors, + which is important for achieving sub-JND (Just Noticeable Difference) + accuracy in color prediction. + + Parameters + ---------- + pred : Tensor + Predicted values of shape (batch_size, n_components). + target : Tensor + Target values of shape (batch_size, n_components). + + Returns + ------- + Tensor + Scalar loss value. + + Notes + ----- + The loss combines four components: + - MSE: Standard mean squared error (weight 2.0) + - Huber: Small beta (0.01) for precision on small errors (weight 0.5) + - Log penalty: Penalizes small errors heavily (weight 0.1) + """ + mse = torch.mean((pred - target) ** 2) + huber = torch.nn.functional.smooth_l1_loss(pred, target, beta=0.01) + log_penalty = torch.mean(torch.log1p(torch.abs(pred - target))) + return 2.0 * mse + 0.5 * huber + 0.1 * log_penalty + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the Multi-MLP base ONNX model and normalization parameters. + + Parameters + ---------- + model_path : Path + Path to the Multi-MLP base model ONNX file. + params_path : Path + Path to the .npz file containing normalization parameters. + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session for the base model. + input_params : dict + Dictionary containing input normalization ranges. + output_params : dict + Dictionary containing output normalization ranges. + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +@click.command() +@click.option( + "--base-model", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to Multi-MLP base model ONNX file", +) +@click.option( + "--params", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Path to normalization params file", +) +@click.option( + "--optimized", + is_flag=True, + default=False, + help="Use optimized hyperparameters (Optuna results)", +) +@click.option( + "--epochs", + type=int, + default=200, + help="Number of training epochs", +) +@click.option( + "--batch-size", + type=int, + default=None, + help="Batch size for training (default: 1024 or 512 if optimized)", +) +@click.option( + "--lr", + type=float, + default=None, + help="Learning rate (default: 3e-4 or 5e-4 if optimized)", +) +@click.option( + "--patience", + type=int, + default=20, + help="Patience for early stopping", +) +def main( + base_model: Path | None, + params: Path | None, + optimized: bool, + epochs: int, + batch_size: int | None, + lr: float | None, + patience: int, +) -> None: + """ + Train Multi-MLP error predictor with 4 independent branches. + + Parameters + ---------- + base_model : Path or None + Path to Multi-MLP base model ONNX file. Uses default if None. + params : Path or None + Path to normalization parameters file. Uses default if None. + optimized : bool + If True, use optimized hyperparameters from Optuna search. + If False, use original configuration. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses precision-focused loss function (or MSE if optimized) + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + + config_name = "Optimized" if optimized else "Original" + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP Multi-Error Predictor: %s Configuration", config_name) + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + if base_model is None: + base_model = model_directory / "multi_mlp.onnx" + if params is None: + params = model_directory / "multi_mlp_normalization_params.npz" + + cache_file = data_dir / "training_data.npz" + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-MLP base model from %s...", base_model) + base_session, input_params, output_params = load_base_model(base_model, params) + + # Load training data + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Multi-MLP base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + # Base predictions (normalized) + base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0] + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0] + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Multi-MLP base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + # Configuration based on --optimized flag + if optimized: + default_batch_size = 512 + default_lr = 5e-4 # Reduced from 8e-4 for stability + dropout = 0.15 + chroma_width = 1.0 + criterion = mse_loss + suffix = "_optimized" + else: + default_batch_size = 1024 + default_lr = 3e-4 + dropout = 0.0 + chroma_width = 1.5 + criterion = precision_focused_loss + suffix = "" + + # Use provided values or defaults + batch_size = batch_size if batch_size is not None else default_batch_size + lr = lr if lr is not None else default_lr + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-MLP error predictor + model = MultiMLPErrorPredictorToMunsell(chroma_width=chroma_width).to(device) + LOGGER.info("") + LOGGER.info("Multi-MLP error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Count parameters per branch + hue_params = sum(p.numel() for p in model.hue_branch.parameters()) + value_params = sum(p.numel() for p in model.value_branch.parameters()) + chroma_params = sum(p.numel() for p in model.chroma_branch.parameters()) + code_params = sum(p.numel() for p in model.code_branch.parameters()) + + LOGGER.info(" - Hue branch: %s", f"{hue_params:,}") + LOGGER.info(" - Value branch: %s", f"{value_params:,}") + LOGGER.info(" - Chroma branch: %s", f"{chroma_params:,}") + LOGGER.info(" - Code branch: %s", f"{code_params:,}") + + # Training setup + LOGGER.info("") + loss_name = "pure MSE" if optimized else "precision_focused_loss" + LOGGER.info("Using %s loss", loss_name) + LOGGER.info( + "Hyperparameters: lr=%.4f, batch_size=%d, dropout=%.2f, chroma_width=%.1f", + lr, + batch_size, + dropout, + chroma_width, + ) + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + + # MLflow setup + model_name = f"multi_mlp_error_predictor{suffix}" + run_name = setup_mlflow_experiment("from_xyY", model_name) + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + # Log parameters + mlflow.log_params( + { + "model": model_name, + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "dropout": dropout, + "chroma_width": chroma_width, + "optimized": optimized, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Update learning rate + scheduler.step(val_loss) + + # Log to MLflow + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory + / f"multi_mlp_multi_error_predictor{suffix}_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Multi-MLP error predictor to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input (xyY_norm + base_pred_norm = 7 inputs) + dummy_input = torch.randn(1, 7).to(device) + + # Export + onnx_file = model_directory / f"multi_mlp_multi_error_predictor{suffix}.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("Multi-MLP error predictor ONNX model saved to: %s", onnx_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor_large.py b/learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor_large.py new file mode 100644 index 0000000000000000000000000000000000000000..2a251953df1e1ca81588d2cb6339736325efe4fb --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor_large.py @@ -0,0 +1,338 @@ +""" +Train Multi-MLP error predictor on large dataset (2M samples). + +This script trains the error predictor on the larger dataset for potentially +improved accuracy. Uses the Multi-MLP (Large) base model. +""" + +import logging +from pathlib import Path +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPErrorPredictorToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load Multi-MLP (Large) base ONNX model and normalization parameters. + + Parameters + ---------- + model_path : Path + Path to ONNX model file. + params_path : Path + Path to normalization parameters file. + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session. + input_params : dict + Input normalization ranges. + output_params : dict + Output normalization ranges. + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +@click.command() +@click.option( + "--epochs", + type=int, + default=300, + help="Number of training epochs (default: 300)", +) +@click.option( + "--batch-size", + type=int, + default=2048, + help="Batch size for training (default: 2048)", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate (default: 3e-4)", +) +@click.option( + "--patience", + type=int, + default=30, + help="Early stopping patience (default: 30)", +) +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train Multi-MLP error predictor on large dataset (2M samples). + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses precision-focused loss function + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP Error Predictor Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_directory / "multi_mlp_large.onnx" + params_path = model_directory / "multi_mlp_large_normalization_params.npz" + cache_file = data_dir / "training_data_large.npz" + + # Check base model exists + if not base_model_path.exists(): + LOGGER.error("Error: Base model not found at %s", base_model_path) + LOGGER.error("Please run train_multi_mlp_large.py first") + return + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-MLP (Large) base model from %s...", base_model_path) + base_session, input_params, output_params = load_base_model( + base_model_path, params_path + ) + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Multi-MLP (Large) base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + # Process in batches to avoid memory issues + batch_size_inference = 50000 + base_pred_train_list = [] + for i in range(0, len(X_train_norm), batch_size_inference): + batch = X_train_norm[i : i + batch_size_inference] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train_list.append(pred) + base_pred_train_norm = np.concatenate(base_pred_train_list, axis=0) + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + base_pred_val_list = [] + for i in range(0, len(X_val_norm), batch_size_inference): + batch = X_val_norm[i : i + batch_size_inference] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val_list.append(pred) + base_pred_val_norm = np.concatenate(base_pred_val_list, axis=0) + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Multi-MLP (Large) base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-MLP error predictor + model = MultiMLPErrorPredictorToMunsell(chroma_width=1.5).to(device) + LOGGER.info("") + LOGGER.info("Multi-MLP error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + criterion = precision_focused_loss + + LOGGER.info("") + LOGGER.info("Using precision_focused_loss") + LOGGER.info("Hyperparameters: lr=%.4f, batch_size=%d", lr, batch_size) + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_mlp_error_predictor_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_mlp_error_predictor_large", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / "multi_mlp_multi_error_predictor_large_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + "output_params": output_params, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Multi-MLP error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 7).to(device) + + onnx_file = model_directory / "multi_mlp_multi_error_predictor_large.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("Multi-MLP error predictor ONNX model saved to: %s", onnx_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp_weighted_boundary.py b/learning_munsell/training/from_xyY/train_multi_mlp_weighted_boundary.py new file mode 100644 index 0000000000000000000000000000000000000000..7234cef42979704d4ac369cfd5c2ace77a0e0688 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp_weighted_boundary.py @@ -0,0 +1,466 @@ +""" +Train multi-MLP model with weighted sampling and boundary-aware loss. + +This script combines the Multi-MLP architecture (4 independent MLP +branches) with two approaches to improve performance on problematic +high-value, high-chroma regions (Y/GY/G hues): + +1. Weighted Training: Apply higher loss weights to samples in problem regions + (Y/GY/G hues, value >= 8, chroma >= 12) + +2. Boundary-Aware Loss: Add penalty term when chroma prediction exceeds + maximum valid chroma for the (hue_code, value) combination from real + Munsell gamut boundaries. +""" + +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) + +LOGGER = logging.getLogger(__name__) + +# Maximum valid chroma per (hue_code, value) from MUNSELL_COLOURS_REAL +MUNSELL_MAX_CHROMA = { + 1: {1.0: 8.0, 2.0: 10.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 16.0, 7.0: 16.0, 8.0: 12.0, 9.0: 4.0}, + 2: {1.0: 8.0, 2.0: 14.0, 3.0: 20.0, 4.0: 24.0, 5.0: 24.0, + 6.0: 22.0, 7.0: 22.0, 8.0: 18.0, 9.0: 10.0}, + 3: {1.0: 8.0, 2.0: 16.0, 3.0: 22.0, 4.0: 26.0, 5.0: 28.0, + 6.0: 28.0, 7.0: 26.0, 8.0: 24.0, 9.0: 16.0}, + 4: {1.0: 6.0, 2.0: 12.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 20.0, 7.0: 22.0, 8.0: 24.0, 9.0: 18.0}, + 5: {1.0: 2.0, 2.0: 4.0, 3.0: 6.0, 4.0: 10.0, 5.0: 12.0, + 6.0: 14.0, 7.0: 16.0, 8.0: 20.0, 9.0: 20.0}, + 6: {1.0: 8.0, 2.0: 8.0, 3.0: 10.0, 4.0: 12.0, 5.0: 16.0, + 6.0: 18.0, 7.0: 20.0, 8.0: 20.0, 9.0: 8.0}, + 7: {1.0: 10.0, 2.0: 14.0, 3.0: 16.0, 4.0: 20.0, 5.0: 20.0, + 6.0: 18.0, 7.0: 16.0, 8.0: 10.0, 9.0: 6.0}, + 8: {1.0: 16.0, 2.0: 20.0, 3.0: 22.0, 4.0: 26.0, 5.0: 26.0, + 6.0: 24.0, 7.0: 20.0, 8.0: 14.0, 9.0: 6.0}, + 9: {1.0: 26.0, 2.0: 30.0, 3.0: 34.0, 4.0: 32.0, 5.0: 30.0, + 6.0: 26.0, 7.0: 22.0, 8.0: 14.0, 9.0: 6.0}, + 10: {1.0: 38.0, 2.0: 38.0, 3.0: 34.0, 4.0: 30.0, 5.0: 22.0, + 6.0: 16.0, 7.0: 12.0, 8.0: 8.0, 9.0: 4.0}, +} + + +def compute_sample_weights( + y: NDArray, + problem_weight: float = 3.0, +) -> NDArray: + """ + Compute per-sample weights, upweighting problem regions. + + Parameters + ---------- + y : ndarray + Munsell specifications [hue, value, chroma, code] of shape (n, 4). + problem_weight : float, optional + Weight multiplier for problem region samples. + + Returns + ------- + ndarray + Per-sample weights of shape (n,). + """ + weights = np.ones(len(y), dtype=np.float32) + + codes = np.round(y[:, 3]).astype(int) + values = y[:, 1] + chromas = y[:, 2] + + is_problem_hue = np.isin(codes, [3, 4, 5]) + is_high_value = values >= 8.0 + is_high_chroma = chromas >= 12.0 + + problem_mask = is_problem_hue & is_high_value & is_high_chroma + weights[problem_mask] = problem_weight + + return weights + + +def build_max_chroma_tensor(device: torch.device) -> torch.Tensor: + """Build a lookup tensor for maximum valid chroma values.""" + max_chroma_tensor = torch.zeros(11, 10, device=device) + + for code, values in MUNSELL_MAX_CHROMA.items(): + for value, max_chroma in values.items(): + max_chroma_tensor[code, int(value)] = max_chroma + + return max_chroma_tensor + + +def boundary_aware_loss( + pred: torch.Tensor, + target: torch.Tensor, + max_chroma_tensor: torch.Tensor, + output_params: dict, + component_weights: torch.Tensor, + sample_weights: torch.Tensor, + boundary_penalty_weight: float = 0.5, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute loss with boundary-aware chroma penalty.""" + chroma_range = output_params["chroma_range"] + value_range = output_params["value_range"] + code_range = output_params["code_range"] + + pred_chroma = pred[:, 2] * (chroma_range[1] - chroma_range[0]) + chroma_range[0] + pred_value = pred[:, 1] * (value_range[1] - value_range[0]) + value_range[0] + pred_code = pred[:, 3] * (code_range[1] - code_range[0]) + code_range[0] + + code_idx = torch.clamp(torch.round(pred_code), 1, 10).long() + value_idx = torch.clamp(torch.round(pred_value), 1, 9).long() + + max_chroma = max_chroma_tensor[code_idx, value_idx] + + chroma_excess = torch.relu(pred_chroma - max_chroma) + boundary_loss = (chroma_excess ** 2).mean() + + mse = (pred - target) ** 2 + weighted_mse = mse * component_weights + + sample_weighted_mse = (weighted_mse.mean(dim=1) * sample_weights).mean() + + total_loss = sample_weighted_mse + boundary_penalty_weight * boundary_loss + + return total_loss, sample_weighted_mse, boundary_loss + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + component_weights: torch.Tensor, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """Train the model for one epoch with boundary-aware loss.""" + model.train() + total_loss_sum = 0.0 + mse_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + for X_batch, y_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, mse_loss, boundary_loss = boundary_aware_loss( + outputs, y_batch, max_chroma_tensor, output_params, + component_weights, w_batch, boundary_penalty_weight + ) + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + total_loss_sum += total_loss.item() + mse_loss_sum += mse_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, mse_loss_sum / n, boundary_loss_sum / n + + +def validate( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + component_weights: torch.Tensor, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """Validate the model with boundary-aware loss.""" + model.eval() + total_loss_sum = 0.0 + mse_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + with torch.no_grad(): + for X_batch, y_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, mse_loss, boundary_loss = boundary_aware_loss( + outputs, y_batch, max_chroma_tensor, output_params, + component_weights, w_batch, boundary_penalty_weight + ) + + total_loss_sum += total_loss.item() + mse_loss_sum += mse_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, mse_loss_sum / n, boundary_loss_sum / n + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=2048, help="Batch size for training") +@click.option("--lr", default=3.41e-4, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the Multi-MLP model with weighted sampling and boundary-aware loss. + + Notes + ----- + This script combines: + - Multi-MLP architecture (4 independent MLP branches) + - Weighted training (3x weight for problem regions) + - Boundary-aware loss (penalty for invalid chroma predictions) + - Large dataset training (2M samples) + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP Model with Weighted Training + Boundary-Aware Loss") + LOGGER.info("=" * 80) + + problem_region_weight = 3.0 + boundary_penalty_weight = 0.5 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Load large training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Compute sample weights + LOGGER.info("") + LOGGER.info("Computing sample weights (problem region weight: %.1f)...", + problem_region_weight) + train_weights = compute_sample_weights(y_train, problem_region_weight) + val_weights = compute_sample_weights(y_val, problem_region_weight) + + n_problem_train = np.sum(train_weights > 1.0) + n_problem_val = np.sum(val_weights > 1.0) + LOGGER.info(" Train problem region samples: %d (%.2f%%)", + n_problem_train, 100 * n_problem_train / len(y_train)) + LOGGER.info(" Val problem region samples: %d (%.2f%%)", + n_problem_val, 100 * n_problem_val / len(y_val)) + + # Normalize outputs + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + w_train_t = torch.FloatTensor(train_weights) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + w_val_t = torch.FloatTensor(val_weights) + + # Create data loaders with weights + train_dataset = TensorDataset(X_train_t, y_train_t, w_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t, w_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-MLP model + model = MultiMLPToMunsell(chroma_width_multiplier=2.0, dropout=0.1).to(device) + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Build max chroma lookup tensor + max_chroma_tensor = build_max_chroma_tensor(device) + LOGGER.info("") + LOGGER.info("Built max chroma lookup table for boundary-aware loss") + LOGGER.info("Boundary penalty weight: %.2f", boundary_penalty_weight) + + # Component weights: [hue, value, chroma, code] + component_weights = torch.tensor([1.0, 1.0, 5.0, 0.4], device=device) + LOGGER.info("Component weights: %s", component_weights.tolist()) + + # Training setup + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_mlp_weighted_boundary") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params({ + "model": "multi_mlp_weighted_boundary", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + "problem_region_weight": problem_region_weight, + "boundary_penalty_weight": boundary_penalty_weight, + "component_weights": component_weights.tolist(), + "chroma_width_multiplier": 2.0, + "dropout": 0.1, + }) + + for epoch in range(epochs): + train_total, train_mse, train_boundary = train_epoch( + model, train_loader, optimizer, device, + max_chroma_tensor, output_params, component_weights, + boundary_penalty_weight + ) + val_total, val_mse, val_boundary = validate( + model, val_loader, device, + max_chroma_tensor, output_params, component_weights, + boundary_penalty_weight + ) + + scheduler.step(val_total) + + log_training_epoch( + epoch, train_total, val_total, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics({ + "train_mse": train_mse, + "train_boundary": train_boundary, + "val_mse": val_mse, + "val_boundary": val_boundary, + }, step=epoch) + + LOGGER.info( + "Epoch %03d/%d - Train: %.6f (mse=%.6f, bnd=%.6f) | " + "Val: %.6f (mse=%.6f, bnd=%.6f) | LR: %.6f", + epoch + 1, epochs, + train_total, train_mse, train_boundary, + val_total, val_mse, val_boundary, + optimizer.param_groups[0]["lr"], + ) + + if val_total < best_val_loss: + best_val_loss = val_total + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_mlp_weighted_boundary_best.pth" + + torch.save({ + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_total, + }, checkpoint_file) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_total) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics({ + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + }) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 3).to(device) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + onnx_file = model_directory / "multi_mlp_weighted_boundary.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + params_file = model_directory / "multi_mlp_weighted_boundary_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_mlp_weighted_boundary_multi_error_predictor.py b/learning_munsell/training/from_xyY/train_multi_mlp_weighted_boundary_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4b12d02a80fffbf40d64af2159b615fa4be14679 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_mlp_weighted_boundary_multi_error_predictor.py @@ -0,0 +1,511 @@ +""" +Train Multi-MLP error predictor with weighted + boundary-aware loss. + +This extends the error predictor training to also apply: +1. Weighted Training: Higher loss weights for problem regions (Y/GY/G hues, + value >= 8, chroma >= 12) +2. Boundary-Aware Loss: Penalty when corrected chroma prediction exceeds + maximum valid chroma for the (hue_code, value) combination + +Uses the weighted boundary Multi-MLP base model for two-stage inference. +""" + +import logging +from pathlib import Path + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPErrorPredictorToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell + +LOGGER = logging.getLogger(__name__) + +# Maximum valid chroma per (hue_code, value) from MUNSELL_COLOURS_REAL +MUNSELL_MAX_CHROMA = { + 1: {1.0: 8.0, 2.0: 10.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 16.0, 7.0: 16.0, 8.0: 12.0, 9.0: 4.0}, + 2: {1.0: 8.0, 2.0: 14.0, 3.0: 20.0, 4.0: 24.0, 5.0: 24.0, + 6.0: 22.0, 7.0: 22.0, 8.0: 18.0, 9.0: 10.0}, + 3: {1.0: 8.0, 2.0: 16.0, 3.0: 22.0, 4.0: 26.0, 5.0: 28.0, + 6.0: 28.0, 7.0: 26.0, 8.0: 24.0, 9.0: 16.0}, + 4: {1.0: 6.0, 2.0: 12.0, 3.0: 14.0, 4.0: 16.0, 5.0: 18.0, + 6.0: 20.0, 7.0: 22.0, 8.0: 24.0, 9.0: 18.0}, + 5: {1.0: 2.0, 2.0: 4.0, 3.0: 6.0, 4.0: 10.0, 5.0: 12.0, + 6.0: 14.0, 7.0: 16.0, 8.0: 20.0, 9.0: 20.0}, + 6: {1.0: 8.0, 2.0: 8.0, 3.0: 10.0, 4.0: 12.0, 5.0: 16.0, + 6.0: 18.0, 7.0: 20.0, 8.0: 20.0, 9.0: 8.0}, + 7: {1.0: 10.0, 2.0: 14.0, 3.0: 16.0, 4.0: 20.0, 5.0: 20.0, + 6.0: 18.0, 7.0: 16.0, 8.0: 10.0, 9.0: 6.0}, + 8: {1.0: 16.0, 2.0: 20.0, 3.0: 22.0, 4.0: 26.0, 5.0: 26.0, + 6.0: 24.0, 7.0: 20.0, 8.0: 14.0, 9.0: 6.0}, + 9: {1.0: 26.0, 2.0: 30.0, 3.0: 34.0, 4.0: 32.0, 5.0: 30.0, + 6.0: 26.0, 7.0: 22.0, 8.0: 14.0, 9.0: 6.0}, + 10: {1.0: 38.0, 2.0: 38.0, 3.0: 34.0, 4.0: 30.0, 5.0: 22.0, + 6.0: 16.0, 7.0: 12.0, 8.0: 8.0, 9.0: 4.0}, +} + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """Load the base ONNX model and normalization parameters.""" + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +def compute_sample_weights( + y: NDArray, + problem_weight: float = 3.0, +) -> NDArray: + """Compute per-sample weights, upweighting problem regions.""" + weights = np.ones(len(y), dtype=np.float32) + + codes = np.round(y[:, 3]).astype(int) + values = y[:, 1] + chromas = y[:, 2] + + is_problem_hue = np.isin(codes, [3, 4, 5]) + is_high_value = values >= 8.0 + is_high_chroma = chromas >= 12.0 + + problem_mask = is_problem_hue & is_high_value & is_high_chroma + weights[problem_mask] = problem_weight + + return weights + + +def build_max_chroma_tensor(device: torch.device) -> torch.Tensor: + """Build a lookup tensor for maximum valid chroma values.""" + max_chroma_tensor = torch.zeros(11, 10, device=device) + + for code, values in MUNSELL_MAX_CHROMA.items(): + for value, max_chroma in values.items(): + max_chroma_tensor[code, int(value)] = max_chroma + + return max_chroma_tensor + + +def weighted_boundary_error_loss( + error_pred: torch.Tensor, + error_target: torch.Tensor, + base_pred_norm: torch.Tensor, + max_chroma_tensor: torch.Tensor, + output_params: dict, + sample_weights: torch.Tensor, + boundary_penalty_weight: float = 0.5, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute weighted + boundary-aware loss for error predictor.""" + # Compute corrected prediction (normalized) + corrected_pred = base_pred_norm + error_pred + + # Denormalize to get actual Munsell values for boundary check + value_range = output_params["value_range"] + chroma_range = output_params["chroma_range"] + code_range = output_params["code_range"] + + pred_value = corrected_pred[:, 1] * (value_range[1] - value_range[0]) + value_range[0] + pred_chroma = corrected_pred[:, 2] * (chroma_range[1] - chroma_range[0]) + chroma_range[0] + pred_code = corrected_pred[:, 3] * (code_range[1] - code_range[0]) + code_range[0] + + # Round code and value for lookup + code_idx = torch.clamp(torch.round(pred_code), 1, 10).long() + value_idx = torch.clamp(torch.round(pred_value), 1, 9).long() + + # Look up max chroma for each sample + max_chroma = max_chroma_tensor[code_idx, value_idx] + + # Boundary violation penalty + chroma_excess = torch.relu(pred_chroma - max_chroma) + boundary_loss = (chroma_excess ** 2).mean() + + # Precision-focused base loss components + mse = torch.mean((error_pred - error_target) ** 2, dim=1) + mae = torch.mean(torch.abs(error_pred - error_target), dim=1) + log_penalty = torch.mean( + torch.log1p(torch.abs(error_pred - error_target) * 1000.0), dim=1 + ) + + # Huber loss with small delta + delta = 0.01 + abs_error = torch.abs(error_pred - error_target) + huber = torch.where( + abs_error <= delta, 0.5 * abs_error**2, delta * (abs_error - 0.5 * delta) + ) + huber_loss = torch.mean(huber, dim=1) + + # Combine base loss components + base_loss = 1.0 * mse + 0.5 * mae + 0.3 * log_penalty + 0.5 * huber_loss + + # Apply sample weights + weighted_base_loss = (base_loss * sample_weights).mean() + + # Total loss + total_loss = weighted_base_loss + boundary_penalty_weight * boundary_loss + + return total_loss, weighted_base_loss, boundary_loss + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """Train the model for one epoch with weighted boundary-aware loss.""" + model.train() + total_loss_sum = 0.0 + base_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + for X_batch, y_batch, base_pred_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + base_pred_batch = base_pred_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, base_loss, boundary_loss = weighted_boundary_error_loss( + outputs, y_batch, base_pred_batch, max_chroma_tensor, output_params, + w_batch, boundary_penalty_weight + ) + + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + total_loss_sum += total_loss.item() + base_loss_sum += base_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, base_loss_sum / n, boundary_loss_sum / n + + +def validate( + model: nn.Module, + dataloader: DataLoader, + device: torch.device, + max_chroma_tensor: torch.Tensor, + output_params: dict, + boundary_penalty_weight: float, +) -> tuple[float, float, float]: + """Validate the model with weighted boundary-aware loss.""" + model.eval() + total_loss_sum = 0.0 + base_loss_sum = 0.0 + boundary_loss_sum = 0.0 + + with torch.no_grad(): + for X_batch, y_batch, base_pred_batch, w_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + base_pred_batch = base_pred_batch.to(device) + w_batch = w_batch.to(device) + + outputs = model(X_batch) + total_loss, base_loss, boundary_loss = weighted_boundary_error_loss( + outputs, y_batch, base_pred_batch, max_chroma_tensor, output_params, + w_batch, boundary_penalty_weight + ) + + total_loss_sum += total_loss.item() + base_loss_sum += base_loss.item() + boundary_loss_sum += boundary_loss.item() + + n = len(dataloader) + return total_loss_sum / n, base_loss_sum / n, boundary_loss_sum / n + + +@click.command() +@click.option("--epochs", type=int, default=300, help="Number of training epochs") +@click.option("--batch-size", type=int, default=2048, help="Batch size for training") +@click.option("--lr", type=float, default=3e-4, help="Learning rate") +@click.option("--patience", type=int, default=30, help="Patience for early stopping") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the Multi-MLP weighted + boundary-aware error predictor. + + This script trains an error predictor that applies: + 1. Weighted Training: Higher loss weights for problem regions + 2. Boundary-Aware Loss: Penalty when corrected chroma exceeds gamut + + Uses the Multi-MLP weighted boundary base model for two-stage inference. + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP Error Predictor: Weighted + Boundary-Aware Loss") + LOGGER.info("=" * 80) + + problem_region_weight = 3.0 + boundary_penalty_weight = 0.5 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_directory / "multi_mlp_weighted_boundary.onnx" + params_path = model_directory / "multi_mlp_weighted_boundary_normalization_params.npz" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + return + + if not base_model_path.exists(): + LOGGER.error("Error: Base model not found at %s", base_model_path) + return + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-MLP weighted boundary base model from %s...", base_model_path) + base_session, input_params, output_params = load_base_model(base_model_path, params_path) + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Compute sample weights + LOGGER.info("") + LOGGER.info("Computing sample weights (problem region weight: %.1f)...", + problem_region_weight) + train_weights = compute_sample_weights(y_train, problem_region_weight) + val_weights = compute_sample_weights(y_val, problem_region_weight) + + n_problem_train = np.sum(train_weights > 1.0) + LOGGER.info(" Train problem region samples: %d (%.2f%%)", + n_problem_train, 100 * n_problem_train / len(y_train)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + inference_batch_size = 50000 + base_pred_train_norm = [] + for i in range(0, len(X_train_norm), inference_batch_size): + batch = X_train_norm[i : i + inference_batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train_norm.append(pred) + base_pred_train_norm = np.concatenate(base_pred_train_norm, axis=0) + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + base_pred_val_norm = [] + for i in range(0, len(X_val_norm), inference_batch_size): + batch = X_val_norm[i : i + inference_batch_size] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val_norm.append(pred) + base_pred_val_norm = np.concatenate(base_pred_val_norm, axis=0) + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + LOGGER.info("") + LOGGER.info("Base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + base_pred_train_t = torch.FloatTensor(base_pred_train_norm) + w_train_t = torch.FloatTensor(train_weights) + + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + base_pred_val_t = torch.FloatTensor(base_pred_val_norm) + w_val_t = torch.FloatTensor(val_weights) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t, base_pred_train_t, w_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t, base_pred_val_t, w_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-MLP error predictor + model = MultiMLPErrorPredictorToMunsell(chroma_width=1.5).to(device) + LOGGER.info("") + LOGGER.info("Multi-MLP error predictor architecture:") + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Build max chroma lookup tensor + max_chroma_tensor = build_max_chroma_tensor(device) + LOGGER.info("") + LOGGER.info("Built max chroma lookup table for boundary-aware loss") + LOGGER.info("Boundary penalty weight: %.2f", boundary_penalty_weight) + LOGGER.info("Problem region weight: %.1f", problem_region_weight) + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment( + "from_xyY", "multi_mlp_weighted_boundary_multi_error_predictor" + ) + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params({ + "model": "multi_mlp_weighted_boundary_multi_error_predictor", + "num_epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "weight_decay": 1e-5, + "optimizer": "AdamW", + "scheduler": "ReduceLROnPlateau", + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + "problem_region_weight": problem_region_weight, + "boundary_penalty_weight": boundary_penalty_weight, + "base_model": "multi_mlp_weighted_boundary", + "chroma_width": 1.5, + }) + + for epoch in range(epochs): + train_total, train_base, train_boundary = train_epoch( + model, train_loader, optimizer, device, + max_chroma_tensor, output_params, boundary_penalty_weight + ) + val_total, val_base, val_boundary = validate( + model, val_loader, device, + max_chroma_tensor, output_params, boundary_penalty_weight + ) + + scheduler.step(val_total) + + log_training_epoch( + epoch, train_total, val_total, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics({ + "train_base": train_base, + "train_boundary": train_boundary, + "val_base": val_base, + "val_boundary": val_boundary, + }, step=epoch) + + LOGGER.info( + "Epoch %03d/%d - Train: %.6f (base=%.6f, bnd=%.6f) | " + "Val: %.6f (base=%.6f, bnd=%.6f) | LR: %.6f", + epoch + 1, epochs, + train_total, train_base, train_boundary, + val_total, val_base, val_boundary, + optimizer.param_groups[0]["lr"], + ) + + if val_total < best_val_loss: + best_val_loss = val_total + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / "multi_mlp_weighted_boundary_multi_error_predictor_best.pth" + ) + + torch.save({ + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_total, + }, checkpoint_file) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_total) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics({ + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + }) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + dummy_input = torch.randn(1, 7).to(device) + + onnx_file = model_directory / "multi_mlp_weighted_boundary_multi_error_predictor.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + LOGGER.info("Error predictor ONNX model saved to: %s", onnx_file) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_resnet_error_predictor_large.py b/learning_munsell/training/from_xyY/train_multi_resnet_error_predictor_large.py new file mode 100644 index 0000000000000000000000000000000000000000..956004899131c55c2b03b4990216d7554bb92f73 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_resnet_error_predictor_large.py @@ -0,0 +1,381 @@ +""" +Train Multi-ResNet error predictor on large dataset (2M samples). + +This script trains the error predictor using true ResNet architecture with skip +connections. Uses the Multi-ResNet (Large) base model for two-stage refinement. +""" + +import logging +from pathlib import Path +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiResNetErrorPredictorToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load Multi-ResNet (Large) base ONNX model and normalization parameters. + + Parameters + ---------- + model_path : Path + Path to ONNX model file. + params_path : Path + Path to normalization parameters file. + + Returns + ------- + session : ort.InferenceSession + ONNX Runtime inference session. + input_params : dict + Input normalization ranges (None if xyY already in [0, 1]). + output_params : dict + Output normalization ranges. + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + # Multi-ResNet uses xyY already in [0, 1] range - no input normalization needed + input_params = ( + params["input_params"].item() if "input_params" in params.files else None + ) + output_params = params["output_params"].item() + return session, input_params, output_params + + +@click.command() +@click.option( + "--epochs", + type=int, + default=300, + help="Number of training epochs (default: 300)", +) +@click.option( + "--batch-size", + type=int, + default=2048, + help="Batch size for training (default: 2048)", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate (default: 3e-4)", +) +@click.option( + "--patience", + type=int, + default=30, + help="Early stopping patience (default: 30)", +) +@click.option( + "--hidden-dim", + type=int, + default=256, + help="Hidden dimension for ResNet blocks (default: 256)", +) +@click.option( + "--num-blocks", + type=int, + default=4, + help="Number of residual blocks per branch (default: 4)", +) +@click.option( + "--chroma-hidden-dim", + type=int, + default=512, + help="Hidden dim for chroma branch (default: 512)", +) +def main( + epochs: int, + batch_size: int, + lr: float, + patience: int, + hidden_dim: int, + num_blocks: int, + chroma_hidden_dim: int, +) -> None: + """ + Train Multi-ResNet error predictor on large dataset (2M samples). + + Notes + ----- + The training pipeline: + 1. Loads pre-trained Multi-ResNet (Large) base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor with true skip connections on these residuals + 5. Uses precision-focused loss function + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-ResNet Error Predictor Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_directory / "multi_resnet_large.onnx" + params_path = model_directory / "multi_resnet_large_normalization_params.npz" + cache_file = data_dir / "training_data_large.npz" + + # Check base model exists + if not base_model_path.exists(): + LOGGER.error("Error: Base model not found at %s", base_model_path) + LOGGER.error("Please run train_multi_resnet_large.py first") + return + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Multi-ResNet (Large) base model from %s...", base_model_path) + base_session, input_params, output_params = load_base_model( + base_model_path, params_path + ) + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Multi-ResNet (Large) base model predictions...") + + # xyY is already in [0, 1] range for this model + X_train_norm = X_train.astype(np.float32) + y_train_norm = normalize_munsell(y_train, output_params) + + # Process in batches to avoid memory issues + batch_size_inference = 50000 + base_pred_train_list = [] + for i in range(0, len(X_train_norm), batch_size_inference): + batch = X_train_norm[i : i + batch_size_inference] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train_list.append(pred) + base_pred_train_norm = np.concatenate(base_pred_train_list, axis=0) + + X_val_norm = X_val.astype(np.float32) + y_val_norm = normalize_munsell(y_val, output_params) + + base_pred_val_list = [] + for i in range(0, len(X_val_norm), batch_size_inference): + batch = X_val_norm[i : i + batch_size_inference] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val_list.append(pred) + base_pred_val_norm = np.concatenate(base_pred_val_list, axis=0) + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Multi-ResNet (Large) base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize Multi-ResNet error predictor + model = MultiResNetErrorPredictorToMunsell( + hidden_dim=hidden_dim, + num_blocks=num_blocks, + chroma_hidden_dim=chroma_hidden_dim, + ).to(device) + + LOGGER.info("") + LOGGER.info("Multi-ResNet error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + criterion = precision_focused_loss + + LOGGER.info("") + LOGGER.info("Using precision_focused_loss") + LOGGER.info("Hyperparameters: lr=%.4f, batch_size=%d", lr, batch_size) + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_resnet_error_predictor_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_resnet_error_predictor_large", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "hidden_dim": hidden_dim, + "num_blocks": num_blocks, + "chroma_hidden_dim": chroma_hidden_dim, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / "multi_resnet_error_predictor_large_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "hidden_dim": hidden_dim, + "num_blocks": num_blocks, + "chroma_hidden_dim": chroma_hidden_dim, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Multi-ResNet error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Move to CPU for ONNX export to avoid MPS issues + model_cpu = model.to("cpu") + dummy_input = torch.randn(1, 7) + + onnx_file = model_directory / "multi_resnet_error_predictor_large.onnx" + torch.onnx.export( + model_cpu, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model_cpu, "model") + + LOGGER.info("Multi-ResNet error predictor ONNX model saved to: %s", onnx_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_multi_resnet_large.py b/learning_munsell/training/from_xyY/train_multi_resnet_large.py new file mode 100644 index 0000000000000000000000000000000000000000..d77cba1fac16511c0d0e3f0a21f084b1f6c29c04 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_multi_resnet_large.py @@ -0,0 +1,271 @@ +""" +Train Multi-ResNet model on large dataset for xyY to Munsell conversion. + +This script trains a true ResNet architecture with skip connections on the +larger dataset (2M samples). Unlike the MLP variants, this uses actual +residual blocks where output = activation(x + f(x)). +""" + +import logging +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiResNetToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=2048, help="Batch size for training") +@click.option("--lr", default=3e-4, help="Learning rate") +@click.option("--patience", default=30, help="Early stopping patience") +@click.option("--hidden-dim", default=256, help="Hidden dimension for ResNet blocks") +@click.option("--num-blocks", default=4, help="Number of residual blocks per branch") +@click.option("--chroma-hidden-dim", default=512, help="Hidden dim for chroma branch") +def main( + epochs: int, + batch_size: int, + lr: float, + patience: int, + hidden_dim: int, + num_blocks: int, + chroma_hidden_dim: int, +) -> None: + """ + Train Multi-ResNet model on large dataset for xyY to Munsell. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from large cached .npz file + 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1] + 3. Creates Multi-ResNet with 4 independent branches using true skip connections + 4. Trains with weighted MSE loss (emphasizing chroma 5x) + 5. Uses Adam optimizer with ReduceLROnPlateau scheduler + 6. Applies early stopping based on validation loss + 7. Exports best model to ONNX format + 8. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-ResNet Model Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Load large training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model with true ResNet architecture + model = MultiResNetToMunsell( + hidden_dim=hidden_dim, + num_blocks=num_blocks, + chroma_hidden_dim=chroma_hidden_dim, + ).to(device) + + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + learning_rate = lr + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + # Use weighted MSE with default weights + weights = torch.tensor([1.0, 1.0, 5.0, 0.4]) + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "multi_resnet_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_resnet_large", + "learning_rate": learning_rate, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "hidden_dim": hidden_dim, + "num_blocks": num_blocks, + "chroma_hidden_dim": chroma_hidden_dim, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "multi_resnet_large_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "hidden_dim": hidden_dim, + "num_blocks": num_blocks, + "chroma_hidden_dim": chroma_hidden_dim, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Move to CPU for ONNX export to avoid MPS issues + model_cpu = model.to("cpu") + dummy_input = torch.randn(1, 3) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + onnx_file = model_directory / "multi_resnet_large.onnx" + torch.onnx.export( + model_cpu, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + ) + + params_file = model_directory / "multi_resnet_large_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model_cpu, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_transformer_large.py b/learning_munsell/training/from_xyY/train_transformer_large.py new file mode 100644 index 0000000000000000000000000000000000000000..57f861b371e16743473574e4c8b4b6e4dd8b3eb7 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_transformer_large.py @@ -0,0 +1,354 @@ +""" +Train Transformer model on large dataset (2M samples) for xyY to Munsell conversion. + +This uses a proper transformer architecture with: +- Feature tokenization (each xyY component becomes a token) +- Multi-head self-attention to capture feature interactions +- [CLS] token for output regression +- Training on 2M samples with 10% perturbations +""" + +import logging +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import TransformerToMunsell +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import weighted_mse_loss + +LOGGER = logging.getLogger(__name__) + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + criterion: Any, + device: torch.device, +) -> float: + """ + Train the model for one epoch. + + Parameters + ---------- + model : nn.Module + The neural network model to train. + dataloader : DataLoader + DataLoader providing training batches (X, y). + optimizer : optim.Optimizer + Optimizer for updating model parameters. + criterion : callable + Loss function that takes (predictions, targets) and returns loss. + device : torch.device + Device to run training on. + + Returns + ------- + float + Average loss for the epoch. + """ + model.train() + total_loss = 0.0 + + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) # noqa: PLW2901 + y_batch = y_batch.to(device) # noqa: PLW2901 + + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate( + model: nn.Module, dataloader: DataLoader, criterion: Any, device: torch.device +) -> float: + """ + Validate the model for one epoch. + + Parameters + ---------- + model : nn.Module + The neural network model to validate. + dataloader : DataLoader + DataLoader providing validation batches (X, y). + criterion : callable + Loss function that takes (predictions, targets) and returns loss. + device : torch.device + Device to run validation on. + + Returns + ------- + float + Average loss for the epoch. + """ + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) # noqa: PLW2901 + y_batch = y_batch.to(device) # noqa: PLW2901 + + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +@click.command() +@click.option("--epochs", default=300, help="Number of training epochs") +@click.option("--batch-size", default=2048, help="Batch size for training") +@click.option("--lr", default=1e-4, help="Learning rate") +@click.option("--patience", default=40, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train transformer model on large dataset (2M samples) for xyY to Munsell. + + Notes + ----- + The training pipeline: + 1. Loads training and validation data from large cached .npz file + 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1] + 3. Creates transformer with 6 blocks, 8 heads, feature tokenization + 4. Trains with weighted MSE loss (emphasizing chroma) + 5. Uses AdamW optimizer with CosineAnnealingWarmRestarts scheduler + 6. Applies gradient clipping (max_norm=1.0) for stability + 7. Applies early stopping based on validation loss (patience=40) + 8. Exports best model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Transformer Model Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Load large training data + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data_large.npz" + + if not cache_file.exists(): + LOGGER.error("Error: Large training data not found at %s", cache_file) + LOGGER.error("Please run generate_large_training_data.py first") + return + + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize model + model = TransformerToMunsell( + num_features=3, + embedding_dim=256, + num_blocks=6, + num_heads=8, + ff_dim=1024, + dropout=0.1, + ).to(device) + + LOGGER.info("") + LOGGER.info("Model architecture:") + LOGGER.info("%s", model) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + learning_rate = lr + optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=50, T_mult=2 + ) + # Component weights: emphasize chroma (3.0), de-emphasize code (0.5) + weights = torch.tensor([1.0, 1.0, 3.0, 0.5]) + criterion = lambda pred, target: weighted_mse_loss(pred, target, weights) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "transformer_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "transformer_large", + "learning_rate": learning_rate, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + "num_blocks": 6, + "num_heads": 8, + "embedding_dim": 256, + "ff_dim": 1024, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "transformer_large_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "output_params": output_params, + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting model to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Move to CPU for ONNX export + model_cpu = model.cpu() + dummy_input = torch.randn(1, 3) + + model_directory = PROJECT_ROOT / "models" / "from_xyY" + onnx_file = model_directory / "transformer_large.onnx" + # Use legacy exporter for better compatibility with transformer architecture + torch.onnx.export( + model_cpu, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + do_constant_folding=True, + dynamo=False, + ) + + params_file = model_directory / "transformer_large_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_transformer_multi_error_predictor_large.py b/learning_munsell/training/from_xyY/train_transformer_multi_error_predictor_large.py new file mode 100644 index 0000000000000000000000000000000000000000..569b32e1b79a109b4310c981625535c956c4978e --- /dev/null +++ b/learning_munsell/training/from_xyY/train_transformer_multi_error_predictor_large.py @@ -0,0 +1,395 @@ +""" +Train Transformer error predictor on large dataset (2M samples). + +This script trains an error predictor for the Transformer base model, +using a multi-head architecture with separate branches for each Munsell component. +""" + +import logging +from pathlib import Path +from typing import Any + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import onnxruntime as ort +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ComponentErrorPredictor +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import normalize_xyY, normalize_munsell +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class TransformerErrorPredictor(nn.Module): + """ + Multi-head error predictor with 4 independent branches. + + Each branch is a ComponentErrorPredictor specialized for one + Munsell component. The chroma branch is wider (1.5x) to handle + the more complex error patterns in chroma prediction. + + Parameters + ---------- + chroma_width : float, optional + Width multiplier for chroma branch. Default is 1.5. + + Attributes + ---------- + hue_branch : ComponentErrorPredictor + Error predictor for hue component (1.0x width). + value_branch : ComponentErrorPredictor + Error predictor for value component (1.0x width). + chroma_branch : ComponentErrorPredictor + Error predictor for chroma component (1.5x width). + code_branch : ComponentErrorPredictor + Error predictor for hue code component (1.0x width). + """ + + def __init__(self, chroma_width: float = 1.5) -> None: + """Initialize the multi-head error predictor.""" + super().__init__() + + self.hue_branch = ComponentErrorPredictor(width_multiplier=1.0) + self.value_branch = ComponentErrorPredictor(width_multiplier=1.0) + self.chroma_branch = ComponentErrorPredictor(width_multiplier=chroma_width) + self.code_branch = ComponentErrorPredictor(width_multiplier=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + x : Tensor + Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7). + + Returns + ------- + Tensor + Concatenated error corrections [hue, value, chroma, code] + of shape (batch_size, 4). + """ + hue_error = self.hue_branch(x) + value_error = self.value_branch(x) + chroma_error = self.chroma_branch(x) + code_error = self.code_branch(x) + return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1) + + +def load_base_model( + model_path: Path, params_path: Path +) -> tuple[ort.InferenceSession, dict, dict]: + """ + Load the base transformer ONNX model and its normalization parameters. + + Parameters + ---------- + model_path : Path + Path to ONNX model file. + params_path : Path + Path to normalization parameters NPZ file. + + Returns + ------- + tuple + (session, input_params, output_params) where: + - session: ONNX Runtime inference session + - input_params: Input normalization parameters dict + - output_params: Output normalization parameters dict + """ + session = ort.InferenceSession(str(model_path)) + params = np.load(params_path, allow_pickle=True) + return session, params["input_params"].item(), params["output_params"].item() + + +@click.command() +@click.option( + "--epochs", + type=int, + default=300, + help="Number of training epochs (default: 300)", +) +@click.option( + "--batch-size", + type=int, + default=2048, + help="Batch size for training (default: 2048)", +) +@click.option( + "--lr", + type=float, + default=3e-4, + help="Learning rate (default: 3e-4)", +) +@click.option( + "--patience", + type=int, + default=30, + help="Early stopping patience (default: 30)", +) +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train transformer error predictor on large 2M sample dataset. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses precision-focused loss function + 6. Learning rate scheduling with ReduceLROnPlateau + 7. Early stopping based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + LOGGER.info("=" * 80) + LOGGER.info("Transformer Error Predictor Training on Large Dataset (2M samples)") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.backends.mps.is_available(): + device = torch.device("mps") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + + base_model_path = model_directory / "transformer_large.onnx" + params_path = model_directory / "transformer_large_normalization_params.npz" + cache_file = data_dir / "training_data_large.npz" + + # Check base model exists + if not base_model_path.exists(): + LOGGER.error("Error: Base model not found at %s", base_model_path) + LOGGER.error("Please run train_transformer_large.py first") + return + + # Load base model + LOGGER.info("") + LOGGER.info("Loading Transformer (Large) base model from %s...", base_model_path) + base_session, input_params, output_params = load_base_model( + base_model_path, params_path + ) + + # Load training data + LOGGER.info("Loading large training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Generate base model predictions + LOGGER.info("") + LOGGER.info("Generating Transformer (Large) base model predictions...") + X_train_norm = normalize_xyY(X_train, input_params) + y_train_norm = normalize_munsell(y_train, output_params) + + # Process in batches to avoid memory issues + batch_size_inference = 50000 + base_pred_train_list = [] + for i in range(0, len(X_train_norm), batch_size_inference): + batch = X_train_norm[i : i + batch_size_inference] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_train_list.append(pred) + base_pred_train_norm = np.concatenate(base_pred_train_list, axis=0) + + X_val_norm = normalize_xyY(X_val, input_params) + y_val_norm = normalize_munsell(y_val, output_params) + + base_pred_val_list = [] + for i in range(0, len(X_val_norm), batch_size_inference): + batch = X_val_norm[i : i + batch_size_inference] + pred = base_session.run(None, {"xyY": batch})[0] + base_pred_val_list.append(pred) + base_pred_val_norm = np.concatenate(base_pred_val_list, axis=0) + + # Compute errors (in normalized space) + error_train = y_train_norm - base_pred_train_norm + error_val = y_val_norm - base_pred_val_norm + + # Statistics + LOGGER.info("") + LOGGER.info("Transformer (Large) base model error statistics (normalized space):") + LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train))) + LOGGER.info(" Std of error: %.6f", np.std(error_train)) + LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train))) + + # Create combined input: [xyY_norm, base_prediction_norm] + X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1) + X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train_combined) + error_train_t = torch.FloatTensor(error_train) + X_val_t = torch.FloatTensor(X_val_combined) + error_val_t = torch.FloatTensor(error_val) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, error_train_t) + val_dataset = TensorDataset(X_val_t, error_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize error predictor + model = TransformerErrorPredictor(chroma_width=1.5).to(device) + LOGGER.info("") + LOGGER.info("Transformer error predictor architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup + criterion = precision_focused_loss + + LOGGER.info("") + LOGGER.info("Using precision_focused_loss") + LOGGER.info("Hyperparameters: lr=%.4f, batch_size=%d", lr, batch_size) + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=10 + ) + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "transformer_error_predictor_large") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "transformer_error_predictor_large", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + "train_samples": len(X_train), + "val_samples": len(X_val), + "dataset": "large_2M", + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + model_directory.mkdir(exist_ok=True) + checkpoint_file = ( + model_directory / "transformer_error_predictor_large_best.pth" + ) + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + "output_params": output_params, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting Transformer error predictor to ONNX...") + model.eval() + + checkpoint = torch.load(checkpoint_file, weights_only=False) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Move to CPU for export + model_cpu = model.cpu() + dummy_input = torch.randn(1, 7) + + onnx_file = model_directory / "transformer_multi_error_predictor_large.onnx" + torch.onnx.export( + model_cpu, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch_size"}, + "error_correction": {0: "batch_size"}, + }, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("Transformer error predictor ONNX model saved to: %s", onnx_file) + LOGGER.info("Artifacts logged to MLflow") + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/from_xyY/train_unified_mlp.py b/learning_munsell/training/from_xyY/train_unified_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..31a6fc1a71d3b2e0d814ddb4474f9257048440d1 --- /dev/null +++ b/learning_munsell/training/from_xyY/train_unified_mlp.py @@ -0,0 +1,352 @@ +""" +Train wider unified MLP model for xyY to Munsell conversion. + +Single end-to-end model with wider MLP architecture: +- Input: 3 features (xyY) +- Architecture: 3 → 512 → 1024 + 8 residual blocks → 512 → 4 +- Output: 4 features (hue, value, chroma, code) +- Uses GELU activations, batch normalization, and precision-focused loss +- Much wider than previous version to match two-stage model capacity +""" + +import logging +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from numpy.typing import NDArray +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ResidualBlock +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + normalize_munsell, +) +from learning_munsell.utilities.losses import precision_focused_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class UnifiedMLP(nn.Module): + """ + Wider Unified MLP for xyY to Munsell conversion. + + Architecture: + - Input: 3 features (xyY normalized) + - Encoder: 3 → 512 → 1024 + - Residual blocks at 1024-dim (4 blocks) + - Decoder: 1024 → 512 → 4 + - Uses GELU activations and residual connections + + Parameters + ---------- + num_residual_blocks : int, optional + Number of residual blocks at 1024-dim. Default is 4. + + Attributes + ---------- + encoder : nn.Sequential + Input encoder: 3 → 512 → 1024. + residual_blocks : nn.ModuleList + Stack of residual blocks at 1024-dim. + decoder : nn.Sequential + Output decoder: 1024 → 512 → 4. + + Notes + ----- + Single end-to-end model with wide MLP architecture designed to + match two-stage model capacity. Uses GELU activation for smoother + gradients and batch normalization for training stability. + """ + + def __init__(self, num_residual_blocks: int = 8) -> None: + """Initialize the unified MLP model.""" + super().__init__() + + # Wider encoder + self.encoder = nn.Sequential( + nn.Linear(3, 512), + nn.GELU(), + nn.BatchNorm1d(512), + nn.Linear(512, 1024), + nn.GELU(), + nn.BatchNorm1d(1024), + ) + + # More residual blocks at wider dimension + self.residual_blocks = nn.ModuleList( + [ResidualBlock(1024) for _ in range(num_residual_blocks)] + ) + + # Wider decoder + self.decoder = nn.Sequential( + nn.Linear(1024, 512), + nn.GELU(), + nn.BatchNorm1d(512), + nn.Linear(512, 4), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through encoder, residual blocks, and decoder. + + Parameters + ---------- + x : Tensor + Input tensor of shape (batch_size, 3) containing normalized xyY values. + + Returns + ------- + Tensor + Output tensor of shape (batch_size, 4) containing normalized Munsell + specifications [hue, value, chroma, code]. + + Notes + ----- + The forward pass consists of three stages: + 1. Encoder: Expands from 3 to 1024 dimensions + 2. Residual blocks: Process at 1024-dim with skip connections + 3. Decoder: Contracts from 1024 to 4 dimensions + """ + # Encode + x = self.encoder(x) + + # Residual blocks + for block in self.residual_blocks: + x = block(x) + + # Decode + return self.decoder(x) + + +@click.command() +@click.option("--epochs", default=200, help="Number of training epochs") +@click.option("--batch-size", default=1024, help="Batch size for training") +@click.option("--lr", default=3e-4, help="Learning rate") +@click.option("--patience", default=20, help="Early stopping patience") +def main(epochs: int, batch_size: int, lr: float, patience: int) -> None: + """ + Train the UnifiedMLP model for xyY to Munsell conversion. + + Notes + ----- + The training pipeline: + 1. Loads normalization parameters from existing config + 2. Loads training data from cache + 3. Normalizes inputs and outputs to [0, 1] range + 4. Creates PyTorch DataLoaders + 5. Initializes UnifiedMLP with residual blocks + 6. Trains with AdamW optimizer and precision-focused loss + 7. Uses learning rate scheduler (ReduceLROnPlateau) + 8. Implements early stopping based on validation loss + 9. Exports best model to ONNX format + 10. Logs all metrics and artifacts to MLflow + """ + + + LOGGER.info("=" * 80) + LOGGER.info("Unified MLP: xyY → Munsell (Single Model)") + LOGGER.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Paths + model_directory = PROJECT_ROOT / "models" / "from_xyY" + data_dir = PROJECT_ROOT / "data" + cache_file = data_dir / "training_data.npz" + + # Load training data + LOGGER.info("") + LOGGER.info("Loading training data from %s...", cache_file) + data = np.load(cache_file) + X_train = data["X_train"] + y_train = data["y_train"] + X_val = data["X_val"] + y_val = data["y_val"] + + LOGGER.info("Train samples: %d", len(X_train)) + LOGGER.info("Validation samples: %d", len(X_val)) + + # Normalize outputs (xyY inputs are already in [0, 1] range) + # Use hardcoded ranges covering the full Munsell space for generalization + output_params = MUNSELL_NORMALIZATION_PARAMS + y_train_norm = normalize_munsell(y_train, output_params) + y_val_norm = normalize_munsell(y_val, output_params) + + # Convert to PyTorch tensors + X_train_t = torch.FloatTensor(X_train) + y_train_t = torch.FloatTensor(y_train_norm) + X_val_t = torch.FloatTensor(X_val) + y_val_t = torch.FloatTensor(y_val_norm) + + # Create data loaders + train_dataset = TensorDataset(X_train_t, y_train_t) + val_dataset = TensorDataset(X_val_t, y_val_t) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + # Initialize unified MLP model + model = UnifiedMLP(num_residual_blocks=4).to(device) + LOGGER.info("") + LOGGER.info("Unified MLP architecture:") + LOGGER.info("%s", model) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Total parameters: %s", f"{total_params:,}") + + # Training setup with precision-focused loss + LOGGER.info("") + LOGGER.info("Using precision-focused loss function:") + LOGGER.info(" - MSE (weight: 1.0)") + LOGGER.info(" - MAE (weight: 0.5)") + LOGGER.info(" - Log penalty for small errors (weight: 0.3)") + LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + criterion = precision_focused_loss + + # MLflow setup + run_name = setup_mlflow_experiment("from_xyY", "unified_mlp") + + LOGGER.info("") + LOGGER.info("MLflow run: %s", run_name) + + # Training loop + best_val_loss = float("inf") + patience_counter = 0 + + LOGGER.info("") + LOGGER.info("Starting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "unified_mlp", + "learning_rate": lr, + "batch_size": batch_size, + "num_epochs": epochs, + "patience": patience, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + train_loss = train_epoch(model, train_loader, optimizer, criterion, device) + val_loss = validate(model, val_loader, criterion, device) + + # Update learning rate + scheduler.step(val_loss) + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + + LOGGER.info( + "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + optimizer.param_groups[0]["lr"], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + + # Save best model + model_directory.mkdir(exist_ok=True) + checkpoint_file = model_directory / "unified_mlp_best.pth" + + torch.save( + { + "model_state_dict": model.state_dict(), + "epoch": epoch, + "val_loss": val_loss, + }, + checkpoint_file, + ) + + LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss) + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info("") + LOGGER.info("Early stopping after %d epochs", epoch + 1) + break + + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_epoch": epoch + 1, + } + ) + + # Export to ONNX + LOGGER.info("") + LOGGER.info("Exporting unified MLP to ONNX...") + model.eval() + + # Load best model + checkpoint = torch.load(checkpoint_file) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Create dummy input (xyY = 3 inputs) + dummy_input = torch.randn(1, 3).to(device) + + # Export + onnx_file = model_directory / "unified_mlp.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_file, + export_params=True, + opset_version=15, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={ + "xyY": {0: "batch_size"}, + "munsell_spec": {0: "batch_size"}, + }, + ) + + # Save normalization parameters alongside model + params_file = model_directory / "unified_mlp_normalization_params.npz" + input_params = XYY_NORMALIZATION_PARAMS + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + + mlflow.log_artifact(str(checkpoint_file)) + mlflow.log_artifact(str(onnx_file)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + + LOGGER.info("Unified MLP ONNX model saved to: %s", onnx_file) + LOGGER.info("Normalization parameters saved to: %s", params_file) + LOGGER.info("Artifacts logged to MLflow") + + + LOGGER.info("=" * 80) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/__init__.py b/learning_munsell/training/to_xyY/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0eef83d6af9400dec7daa981332a5c97a3fdbd04 --- /dev/null +++ b/learning_munsell/training/to_xyY/__init__.py @@ -0,0 +1 @@ +"""Training scripts for Munsell to xyY conversion.""" diff --git a/learning_munsell/training/to_xyY/hyperparameter_search_multi_error_predictor.py b/learning_munsell/training/to_xyY/hyperparameter_search_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..be33d0763b509a4d9c941a4c53ec5db5c9547256 --- /dev/null +++ b/learning_munsell/training/to_xyY/hyperparameter_search_multi_error_predictor.py @@ -0,0 +1,414 @@ +""" +Hyperparameter search for Multi-Error Predictor using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Width multiplier + +Objective: Minimize final prediction MAE (base + error correction) +""" + +from __future__ import annotations + +import logging +from datetime import datetime + +import matplotlib.pyplot as plt +import mlflow +import numpy as np +import optuna +import torch +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPErrorPredictorToxyY, MultiMLPToxyY +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple + +LOGGER = logging.getLogger(__name__) + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize final prediction MAE. + + This function trains a multi-error predictor model with hyperparameters + suggested by Optuna, evaluates it, and returns the final prediction MAE + (combining base model predictions with error corrections). The search space + includes learning rate, batch size, and width multiplier. + + Hyperparameter Search Space + ---------------------------- + - lr : float + Learning rate in log scale from 1e-4 to 1e-3. + - batch_size : {256, 512, 1024} + Batch size for training and validation. + - width_multiplier : float + Network width scaling from 0.75 to 1.5 in steps of 0.25. + + Training Configuration + ---------------------- + - Optimizer: AdamW with weight_decay=1e-4 + - Scheduler: CosineAnnealingLR with T_max=100 + - Loss function: MSE loss on predicted errors + - Early stopping: patience=15 epochs based on average MAE + - Pruning: MedianPruner to stop unpromising trials + + Parameters + ---------- + trial : Trial + Optuna trial object that suggests hyperparameters. + + Returns + ------- + float + Best average MAE across all xyY components after adding error + corrections to base model predictions. Lower is better. + """ + + # Suggest hyperparameters + lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True) + batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024]) + width_multiplier = trial.suggest_float("width_multiplier", 0.75, 1.5, step=0.25) + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" width_multiplier: %.2f", width_multiplier) + + # Set device + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info(" device: %s", device) + + # Load base model (using original multi_mlp for hyperparameter search) + models_dir = PROJECT_ROOT / "models" / "to_xyY" + base_checkpoint = models_dir / "multi_mlp.pth" + + if not base_checkpoint.exists(): + LOGGER.error("Base model not found: %s", base_checkpoint) + raise FileNotFoundError(f"Base model not found: {base_checkpoint}") + + checkpoint = torch.load(base_checkpoint, weights_only=True) + # Original multi_mlp.pth has width_multiplier=1.0 for all branches + base_model = MultiMLPToxyY(width_multiplier=1.0, y_width_multiplier=1.0).to(device) + base_model.load_state_dict(checkpoint["model_state_dict"]) + base_model.eval() + + for param in base_model.parameters(): + param.requires_grad = False + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val (same split as base model) + n_samples = len(munsell_specs) + np.random.seed(42) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + # Compute base model predictions and errors + with torch.no_grad(): + base_pred_train = base_model(X_train.to(device)).cpu() + base_pred_val = base_model(X_val.to(device)).cpu() + + errors_train = y_train - base_pred_train + errors_val = y_val - base_pred_val + + # Create combined inputs: [munsell_norm(4) + base_pred(3)] = 7 features + combined_train = torch.cat([X_train, base_pred_train], dim=1) + combined_val = torch.cat([X_val, base_pred_val], dim=1) + + train_loader = DataLoader( + TensorDataset(combined_train, errors_train), + batch_size=batch_size, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(combined_val, errors_val), + batch_size=batch_size, + shuffle=False, + ) + + LOGGER.info( + " Training samples: %d, Validation samples: %d", len(X_train), len(X_val) + ) + + # Initialize error predictor model + error_model = MultiMLPErrorPredictorToxyY( + width_multiplier=width_multiplier + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in error_model.parameters()) + LOGGER.info(" Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(error_model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) + + # MLflow setup + run_name = setup_mlflow_experiment( + "to_xyY", f"hparam_multi_error_trial_{trial.number}" + ) + + # Training loop with early stopping + num_epochs = 100 # Reduced for hyperparameter search + patience = 15 + best_mae = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "lr": lr, + "batch_size": batch_size, + "width_multiplier": width_multiplier, + "total_params": total_params, + "trial_number": trial.number, + } + ) + + for epoch in range(num_epochs): + # Training + error_model.train() + train_loss = 0.0 + for combined_batch, errors_batch in train_loader: + combined_batch = combined_batch.to(device) + errors_batch = errors_batch.to(device) + optimizer.zero_grad() + pred_errors = error_model(combined_batch) + loss = nn.functional.mse_loss(pred_errors, errors_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(combined_batch) + + train_loss /= len(combined_train) + scheduler.step() + + # Validation + error_model.eval() + val_loss = 0.0 + with torch.no_grad(): + for combined_batch, errors_batch in val_loader: + combined_batch = combined_batch.to(device) + errors_batch = errors_batch.to(device) + pred_errors = error_model(combined_batch) + val_loss += nn.functional.mse_loss( + pred_errors, errors_batch + ).item() * len(combined_batch) + val_loss /= len(combined_val) + + # Compute final prediction MAE (base + error correction) + with torch.no_grad(): + pred_errors = error_model(combined_val.to(device)) + final_pred = base_pred_val.to(device) + pred_errors + mae = torch.mean(torch.abs(final_pred - y_val.to(device)), dim=0).cpu() + avg_mae = mae.mean().item() + + # Log to MLflow + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "final_mae_x": mae[0].item(), + "final_mae_y": mae[1].item(), + "final_mae_Y": mae[2].item(), + "final_mae_avg": avg_mae, + "learning_rate": optimizer.param_groups[0]["lr"], + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train: %.6f, Val: %.6f - " + "Final MAE: x=%.6f, y=%.6f, Y=%.6f, avg=%.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + avg_mae, + ) + + # Early stopping based on average MAE + if avg_mae < best_mae: + best_mae = avg_mae + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + # Report intermediate value for pruning + trial.report(avg_mae, epoch) + + # Handle pruning + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_avg_mae": best_mae, + "final_train_loss": train_loss, + "best_mae_x": mae[0].item(), + "best_mae_y": mae[1].item(), + "best_mae_Y": mae[2].item(), + } + ) + + LOGGER.info(" Final average MAE: %.6f", best_mae) + + return best_mae + + +def main() -> None: + """ + Run hyperparameter search for the Multi-Error Predictor model. + + This function orchestrates an Optuna hyperparameter optimization study to find + the best hyperparameters for the multi-error predictor that improves upon a + base MultiMLP model for Munsell to xyY conversion. + + Study Configuration + ------------------- + - Objective: Minimize average MAE across xyY components + - Number of trials: 20 + - Pruner: MedianPruner with n_startup_trials=3, n_warmup_steps=10 + - Direction: minimize + + Outputs + ------- + - Console logs with trial progress and results + - Text file with detailed results in results/to_xyY/ + - Visualization plots: + - Optimization history showing MAE progression + - Parameter importances showing which hyperparameters matter most + - Parallel coordinate plot showing hyperparameter relationships + - MLflow tracking for each trial + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Error Predictor (to_xyY) Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + # Create study + study = optuna.create_study( + direction="minimize", + study_name="multi_error_predictor_to_xyY_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10), + ) + + # Run optimization + n_trials = 20 # Number of trials to run + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (avg_mae): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "to_xyY" + results_dir.mkdir(exist_ok=True, parents=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"hparam_search_multi_error_predictor_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-Error Predictor (to_xyY) Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best average MAE: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for t in study.trials: + f.write(f"\nTrial {t.number}:\n") + if t.value is not None: + f.write(f" Value: {t.value:.6f}\n") + else: + f.write(" Value: Pruned\n") + f.write(" Params:\n") + for key, value in t.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + # Generate visualizations using matplotlib + from optuna.visualization.matplotlib import ( + plot_optimization_history, + plot_param_importances, + plot_parallel_coordinate, + ) + + # Optimization history + ax = plot_optimization_history(study) + ax.figure.savefig( + results_dir / f"optimization_history_multi_error_predictor_{timestamp}.png", + dpi=150, + ) + plt.close(ax.figure) + + # Parameter importances + ax = plot_param_importances(study) + ax.figure.savefig( + results_dir / f"param_importances_multi_error_predictor_{timestamp}.png", + dpi=150, + ) + plt.close(ax.figure) + + # Parallel coordinate plot + ax = plot_parallel_coordinate(study) + ax.figure.savefig( + results_dir / f"parallel_coordinate_multi_error_predictor_{timestamp}.png", + dpi=150, + ) + plt.close(ax.figure) + + LOGGER.info("Visualizations saved to: %s", results_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/hyperparameter_search_multi_head.py b/learning_munsell/training/to_xyY/hyperparameter_search_multi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..69d79e3b5eece3232213895e57c9254715558b6b --- /dev/null +++ b/learning_munsell/training/to_xyY/hyperparameter_search_multi_head.py @@ -0,0 +1,523 @@ +""" +Hyperparameter search for Multi-Head model (Munsell to xyY) using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Encoder width multiplier (shared encoder capacity) +- Head width multiplier (component-specific head capacity) +- Dropout +- Weight decay + +Objective: Minimize validation loss +""" + +from __future__ import annotations + +import logging +from datetime import datetime + +import matplotlib.pyplot as plt +import mlflow +import numpy as np +import optuna +import torch +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple +from learning_munsell.utilities.losses import weighted_mse_loss +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +class MultiHeadParametric(nn.Module): + """ + Parametric Multi-Head MLP for Munsell to xyY hyperparameter search. + + This architecture uses a shared encoder to learn general color space features + from Munsell specifications, followed by component-specific prediction heads + for each xyY coordinate. The encoder and heads can be independently scaled + via width multipliers. + + Architecture + ------------ + - Shared encoder: 4 -> h1 -> h2 -> h3 (scaled by encoder_width) + - Learns general features from Munsell (hue, value, chroma, code) + - Uses ReLU activation with BatchNorm and optional Dropout + - Component-specific heads: h3 -> h2' -> h1' -> 1 (scaled by head_width) + - Separate heads for x, y, Y predictions + - Each head has its own learned transformation + + Attributes + ---------- + encoder : nn.Sequential + Shared feature extraction network. + x_head : nn.Sequential + Prediction head for x chromaticity component. + y_head : nn.Sequential + Prediction head for y chromaticity component. + Y_head : nn.Sequential + Prediction head for Y luminance component. + + Parameters + ---------- + encoder_width : float, optional + Scaling factor for encoder hidden dimensions (default: 1.0). + Base dimensions: e_h1=128, e_h2=256, e_h3=512. + head_width : float, optional + Scaling factor for head hidden dimensions (default: 1.0). + Base dimensions: h_h1=128, h_h2=256. + dropout : float, optional + Dropout probability applied after each layer (default: 0.0). + If 0, no dropout layers are added. + """ + + def __init__( + self, + encoder_width: float = 1.0, + head_width: float = 1.0, + dropout: float = 0.0, + ) -> None: + super().__init__() + + # Encoder dimensions (shared) + e_h1 = int(128 * encoder_width) + e_h2 = int(256 * encoder_width) + e_h3 = int(512 * encoder_width) + + # Head dimensions (component-specific) + h_h1 = int(128 * head_width) + h_h2 = int(256 * head_width) + + # Shared encoder - learns general color space features + encoder_layers = [ + nn.Linear(4, e_h1), + nn.ReLU(), + nn.BatchNorm1d(e_h1), + ] + + if dropout > 0: + encoder_layers.append(nn.Dropout(dropout)) + + encoder_layers.extend( + [ + nn.Linear(e_h1, e_h2), + nn.ReLU(), + nn.BatchNorm1d(e_h2), + ] + ) + + if dropout > 0: + encoder_layers.append(nn.Dropout(dropout)) + + encoder_layers.extend( + [ + nn.Linear(e_h2, e_h3), + nn.ReLU(), + nn.BatchNorm1d(e_h3), + ] + ) + + if dropout > 0: + encoder_layers.append(nn.Dropout(dropout)) + + self.encoder = nn.Sequential(*encoder_layers) + + # Component-specific heads + def create_head() -> nn.Sequential: + head_layers = [ + nn.Linear(e_h3, h_h2), + nn.ReLU(), + nn.BatchNorm1d(h_h2), + ] + + if dropout > 0: + head_layers.append(nn.Dropout(dropout)) + + head_layers.extend( + [ + nn.Linear(h_h2, h_h1), + nn.ReLU(), + nn.BatchNorm1d(h_h1), + ] + ) + + if dropout > 0: + head_layers.append(nn.Dropout(dropout)) + + head_layers.append(nn.Linear(h_h1, 1)) + + return nn.Sequential(*head_layers) + + self.x_head = create_head() + self.y_head = create_head() + self.Y_head = create_head() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through multi-head network. + + Parameters + ---------- + x : torch.Tensor + Normalized Munsell input of shape (batch_size, 4). + Expected input features: [hue, value, chroma, code]. + + Returns + ------- + torch.Tensor + Predicted xyY values of shape (batch_size, 3). + Columns represent [x, y, Y] chromaticity and luminance. + """ + # Shared feature extraction + features = self.encoder(x) + + # Component-specific predictions + x_coord = self.x_head(features) + y_coord = self.y_head(features) + Y_lum = self.Y_head(features) + + # Concatenate: [x, y, Y] + return torch.cat([x_coord, y_coord, Y_lum], dim=1) + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize validation loss for Multi-Head model. + + This function trains a multi-head model with hyperparameters suggested by + Optuna, evaluates it, and returns the validation loss. The search space + includes learning rate, batch size, encoder/head width multipliers, dropout, + and weight decay. + + Hyperparameter Search Space + ---------------------------- + - lr : float + Learning rate in log scale from 1e-4 to 1e-3. + - batch_size : {256, 512, 1024} + Batch size for training and validation. + - encoder_width : float + Encoder width scaling from 0.75 to 1.5 in steps of 0.25. + Controls capacity of shared feature extractor. + - head_width : float + Head width scaling from 0.75 to 1.5 in steps of 0.25. + Controls capacity of component-specific prediction heads. + - dropout : float + Dropout rate from 0.0 to 0.2 in steps of 0.05. + - weight_decay : float + L2 regularization in log scale from 1e-5 to 1e-3. + + Training Configuration + ---------------------- + - Optimizer: AdamW with configurable weight_decay + - Scheduler: CosineAnnealingLR with T_max=100 + - Loss function: Weighted MSE loss (equal weights for x, y, Y) + - Early stopping: patience=15 epochs based on validation loss + - Pruning: MedianPruner to stop unpromising trials + + Parameters + ---------- + trial : Trial + Optuna trial object that suggests hyperparameters. + + Returns + ------- + float + Best validation loss achieved during training. Lower is better. + """ + + # Suggest hyperparameters + lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True) + batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024]) + encoder_width = trial.suggest_float("encoder_width", 0.75, 1.5, step=0.25) + head_width = trial.suggest_float("head_width", 0.75, 1.5, step=0.25) + dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05) + weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" encoder_width: %.2f", encoder_width) + LOGGER.info(" head_width: %.2f", head_width) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info(" weight_decay: %.6f", weight_decay) + + # Set device + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info(" device: %s", device) + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val + n_samples = len(munsell_specs) + np.random.seed(42) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + train_loader = DataLoader( + TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False + ) + + LOGGER.info( + " Training samples: %d, Validation samples: %d", len(X_train), len(X_val) + ) + + # Initialize model + model = MultiHeadParametric( + encoder_width=encoder_width, + head_width=head_width, + dropout=dropout, + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info(" Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) + + # MLflow setup + run_name = setup_mlflow_experiment( + "to_xyY", f"hparam_multi_head_trial_{trial.number}" + ) + + # Training loop with early stopping + num_epochs = 100 # Reduced for hyperparameter search + patience = 15 + best_val_loss = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "lr": lr, + "batch_size": batch_size, + "encoder_width": encoder_width, + "head_width": head_width, + "dropout": dropout, + "weight_decay": weight_decay, + "total_params": total_params, + "trial_number": trial.number, + } + ) + + for epoch in range(num_epochs): + train_loss = train_epoch(model, train_loader, optimizer, weighted_mse_loss, device) + val_loss = validate(model, val_loader, weighted_mse_loss, device) + scheduler.step() + + # Per-component MAE + with torch.no_grad(): + pred_val = model(X_val.to(device)) + mae = torch.mean(torch.abs(pred_val - y_val.to(device)), dim=0).cpu() + + # Log to MLflow + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "mae_x": mae[0].item(), + "mae_y": mae[1].item(), + "mae_Y": mae[2].item(), + "learning_rate": optimizer.param_groups[0]["lr"], + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train: %.6f, Val: %.6f - " + "MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + # Report intermediate value for pruning + trial.report(val_loss, epoch) + + # Handle pruning + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_train_loss": train_loss, + "final_mae_x": mae[0].item(), + "final_mae_y": mae[1].item(), + "final_mae_Y": mae[2].item(), + } + ) + + LOGGER.info(" Final validation loss: %.6f", best_val_loss) + + return best_val_loss + + +def main() -> None: + """ + Run hyperparameter search for the Multi-Head model. + + This function orchestrates an Optuna hyperparameter optimization study to find + the best hyperparameters for the multi-head architecture for Munsell to xyY + conversion. The multi-head approach uses a shared encoder with separate + prediction heads for each xyY component. + + Study Configuration + ------------------- + - Objective: Minimize validation loss + - Number of trials: 20 + - Pruner: MedianPruner with n_startup_trials=3, n_warmup_steps=10 + - Direction: minimize + + Outputs + ------- + - Console logs with trial progress and results + - Text file with detailed results in results/to_xyY/ + - Visualization plots: + - Optimization history showing loss progression + - Parameter importances showing which hyperparameters matter most + - Parallel coordinate plot showing hyperparameter relationships + - MLflow tracking for each trial + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-Head (to_xyY) Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + # Create study + study = optuna.create_study( + direction="minimize", + study_name="multi_head_to_xyY_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10), + ) + + # Run optimization + n_trials = 20 # Number of trials to run + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (val_loss): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "to_xyY" + results_dir.mkdir(exist_ok=True, parents=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"hparam_search_multi_head_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-Head (to_xyY) Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best validation loss: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for t in study.trials: + f.write(f"\nTrial {t.number}:\n") + if t.value is not None: + f.write(f" Value: {t.value:.6f}\n") + else: + f.write(" Value: Pruned\n") + f.write(" Params:\n") + for key, value in t.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + # Generate visualizations using matplotlib + from optuna.visualization.matplotlib import ( + plot_optimization_history, + plot_param_importances, + plot_parallel_coordinate, + ) + + # Optimization history + ax = plot_optimization_history(study) + ax.figure.savefig( + results_dir / f"optimization_history_multi_head_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parameter importances + ax = plot_param_importances(study) + ax.figure.savefig( + results_dir / f"param_importances_multi_head_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parallel coordinate plot + ax = plot_parallel_coordinate(study) + ax.figure.savefig( + results_dir / f"parallel_coordinate_multi_head_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + LOGGER.info("Visualizations saved to: %s", results_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/hyperparameter_search_multi_mlp.py b/learning_munsell/training/to_xyY/hyperparameter_search_multi_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..4beebc09d91de29fea1081e565edaf0ad06cc0e6 --- /dev/null +++ b/learning_munsell/training/to_xyY/hyperparameter_search_multi_mlp.py @@ -0,0 +1,370 @@ +""" +Hyperparameter search for Multi-MLP model (Munsell to xyY) using Optuna. + +Optimizes: +- Learning rate +- Batch size +- Width multiplier (network capacity) +- Y branch width multiplier (luminance specialization) +- Dropout +- Weight decay + +Objective: Minimize validation loss +""" + +from __future__ import annotations + +import logging +from datetime import datetime + +import matplotlib.pyplot as plt +import mlflow +import numpy as np +import optuna +import torch +from optuna.trial import Trial +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ComponentMLP, MultiMLPToxyY +from learning_munsell.utilities.common import setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple +from learning_munsell.utilities.training import train_epoch, validate + +LOGGER = logging.getLogger(__name__) + + +def objective(trial: Trial) -> float: + """ + Optuna objective function to minimize validation loss for Multi-MLP model. + + This function trains a multi-MLP model with hyperparameters suggested by + Optuna, evaluates it, and returns the validation loss. The search space + includes learning rate, batch size, width multipliers, dropout, and weight decay. + + Hyperparameter Search Space + ---------------------------- + - lr : float + Learning rate in log scale from 1e-4 to 1e-3. + - batch_size : {256, 512, 1024} + Batch size for training and validation. + - width_multiplier : float + Width scaling for x and y branches from 0.75 to 1.5 in steps of 0.25. + Controls network capacity for chromaticity predictions. + - y_width_multiplier : float + Width scaling for Y branch from 0.75 to 1.5 in steps of 0.25. + Allows specialized capacity for luminance prediction. + - dropout : float + Dropout rate from 0.0 to 0.2 in steps of 0.05. + - weight_decay : float + L2 regularization in log scale from 1e-5 to 1e-3. + + Training Configuration + ---------------------- + - Optimizer: AdamW with configurable weight_decay + - Scheduler: CosineAnnealingLR with T_max=100 + - Loss function: MSE loss + - Early stopping: patience=15 epochs based on validation loss + - Pruning: MedianPruner to stop unpromising trials + + Parameters + ---------- + trial : Trial + Optuna trial object that suggests hyperparameters. + + Returns + ------- + float + Best validation loss achieved during training. Lower is better. + """ + + # Suggest hyperparameters + lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True) + batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024]) + width_multiplier = trial.suggest_float("width_multiplier", 0.75, 1.5, step=0.25) + y_width_multiplier = trial.suggest_float("y_width_multiplier", 0.75, 1.5, step=0.25) + dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05) + weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) + + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Trial %d", trial.number) + LOGGER.info("=" * 80) + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" width_multiplier: %.2f", width_multiplier) + LOGGER.info(" y_width_multiplier: %.2f", y_width_multiplier) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info(" weight_decay: %.6f", weight_decay) + + # Set device + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info(" device: %s", device) + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val + n_samples = len(munsell_specs) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + train_loader = DataLoader( + TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False + ) + + LOGGER.info( + " Training samples: %d, Validation samples: %d", len(X_train), len(X_val) + ) + + # Initialize model + model = MultiMLPToxyY( + width_multiplier=width_multiplier, + y_width_multiplier=y_width_multiplier, + dropout=dropout, + ).to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info(" Total parameters: %s", f"{total_params:,}") + + # Training setup + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) + + # MLflow setup + run_name = setup_mlflow_experiment( + "to_xyY", f"hparam_multi_mlp_trial_{trial.number}" + ) + + # Training loop with early stopping + num_epochs = 100 # Reduced for hyperparameter search + patience = 15 + best_val_loss = float("inf") + patience_counter = 0 + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "lr": lr, + "batch_size": batch_size, + "width_multiplier": width_multiplier, + "y_width_multiplier": y_width_multiplier, + "dropout": dropout, + "weight_decay": weight_decay, + "total_params": total_params, + "trial_number": trial.number, + } + ) + + for epoch in range(num_epochs): + train_loss = train_epoch(model, train_loader, optimizer, nn.functional.mse_loss, device) + val_loss = validate(model, val_loader, nn.functional.mse_loss, device) + scheduler.step() + + # Per-component MAE + with torch.no_grad(): + pred_val = model(X_val.to(device)) + mae = torch.mean(torch.abs(pred_val - y_val.to(device)), dim=0).cpu() + + # Log to MLflow + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "mae_x": mae[0].item(), + "mae_y": mae[1].item(), + "mae_Y": mae[2].item(), + "learning_rate": optimizer.param_groups[0]["lr"], + }, + step=epoch, + ) + + if (epoch + 1) % 10 == 0: + LOGGER.info( + " Epoch %03d/%d - Train: %.6f, Val: %.6f - " + "MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + num_epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + LOGGER.info(" Early stopping at epoch %d", epoch + 1) + break + + # Report intermediate value for pruning + trial.report(val_loss, epoch) + + # Handle pruning + if trial.should_prune(): + LOGGER.info(" Trial pruned at epoch %d", epoch + 1) + mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch}) + raise optuna.TrialPruned + + # Log final results + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_train_loss": train_loss, + "final_mae_x": mae[0].item(), + "final_mae_y": mae[1].item(), + "final_mae_Y": mae[2].item(), + } + ) + + LOGGER.info(" Final validation loss: %.6f", best_val_loss) + + return best_val_loss + + +def main() -> None: + """ + Run hyperparameter search for the Multi-MLP model. + + This function orchestrates an Optuna hyperparameter optimization study to find + the best hyperparameters for the multi-MLP architecture for Munsell to xyY + conversion. The multi-MLP approach uses three independent MLP branches, + one for each xyY component. + + Study Configuration + ------------------- + - Objective: Minimize validation loss + - Number of trials: 20 + - Pruner: MedianPruner with n_startup_trials=3, n_warmup_steps=10 + - Direction: minimize + + Outputs + ------- + - Console logs with trial progress and results + - Text file with detailed results in results/to_xyY/ + - Visualization plots: + - Optimization history showing loss progression + - Parameter importances showing which hyperparameters matter most + - Parallel coordinate plot showing hyperparameter relationships + - MLflow tracking for each trial + """ + + LOGGER.info("=" * 80) + LOGGER.info("Multi-MLP (to_xyY) Hyperparameter Search with Optuna") + LOGGER.info("=" * 80) + + # Create study + study = optuna.create_study( + direction="minimize", + study_name="multi_mlp_to_xyY_hparam_search", + pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10), + ) + + # Run optimization + n_trials = 20 # Number of trials to run + + LOGGER.info("") + LOGGER.info("Starting hyperparameter search with %d trials...", n_trials) + LOGGER.info("") + + study.optimize(objective, n_trials=n_trials, timeout=None) + + # Print results + LOGGER.info("") + LOGGER.info("=" * 80) + LOGGER.info("Hyperparameter Search Results") + LOGGER.info("=" * 80) + LOGGER.info("") + LOGGER.info("Best trial:") + LOGGER.info(" Value (val_loss): %.6f", study.best_value) + LOGGER.info("") + LOGGER.info("Best hyperparameters:") + for key, value in study.best_params.items(): + LOGGER.info(" %s: %s", key, value) + + # Save results + results_dir = PROJECT_ROOT / "results" / "to_xyY" + results_dir.mkdir(exist_ok=True, parents=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + results_file = results_dir / f"hparam_search_multi_mlp_{timestamp}.txt" + + with open(results_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("Multi-MLP (to_xyY) Hyperparameter Search Results\n") + f.write("=" * 80 + "\n\n") + f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Number of trials: {len(study.trials)}\n") + f.write(f"Best validation loss: {study.best_value:.6f}\n\n") + f.write("Best hyperparameters:\n") + for key, value in study.best_params.items(): + f.write(f" {key}: {value}\n") + f.write("\n\nAll trials:\n") + f.write("-" * 80 + "\n") + + for t in study.trials: + f.write(f"\nTrial {t.number}:\n") + if t.value is not None: + f.write(f" Value: {t.value:.6f}\n") + else: + f.write(" Value: Pruned\n") + f.write(" Params:\n") + for key, value in t.params.items(): + f.write(f" {key}: {value}\n") + + LOGGER.info("") + LOGGER.info("Results saved to: %s", results_file) + + # Generate visualizations using matplotlib + from optuna.visualization.matplotlib import ( + plot_optimization_history, + plot_param_importances, + plot_parallel_coordinate, + ) + + # Optimization history + ax = plot_optimization_history(study) + ax.figure.savefig( + results_dir / f"optimization_history_multi_mlp_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parameter importances + ax = plot_param_importances(study) + ax.figure.savefig( + results_dir / f"param_importances_multi_mlp_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + # Parallel coordinate plot + ax = plot_parallel_coordinate(study) + ax.figure.savefig( + results_dir / f"parallel_coordinate_multi_mlp_{timestamp}.png", dpi=150 + ) + plt.close(ax.figure) + + LOGGER.info("Visualizations saved to: %s", results_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/train_multi_head.py b/learning_munsell/training/to_xyY/train_multi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e46ae88a056ca77a32a2dbe3d8617e8fd21fac92 --- /dev/null +++ b/learning_munsell/training/to_xyY/train_multi_head.py @@ -0,0 +1,417 @@ +""" +Train multi-head ML model for Munsell to xyY conversion. + +Architecture: +- Shared encoder: 4 inputs (Munsell) → 512-dim features +- 3 separate heads (one per component): + - x head (chromaticity coordinate) + - y head (chromaticity coordinate) + - Y head (luminance) + +This architecture allows each component to learn specialized features +while sharing the general color space understanding. +""" + +from __future__ import annotations + +import copy +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple +from learning_munsell.utilities.losses import weighted_mse_loss + +LOGGER = logging.getLogger(__name__) + + +class MultiHead(nn.Module): + """ + Multi-head model for Munsell to xyY conversion. + + Architecture features a shared encoder followed by component-specific + prediction heads. The shared encoder learns general color space features, + while each head specializes in predicting one component. + + Attributes + ---------- + encoder : nn.Sequential + Shared encoder network: 4 → 128 → 256 → 512 with ReLU and BatchNorm. + x_head : nn.Sequential + Prediction head for x chromaticity coordinate: 512 → 256 → 128 → 1. + y_head : nn.Sequential + Prediction head for y chromaticity coordinate: 512 → 256 → 128 → 1. + Y_head : nn.Sequential + Prediction head for Y luminance: 512 → 256 → 128 → 1. + + Notes + ----- + The multi-head architecture allows each xyY component to learn specialized + features while sharing the general color space understanding through the + encoder. + """ + + def __init__(self) -> None: + """Initialize the multi-head model.""" + super().__init__() + + # Shared encoder - learns general color space features + self.encoder = nn.Sequential( + nn.Linear(4, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 512), + nn.ReLU(), + nn.BatchNorm1d(512), + ) + + # x head - chromaticity coordinate + self.x_head = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + # y head - chromaticity coordinate + self.y_head = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + # Y head - luminance + self.Y_head = nn.Sequential( + nn.Linear(512, 256), + nn.ReLU(), + nn.BatchNorm1d(256), + nn.Linear(256, 128), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Linear(128, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through shared encoder and prediction heads. + + Parameters + ---------- + x : Tensor + Normalized Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + + Returns + ------- + Tensor + Predicted xyY values [x, y, Y] of shape (batch_size, 3). + """ + # Shared feature extraction + features = self.encoder(x) + + # Component-specific predictions + x_coord = self.x_head(features) + y_coord = self.y_head(features) + Y_lum = self.Y_head(features) + + # Concatenate: [x, y, Y] + return torch.cat([x_coord, y_coord, Y_lum], dim=1) + + +@click.command() +@click.option("--epochs", default=300, help="Maximum number of training epochs.") +@click.option("--batch-size", default=512, help="Training batch size.") +@click.option("--lr", default=5e-4, help="Initial learning rate for AdamW optimizer.") +def main( + epochs: int = 300, batch_size: int = 512, lr: float = 5e-4 +) -> tuple[MultiHead, float]: + """ + Train the Multi-Head model for Munsell to xyY conversion. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. Default is 300. + batch_size : int, optional + Training batch size. Default is 512. + lr : float, optional + Initial learning rate for AdamW optimizer. Default is 5e-4. + + Returns + ------- + tuple + A tuple containing: + - model : MultiHead + Trained model with best validation loss weights loaded. + - best_val_loss : float + Best validation loss achieved during training. + + Notes + ----- + The training pipeline: + 1. Loads Munsell training data from cache + 2. Normalizes Munsell specifications to [0, 1] range + 3. Splits data into train/validation sets (90/10) + 4. Trains multi-head model with shared encoder architecture + 5. Uses weighted MSE loss with equal component weights + 6. Learning rate scheduling with CosineAnnealingLR + 7. Early stopping based on validation loss (patience=30) + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val + n_samples = len(munsell_specs) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + train_loader = DataLoader( + TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False + ) + + # Create model + model = MultiHead().to(device) + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Model parameters: %s", f"{total_params:,}") + LOGGER.info( + "Training samples: %d, Validation samples: %d", len(X_train), len(X_val) + ) + + # Count parameters per component + encoder_params = sum(p.numel() for p in model.encoder.parameters()) + x_params = sum(p.numel() for p in model.x_head.parameters()) + y_params = sum(p.numel() for p in model.y_head.parameters()) + Y_params = sum(p.numel() for p in model.Y_head.parameters()) + + LOGGER.info(" - Shared encoder: %s", f"{encoder_params:,}") + LOGGER.info(" - x head: %s", f"{x_params:,}") + LOGGER.info(" - y head: %s", f"{y_params:,}") + LOGGER.info(" - Y head: %s", f"{Y_params:,}") + + # MLflow setup + run_name = setup_mlflow_experiment("to_xyY", "multi_head") + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + patience = 30 + patience_counter = 0 + + LOGGER.info("\nStarting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head", + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + model.train() + train_loss = 0.0 + for X_batch, y_batch in train_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + + optimizer.zero_grad() + pred = model(X_batch) + loss = weighted_mse_loss(pred, y_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + + train_loss /= len(X_train) + scheduler.step() + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + pred = model(X_batch) + val_loss += weighted_mse_loss(pred, y_batch).item() * len(X_batch) + val_loss /= len(X_val) + + # Per-component MAE + with torch.no_grad(): + pred_val = model(X_val.to(device)) + mae = torch.mean(torch.abs(pred_val - y_val.to(device)), dim=0).cpu() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics( + { + "mae_x": mae[0].item(), + "mae_y": mae[1].item(), + "mae_Y": mae[2].item(), + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + model.load_state_dict(best_state) + + # Final evaluation + model.eval() + with torch.no_grad(): + pred_val = model(X_val.to(device)).cpu() + mae_x = torch.mean(torch.abs(pred_val[:, 0] - y_val[:, 0])).item() + mae_y = torch.mean(torch.abs(pred_val[:, 1] - y_val[:, 1])).item() + mae_Y = torch.mean(torch.abs(pred_val[:, 2] - y_val[:, 2])).item() + + LOGGER.info("\nFinal Results:") + LOGGER.info(" Best Val Loss: %.6f", best_val_loss) + LOGGER.info(" MAE x: %.6f", mae_x) + LOGGER.info(" MAE y: %.6f", mae_y) + LOGGER.info(" MAE Y: %.6f", mae_Y) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_mae_x": mae_x, + "final_mae_y": mae_y, + "final_mae_Y": mae_Y, + "final_epoch": epoch + 1, + } + ) + + # Save model + models_dir = PROJECT_ROOT / "models" / "to_xyY" + models_dir.mkdir(exist_ok=True) + + checkpoint_path = models_dir / "multi_head.pth" + torch.save( + { + "model_state_dict": model.state_dict(), + "val_loss": best_val_loss, + "mae": {"x": mae_x, "y": mae_y, "Y": mae_Y}, + }, + checkpoint_path, + ) + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + model.cpu().eval() + dummy_input = torch.randn(1, 4) + onnx_path = models_dir / "multi_head.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_path, + input_names=["munsell_normalized"], + output_names=["xyY"], + dynamic_axes={"munsell_normalized": {0: "batch"}, "xyY": {0: "batch"}}, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Save normalization parameters alongside model + params_file = models_dir / "multi_head_normalization_params.npz" + input_params = { + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (0.0, 10.0), + } + output_params = { + "x_range": (0.0, 1.0), + "y_range": (0.0, 1.0), + "Y_range": (0.0, 1.0), + } + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + LOGGER.info("Artifacts logged to MLflow") + + + return model, best_val_loss + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py b/learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd876aa99a6304d9b6ee84ce5f1c7df54b72f34 --- /dev/null +++ b/learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py @@ -0,0 +1,433 @@ +""" +Train Multi-Head Multi-Error Predictor for Munsell to xyY conversion. + +Architecture: +- 3 independent error correction branches (one per component: x, y, Y) +- Each branch: 7 inputs (munsell_norm + base_pred) -> encoder -> decoder + -> 1 error output +- Uses GELU activation and BatchNorm for better generalization +""" + +from __future__ import annotations + +import copy +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ComponentErrorPredictor +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple +from learning_munsell.training.to_xyY.train_multi_head import MultiHead as MultiHeadMLP + +LOGGER = logging.getLogger(__name__) + + +class MultiHeadMultiErrorPredictor(nn.Module): + """ + Multi-Head error predictor with 3 independent branches. + + Each branch is a ComponentErrorPredictor specialized for one + xyY component. All branches receive the same full context + (munsell_norm + all base predictions). + + Attributes + ---------- + x_branch : ComponentErrorPredictor + Error predictor for x chromaticity component. + y_branch : ComponentErrorPredictor + Error predictor for y chromaticity component. + Y_branch : ComponentErrorPredictor + Error predictor for Y luminance component. + """ + + def __init__(self) -> None: + """Initialize the multi-head error predictor.""" + super().__init__() + + # Independent error predictor for each component + # Input dim is 7: [munsell_norm (4) + base_pred (3)] + self.x_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0) + self.y_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0) + self.Y_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0) + + def forward(self, combined_input: torch.Tensor) -> torch.Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + combined_input : Tensor + Combined input [munsell_norm, base_pred] of shape (batch_size, 7). + + Returns + ------- + Tensor + Concatenated error corrections [x, y, Y] + of shape (batch_size, 3). + """ + x_error = self.x_branch(combined_input) + y_error = self.y_branch(combined_input) + Y_error = self.Y_branch(combined_input) + + return torch.cat([x_error, y_error, Y_error], dim=1) + + +@click.command() +@click.option("--epochs", default=300, help="Maximum number of training epochs.") +@click.option("--batch-size", default=512, help="Training batch size.") +@click.option("--lr", default=8e-4, help="Initial learning rate for AdamW optimizer.") +def main( + epochs: int = 300, + batch_size: int = 512, + lr: float = 8e-4, +) -> MultiHeadMultiErrorPredictor | None: + """ + Train the Multi-Head Multi-Error Predictor for Munsell to xyY conversion. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. Default is 300. + batch_size : int, optional + Training batch size. Default is 512. + lr : float, optional + Initial learning rate for AdamW optimizer. Default is 8e-4. + + Returns + ------- + MultiHeadMultiErrorPredictor or None + Trained error predictor model with best validation loss weights + loaded, or None if base model is not found. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base multi-head model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses MSE loss function + 6. Learning rate scheduling with CosineAnnealingLR + 7. Early stopping based on validation loss (patience=30) + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load base model + models_dir = PROJECT_ROOT / "models" / "to_xyY" + base_checkpoint = models_dir / "multi_head.pth" + + if not base_checkpoint.exists(): + LOGGER.error("Base model not found: %s", base_checkpoint) + LOGGER.error("Please run train_multi_head.py first") + return None + + LOGGER.info("Loading base model from %s...", base_checkpoint) + checkpoint = torch.load(base_checkpoint, weights_only=True) + base_model = MultiHeadMLP().to(device) + base_model.load_state_dict(checkpoint["model_state_dict"]) + base_model.eval() + + for param in base_model.parameters(): + param.requires_grad = False + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val (same split as base model) + n_samples = len(munsell_specs) + np.random.seed(42) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + # Compute base model predictions and errors + LOGGER.info("Computing base model predictions...") + with torch.no_grad(): + base_pred_train = base_model(X_train.to(device)).cpu() + base_pred_val = base_model(X_val.to(device)).cpu() + + errors_train = y_train - base_pred_train + errors_val = y_val - base_pred_val + + LOGGER.info( + "Base model MAE - train: x=%.6f, y=%.6f, Y=%.6f", + torch.mean(torch.abs(errors_train[:, 0])).item(), + torch.mean(torch.abs(errors_train[:, 1])).item(), + torch.mean(torch.abs(errors_train[:, 2])).item(), + ) + + # Create combined inputs: [munsell_norm(4) + base_pred(3)] = 7 features + combined_train = torch.cat([X_train, base_pred_train], dim=1) + combined_val = torch.cat([X_val, base_pred_val], dim=1) + + train_loader = DataLoader( + TensorDataset(combined_train, errors_train), + batch_size=batch_size, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(combined_val, errors_val), + batch_size=batch_size, + shuffle=False, + ) + + # Create error predictor model + error_model = MultiHeadMultiErrorPredictor().to(device) + optimizer = optim.AdamW(error_model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in error_model.parameters()) + LOGGER.info("Multi-Error Predictor parameters: %s", f"{total_params:,}") + + # Count parameters per component + x_params = sum(p.numel() for p in error_model.x_branch.parameters()) + y_params = sum(p.numel() for p in error_model.y_branch.parameters()) + Y_params = sum(p.numel() for p in error_model.Y_branch.parameters()) + + LOGGER.info(" - x error branch: %s", f"{x_params:,}") + LOGGER.info(" - y error branch: %s", f"{y_params:,}") + LOGGER.info(" - Y error branch: %s", f"{Y_params:,}") + + # MLflow setup + run_name = setup_mlflow_experiment("to_xyY", "multi_head_multi_error_predictor") + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + patience = 30 + patience_counter = 0 + + LOGGER.info("\nStarting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_head_multi_error_predictor", + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + error_model.train() + train_loss = 0.0 + for combined_batch, errors_batch in train_loader: + combined_batch = combined_batch.to(device) + errors_batch = errors_batch.to(device) + + optimizer.zero_grad() + pred_errors = error_model(combined_batch) + loss = nn.functional.mse_loss(pred_errors, errors_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(combined_batch) + + train_loss /= len(combined_train) + scheduler.step() + + # Validation + error_model.eval() + val_loss = 0.0 + with torch.no_grad(): + for combined_batch, errors_batch in val_loader: + combined_batch = combined_batch.to(device) + errors_batch = errors_batch.to(device) + pred_errors = error_model(combined_batch) + val_loss += nn.functional.mse_loss( + pred_errors, errors_batch + ).item() * len(combined_batch) + val_loss /= len(combined_val) + + # Compute final prediction MAE (base + error correction) + with torch.no_grad(): + pred_errors = error_model(combined_val.to(device)) + final_pred = base_pred_val.to(device) + pred_errors + mae = torch.mean(torch.abs(final_pred - y_val.to(device)), dim=0).cpu() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics( + { + "final_mae_x": mae[0].item(), + "final_mae_y": mae[1].item(), + "final_mae_Y": mae[2].item(), + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(error_model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "Final MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + error_model.load_state_dict(best_state) + + # Final evaluation + error_model.eval() + with torch.no_grad(): + pred_errors = error_model(combined_val.to(device)) + final_pred = (base_pred_val.to(device) + pred_errors).cpu() + mae_x = torch.mean(torch.abs(final_pred[:, 0] - y_val[:, 0])).item() + mae_y = torch.mean(torch.abs(final_pred[:, 1] - y_val[:, 1])).item() + mae_Y = torch.mean(torch.abs(final_pred[:, 2] - y_val[:, 2])).item() + + LOGGER.info("\nFinal Results (Base + Multi-Error Predictor):") + LOGGER.info(" MAE x: %.6f", mae_x) + LOGGER.info(" MAE y: %.6f", mae_y) + LOGGER.info(" MAE Y: %.6f", mae_Y) + + # Compare with base model + base_mae = checkpoint["mae"] + LOGGER.info("\nComparison with base model:") + LOGGER.info( + " x: %.6f -> %.6f (%.2f%% change)", + base_mae["x"], + mae_x, + 100 * (mae_x - base_mae["x"]) / base_mae["x"], + ) + LOGGER.info( + " y: %.6f -> %.6f (%.2f%% change)", + base_mae["y"], + mae_y, + 100 * (mae_y - base_mae["y"]) / base_mae["y"], + ) + LOGGER.info( + " Y: %.6f -> %.6f (%.2f%% change)", + base_mae["Y"], + mae_Y, + 100 * (mae_Y - base_mae["Y"]) / base_mae["Y"], + ) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "best_final_mae_x": mae_x, + "best_final_mae_y": mae_y, + "best_final_mae_Y": mae_Y, + "final_epoch": epoch + 1, + } + ) + + # Save model + checkpoint_path = models_dir / "multi_head_multi_error_predictor.pth" + torch.save( + { + "model_state_dict": error_model.state_dict(), + "val_loss": best_val_loss, + "final_mae": {"x": mae_x, "y": mae_y, "Y": mae_Y}, + }, + checkpoint_path, + ) + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + error_model.cpu().eval() + dummy_input = torch.randn(1, 7) # [munsell_norm(4) + base_pred(3)] + onnx_path = models_dir / "multi_head_multi_error_predictor.onnx" + + torch.onnx.export( + error_model, + dummy_input, + onnx_path, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch"}, + "error_correction": {0: "batch"}, + }, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Save normalization parameters alongside model + params_file = ( + models_dir / "multi_head_multi_error_predictor_normalization_params.npz" + ) + input_params = { + # Combined input: munsell_normalized (4) + base_pred (3) + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (0.0, 10.0), + "base_pred_x_range": (0.0, 1.0), + "base_pred_y_range": (0.0, 1.0), + "base_pred_Y_range": (0.0, 1.0), + } + output_params = { + # Error corrections (unnormalized) + "error_x_range": (-1.0, 1.0), + "error_y_range": (-1.0, 1.0), + "error_Y_range": (-1.0, 1.0), + } + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(error_model, "model") + LOGGER.info("Artifacts logged to MLflow") + + + return error_model + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/train_multi_mlp.py b/learning_munsell/training/to_xyY/train_multi_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4ce1bdbddc9e22bc3c188c39724336d596f3ab --- /dev/null +++ b/learning_munsell/training/to_xyY/train_multi_mlp.py @@ -0,0 +1,329 @@ +""" +Train multi-MLP model for Munsell to xyY conversion. + +Architecture: +- 3 independent MLP branches (one per output component: x, y, Y) +- Each branch: 4 inputs -> encoder -> decoder -> 1 output +- Complete independence allows maximum component specialization. +""" + +from __future__ import annotations + +import copy +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPToxyY +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=300, help="Maximum number of training epochs.") +@click.option("--batch-size", default=512, help="Training batch size.") +@click.option("--lr", default=0.001, help="Initial learning rate for AdamW optimizer.") +@click.option( + "--width-multiplier", default=1.0, help="Width multiplier for x and y branches." +) +@click.option( + "--y-width-multiplier", default=1.25, help="Width multiplier for Y branch." +) +@click.option("--weight-decay", default=0.000118, help="L2 regularization weight decay.") +def main( + epochs: int = 300, + batch_size: int = 512, + lr: float = 0.001, + width_multiplier: float = 1.0, + y_width_multiplier: float = 1.25, + weight_decay: float = 0.000118, +) -> tuple[MultiMLPToxyY, float]: + """ + Train the Multi-MLP model for Munsell to xyY conversion. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. Default is 300. + batch_size : int, optional + Training batch size. Default is 512. + lr : float, optional + Initial learning rate for AdamW optimizer. Default is 0.001. + width_multiplier : float, optional + Width multiplier for x and y branches. Default is 1.0. + y_width_multiplier : float, optional + Width multiplier for Y branch. Default is 1.25. + weight_decay : float, optional + L2 regularization weight decay. Default is 0.000118. + + Returns + ------- + tuple + A tuple containing: + - model : MultiMLPToxyY + Trained model with best validation loss weights loaded. + - best_val_loss : float + Best validation loss achieved during training. + + Notes + ----- + The training pipeline: + 1. Loads Munsell training data from cache + 2. Normalizes Munsell specifications to [0, 1] range + 3. Splits data into train/validation sets (90/10) + 4. Trains multi-MLP model with independent branches + 5. Uses MSE loss function + 6. Learning rate scheduling with CosineAnnealingLR + 7. Early stopping based on validation loss (patience=30) + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + LOGGER.info("Hyperparameters (optimized from search):") + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" width_multiplier: %.2f", width_multiplier) + LOGGER.info(" y_width_multiplier: %.2f", y_width_multiplier) + LOGGER.info(" weight_decay: %.6f", weight_decay) + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val + n_samples = len(munsell_specs) + np.random.seed(42) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + train_loader = DataLoader( + TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True + ) + val_loader = DataLoader( + TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False + ) + + # Create model with optimized architecture + model = MultiMLPToxyY( + width_multiplier=width_multiplier, y_width_multiplier=y_width_multiplier + ).to(device) + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Model parameters: %s", f"{total_params:,}") + LOGGER.info( + "Training samples: %d, Validation samples: %d", len(X_train), len(X_val) + ) + + # MLflow setup + run_name = setup_mlflow_experiment("to_xyY", "multi_mlp") + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + patience = 30 + patience_counter = 0 + + LOGGER.info("\nStarting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_mlp", + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "width_multiplier": width_multiplier, + "y_width_multiplier": y_width_multiplier, + "weight_decay": weight_decay, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + model.train() + train_loss = 0.0 + for X_batch, y_batch in train_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + optimizer.zero_grad() + pred = model(X_batch) + loss = nn.functional.mse_loss(pred, y_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + + train_loss /= len(X_train) + scheduler.step() + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for X_batch, y_batch in val_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + pred = model(X_batch) + val_loss += nn.functional.mse_loss(pred, y_batch).item() * len( + X_batch + ) + val_loss /= len(X_val) + + # Per-component MAE + with torch.no_grad(): + pred_val = model(X_val.to(device)) + mae = torch.mean(torch.abs(pred_val - y_val.to(device)), dim=0).cpu() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics( + { + "mae_x": mae[0].item(), + "mae_y": mae[1].item(), + "mae_Y": mae[2].item(), + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + model.load_state_dict(best_state) + + # Final evaluation + model.eval() + with torch.no_grad(): + pred_val = model(X_val.to(device)).cpu() + mae_x = torch.mean(torch.abs(pred_val[:, 0] - y_val[:, 0])).item() + mae_y = torch.mean(torch.abs(pred_val[:, 1] - y_val[:, 1])).item() + mae_Y = torch.mean(torch.abs(pred_val[:, 2] - y_val[:, 2])).item() + + LOGGER.info("\nFinal Results:") + LOGGER.info(" Best Val Loss: %.6f", best_val_loss) + LOGGER.info(" MAE x: %.6f", mae_x) + LOGGER.info(" MAE y: %.6f", mae_y) + LOGGER.info(" MAE Y: %.6f", mae_Y) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_mae_x": mae_x, + "final_mae_y": mae_y, + "final_mae_Y": mae_Y, + "final_epoch": epoch + 1, + } + ) + + # Save model + models_dir = PROJECT_ROOT / "models" / "to_xyY" + models_dir.mkdir(exist_ok=True) + + checkpoint_path = models_dir / "multi_mlp.pth" + torch.save( + { + "model_state_dict": model.state_dict(), + "val_loss": best_val_loss, + "mae": {"x": mae_x, "y": mae_y, "Y": mae_Y}, + }, + checkpoint_path, + ) + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + model.cpu().eval() + dummy_input = torch.randn(1, 4) + onnx_path = models_dir / "multi_mlp.onnx" + torch.onnx.export( + model, + dummy_input, + onnx_path, + input_names=["munsell_normalized"], + output_names=["xyY"], + dynamic_axes={"munsell_normalized": {0: "batch"}, "xyY": {0: "batch"}}, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Save normalization parameters alongside model + params_file = models_dir / "multi_mlp_normalization_params.npz" + input_params = { + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (0.0, 10.0), + } + output_params = { + "x_range": (0.0, 1.0), + "y_range": (0.0, 1.0), + "Y_range": (0.0, 1.0), + } + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + LOGGER.info("Artifacts logged to MLflow") + + + return model, best_val_loss + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py b/learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..cd50e7333483ba93e8b304ac504039334382bcf1 --- /dev/null +++ b/learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py @@ -0,0 +1,416 @@ +""" +Train multi-error predictor for Munsell to xyY conversion. + +Architecture: +- 3 independent error predictor branches (one per component: x, y, Y) +- Each branch takes: [munsell_normalized(4) + base_prediction(1)] = 5 inputs +- Predicts the residual error from the base multi-mlp model +""" + +from __future__ import annotations + +import copy +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import ComponentErrorPredictor, MultiMLPToxyY +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple + +LOGGER = logging.getLogger(__name__) + + +class MultiErrorPredictor(nn.Module): + """ + Multi-Error Predictor with 3 independent branches. + + Each branch is a ComponentErrorPredictor specialized for one + xyY component. Each branch receives the Munsell specification + and its component's base prediction. + + Attributes + ---------- + x_predictor : ComponentErrorPredictor + Error predictor for x chromaticity component. + y_predictor : ComponentErrorPredictor + Error predictor for y chromaticity component. + Y_predictor : ComponentErrorPredictor + Error predictor for Y luminance component. + """ + + def __init__(self) -> None: + """Initialize the multi-error predictor.""" + super().__init__() + + self.x_predictor = ComponentErrorPredictor(input_dim=5) + self.y_predictor = ComponentErrorPredictor(input_dim=5) + self.Y_predictor = ComponentErrorPredictor(input_dim=5) + + def forward(self, munsell: torch.Tensor, base_pred: torch.Tensor) -> torch.Tensor: + """ + Forward pass through all error predictor branches. + + Parameters + ---------- + munsell : Tensor + Normalized Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + base_pred : Tensor + Base model predictions [x, y, Y] of shape (batch_size, 3). + + Returns + ------- + Tensor + Concatenated error corrections [x, y, Y] + of shape (batch_size, 3). + """ + # Each predictor gets munsell input + its component's base prediction + x_input = torch.cat([munsell, base_pred[:, 0:1]], dim=1) + y_input = torch.cat([munsell, base_pred[:, 1:2]], dim=1) + Y_input = torch.cat([munsell, base_pred[:, 2:3]], dim=1) + + x_error = self.x_predictor(x_input) + y_error = self.y_predictor(y_input) + Y_error = self.Y_predictor(Y_input) + + return torch.cat([x_error, y_error, Y_error], dim=1) + + +@click.command() +@click.option("--epochs", default=300, help="Maximum number of training epochs.") +@click.option("--batch-size", default=512, help="Training batch size.") +@click.option("--lr", default=5e-4, help="Initial learning rate for AdamW optimizer.") +def main( + epochs: int = 300, batch_size: int = 512, lr: float = 5e-4 +) -> MultiErrorPredictor | None: + """ + Train the Multi-Error Predictor for Munsell to xyY conversion. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. Default is 300. + batch_size : int, optional + Training batch size. Default is 512. + lr : float, optional + Initial learning rate for AdamW optimizer. Default is 5e-4. + + Returns + ------- + MultiErrorPredictor or None + Trained error predictor model with best validation loss weights + loaded, or None if base model is not found. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base multi-mlp model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses MSE loss function + 6. Learning rate scheduling with CosineAnnealingLR + 7. Early stopping based on validation loss (patience=30) + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load base model + models_dir = PROJECT_ROOT / "models" / "to_xyY" + base_checkpoint = models_dir / "multi_mlp.pth" + + if not base_checkpoint.exists(): + LOGGER.error("Base model not found: %s", base_checkpoint) + LOGGER.error("Please run train_multi_mlp.py first") + return None + + LOGGER.info("Loading base model from %s...", base_checkpoint) + checkpoint = torch.load(base_checkpoint, weights_only=True) + base_model = MultiMLPToxyY().to(device) + base_model.load_state_dict(checkpoint["model_state_dict"]) + base_model.eval() + + for param in base_model.parameters(): + param.requires_grad = False + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val (same split as base model) + n_samples = len(munsell_specs) + np.random.seed(42) # Ensure consistent split + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + # Compute base model predictions and errors + LOGGER.info("Computing base model predictions...") + with torch.no_grad(): + base_pred_train = base_model(X_train.to(device)).cpu() + base_pred_val = base_model(X_val.to(device)).cpu() + + errors_train = y_train - base_pred_train + errors_val = y_val - base_pred_val + + LOGGER.info( + "Base model MAE - train: x=%.6f, y=%.6f, Y=%.6f", + torch.mean(torch.abs(errors_train[:, 0])).item(), + torch.mean(torch.abs(errors_train[:, 1])).item(), + torch.mean(torch.abs(errors_train[:, 2])).item(), + ) + + train_loader = DataLoader( + TensorDataset(X_train, base_pred_train, errors_train), + batch_size=batch_size, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(X_val, base_pred_val, errors_val), + batch_size=batch_size, + shuffle=False, + ) + + # Create error predictor model + error_model = MultiErrorPredictor().to(device) + optimizer = optim.AdamW(error_model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in error_model.parameters()) + LOGGER.info("Error predictor parameters: %s", f"{total_params:,}") + + # MLflow setup + run_name = setup_mlflow_experiment("to_xyY", "multi_mlp_error_predictor") + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + patience = 30 + patience_counter = 0 + + LOGGER.info("\nStarting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "multi_mlp_error_predictor", + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + error_model.train() + train_loss = 0.0 + for X_batch, base_pred_batch, errors_batch in train_loader: + X_batch = X_batch.to(device) + base_pred_batch = base_pred_batch.to(device) + errors_batch = errors_batch.to(device) + optimizer.zero_grad() + pred_errors = error_model(X_batch, base_pred_batch) + loss = nn.functional.mse_loss(pred_errors, errors_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + + train_loss /= len(X_train) + scheduler.step() + + # Validation + error_model.eval() + val_loss = 0.0 + with torch.no_grad(): + for X_batch, base_pred_batch, errors_batch in val_loader: + X_batch = X_batch.to(device) + base_pred_batch = base_pred_batch.to(device) + errors_batch = errors_batch.to(device) + pred_errors = error_model(X_batch, base_pred_batch) + val_loss += nn.functional.mse_loss( + pred_errors, errors_batch + ).item() * len(X_batch) + val_loss /= len(X_val) + + # Compute final prediction MAE (base + error correction) + with torch.no_grad(): + pred_errors = error_model(X_val.to(device), base_pred_val.to(device)) + final_pred = base_pred_val.to(device) + pred_errors + mae = torch.mean(torch.abs(final_pred - y_val.to(device)), dim=0).cpu() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics( + { + "final_mae_x": mae[0].item(), + "final_mae_y": mae[1].item(), + "final_mae_Y": mae[2].item(), + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(error_model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "Final MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + error_model.load_state_dict(best_state) + + # Final evaluation + error_model.eval() + with torch.no_grad(): + pred_errors = error_model(X_val.to(device), base_pred_val.to(device)) + final_pred = (base_pred_val.to(device) + pred_errors).cpu() + mae_x = torch.mean(torch.abs(final_pred[:, 0] - y_val[:, 0])).item() + mae_y = torch.mean(torch.abs(final_pred[:, 1] - y_val[:, 1])).item() + mae_Y = torch.mean(torch.abs(final_pred[:, 2] - y_val[:, 2])).item() + + LOGGER.info("\nFinal Results (Base + Error Predictor):") + LOGGER.info(" MAE x: %.6f", mae_x) + LOGGER.info(" MAE y: %.6f", mae_y) + LOGGER.info(" MAE Y: %.6f", mae_Y) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "best_final_mae_x": mae_x, + "best_final_mae_y": mae_y, + "best_final_mae_Y": mae_Y, + "final_epoch": epoch + 1, + } + ) + + # Save model + checkpoint_path = models_dir / "multi_mlp_error_predictor.pth" + torch.save( + { + "model_state_dict": error_model.state_dict(), + "val_loss": best_val_loss, + "final_mae": {"x": mae_x, "y": mae_y, "Y": mae_Y}, + }, + checkpoint_path, + ) + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + error_model.cpu().eval() + dummy_munsell = torch.randn(1, 4) + dummy_base_pred = torch.randn(1, 3) + onnx_path = models_dir / "multi_mlp_error_predictor.onnx" + + # For ONNX export, create a wrapper that takes concatenated input + class ErrorPredictorWrapper(nn.Module): + def __init__(self, error_predictor: MultiErrorPredictor) -> None: + super().__init__() + self.error_predictor = error_predictor + + def forward(self, combined_input: torch.Tensor) -> torch.Tensor: + munsell = combined_input[:, :4] + base_pred = combined_input[:, 4:] + return self.error_predictor(munsell, base_pred) + + wrapper = ErrorPredictorWrapper(error_model) + dummy_combined = torch.cat([dummy_munsell, dummy_base_pred], dim=1) + + torch.onnx.export( + wrapper, + dummy_combined, + onnx_path, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch"}, + "error_correction": {0: "batch"}, + }, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Save normalization parameters alongside model + params_file = ( + models_dir / "multi_mlp_error_predictor_normalization_params.npz" + ) + input_params = { + # Combined input: munsell_normalized (4) + base_pred (3) + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (0.0, 10.0), + "base_pred_x_range": (0.0, 1.0), + "base_pred_y_range": (0.0, 1.0), + "base_pred_Y_range": (0.0, 1.0), + } + output_params = { + # Error corrections (unnormalized) + "error_x_range": (-1.0, 1.0), + "error_y_range": (-1.0, 1.0), + "error_Y_range": (-1.0, 1.0), + } + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(error_model, "model") + LOGGER.info("Artifacts logged to MLflow") + + + return error_model + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py b/learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e66e7c861c8779978e5d75a379aabcd72b28f7 --- /dev/null +++ b/learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py @@ -0,0 +1,395 @@ +""" +Train Multi-MLP Multi-Error Predictor for Munsell to xyY conversion. + +Architecture: +- 3 independent error correction branches (one per component: x, y, Y) +- Each branch: 7 inputs (munsell_norm + base_pred) -> encoder -> decoder + -> 1 error output +- Uses GELU activation, BatchNorm, and Dropout for better generalization +""" + +from __future__ import annotations + +import copy +import logging + +import click +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import MultiMLPErrorPredictorToxyY, MultiMLPToxyY +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple + +LOGGER = logging.getLogger(__name__) + + +@click.command() +@click.option("--epochs", default=300, help="Maximum number of training epochs.") +@click.option("--batch-size", default=256, help="Training batch size.") +@click.option("--lr", default=0.000383, help="Initial learning rate for AdamW optimizer.") +@click.option("--dropout", default=0.0, help="Dropout probability for regularization.") +@click.option( + "--width-multiplier", default=0.75, help="Width multiplier for hidden layers." +) +@click.option( + "--use-optimized-base", + is_flag=True, + default=False, + help="Whether to use optimized base model.", +) +def main( + epochs: int = 300, + batch_size: int = 256, + lr: float = 0.000383, + dropout: float = 0.0, + width_multiplier: float = 0.75, + use_optimized_base: bool = False, +) -> MultiMLPErrorPredictorToxyY | None: + """ + Train the Multi-MLP Multi-Error Predictor for Munsell to xyY conversion. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. Default is 300. + batch_size : int, optional + Training batch size. Default is 256. + lr : float, optional + Initial learning rate for AdamW optimizer. Default is 0.000383. + dropout : float, optional + Dropout probability for regularization. Default is 0.0. + width_multiplier : float, optional + Width multiplier for hidden layers. Default is 0.75. + use_optimized_base : bool, optional + Whether to use optimized base model. Default is False. + + Returns + ------- + MultiMLPErrorPredictorToxyY or None + Trained error predictor model with best validation loss weights + loaded, or None if base model is not found. + + Notes + ----- + The training pipeline: + 1. Loads pre-trained base multi-mlp model + 2. Generates base model predictions for training data + 3. Computes residual errors between predictions and targets + 4. Trains error predictor on these residuals + 5. Uses MSE loss function + 6. Learning rate scheduling with CosineAnnealingLR + 7. Early stopping based on validation loss (patience=30) + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + suffix = "_optimized" if use_optimized_base else "" + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + LOGGER.info("Hyperparameters (optimized from search):") + LOGGER.info(" lr: %.6f", lr) + LOGGER.info(" batch_size: %d", batch_size) + LOGGER.info(" dropout: %.2f", dropout) + LOGGER.info(" width_multiplier: %.2f", width_multiplier) + LOGGER.info(" use_optimized_base: %s", use_optimized_base) + + # Load base model + models_dir = PROJECT_ROOT / "models" / "to_xyY" + if use_optimized_base: + base_checkpoint = models_dir / "multi_mlp_optimized.pth" + # Optimized base has width_multiplier=1.0, y_width_multiplier=1.25 + base_width = 1.0 + base_y_width = 1.25 + else: + base_checkpoint = models_dir / "multi_mlp.pth" + # Original base has width_multiplier=1.0, y_width_multiplier=1.25 + base_width = 1.0 + base_y_width = 1.25 + + if not base_checkpoint.exists(): + LOGGER.error("Base model not found: %s", base_checkpoint) + LOGGER.error("Please run train_multi_mlp.py first") + return None + + LOGGER.info("Loading base model from %s...", base_checkpoint) + checkpoint = torch.load(base_checkpoint, weights_only=True) + base_model = MultiMLPToxyY( + width_multiplier=base_width, y_width_multiplier=base_y_width + ).to(device) + base_model.load_state_dict(checkpoint["model_state_dict"]) + base_model.eval() + + for param in base_model.parameters(): + param.requires_grad = False + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val (same split as base model) + n_samples = len(munsell_specs) + np.random.seed(42) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + # Compute base model predictions and errors + LOGGER.info("Computing base model predictions...") + with torch.no_grad(): + base_pred_train = base_model(X_train.to(device)).cpu() + base_pred_val = base_model(X_val.to(device)).cpu() + + errors_train = y_train - base_pred_train + errors_val = y_val - base_pred_val + + LOGGER.info( + "Base model MAE - train: x=%.6f, y=%.6f, Y=%.6f", + torch.mean(torch.abs(errors_train[:, 0])).item(), + torch.mean(torch.abs(errors_train[:, 1])).item(), + torch.mean(torch.abs(errors_train[:, 2])).item(), + ) + + # Create combined inputs: [munsell_norm(4) + base_pred(3)] = 7 features + combined_train = torch.cat([X_train, base_pred_train], dim=1) + combined_val = torch.cat([X_val, base_pred_val], dim=1) + + train_loader = DataLoader( + TensorDataset(combined_train, errors_train), + batch_size=batch_size, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(combined_val, errors_val), + batch_size=batch_size, + shuffle=False, + ) + + # Create error predictor model with optimized hyperparameters + error_model = MultiMLPErrorPredictorToxyY( + width_multiplier=width_multiplier + ).to(device) + optimizer = optim.AdamW(error_model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in error_model.parameters()) + LOGGER.info("Multi-Error Predictor parameters: %s", f"{total_params:,}") + + # MLflow setup + run_name = setup_mlflow_experiment( + "to_xyY", f"multi_mlp_multi_error_predictor{suffix}" + ) + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + patience = 30 + patience_counter = 0 + + LOGGER.info("\nStarting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": f"multi_mlp_multi_error_predictor{suffix}", + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "dropout": dropout, + "width_multiplier": width_multiplier, + "use_optimized_base": use_optimized_base, + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + error_model.train() + train_loss = 0.0 + for combined_batch, errors_batch in train_loader: + combined_batch = combined_batch.to(device) + errors_batch = errors_batch.to(device) + optimizer.zero_grad() + pred_errors = error_model(combined_batch) + loss = nn.functional.mse_loss(pred_errors, errors_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(combined_batch) + + train_loss /= len(combined_train) + scheduler.step() + + # Validation + error_model.eval() + val_loss = 0.0 + with torch.no_grad(): + for combined_batch, errors_batch in val_loader: + combined_batch = combined_batch.to(device) + errors_batch = errors_batch.to(device) + pred_errors = error_model(combined_batch) + val_loss += nn.functional.mse_loss( + pred_errors, errors_batch + ).item() * len(combined_batch) + val_loss /= len(combined_val) + + # Compute final prediction MAE (base + error correction) + with torch.no_grad(): + pred_errors = error_model(combined_val.to(device)) + final_pred = base_pred_val.to(device) + pred_errors + mae = torch.mean(torch.abs(final_pred - y_val.to(device)), dim=0).cpu() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics( + { + "final_mae_x": mae[0].item(), + "final_mae_y": mae[1].item(), + "final_mae_Y": mae[2].item(), + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(error_model.state_dict()) + patience_counter = 0 + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "Final MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + else: + patience_counter += 1 + if (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + if patience_counter >= patience: + LOGGER.info("Early stopping at epoch %d", epoch + 1) + break + + # Load best model + error_model.load_state_dict(best_state) + + # Final evaluation + error_model.eval() + with torch.no_grad(): + pred_errors = error_model(combined_val.to(device)) + final_pred = (base_pred_val.to(device) + pred_errors).cpu() + mae_x = torch.mean(torch.abs(final_pred[:, 0] - y_val[:, 0])).item() + mae_y = torch.mean(torch.abs(final_pred[:, 1] - y_val[:, 1])).item() + mae_Y = torch.mean(torch.abs(final_pred[:, 2] - y_val[:, 2])).item() + + LOGGER.info("\nFinal Results (Base + Multi-Error Predictor):") + LOGGER.info(" MAE x: %.6f", mae_x) + LOGGER.info(" MAE y: %.6f", mae_y) + LOGGER.info(" MAE Y: %.6f", mae_Y) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "best_final_mae_x": mae_x, + "best_final_mae_y": mae_y, + "best_final_mae_Y": mae_Y, + "final_epoch": epoch + 1, + } + ) + + # Save model with optional _optimized suffix + checkpoint_path = models_dir / f"multi_mlp_multi_error_predictor{suffix}.pth" + torch.save( + { + "model_state_dict": error_model.state_dict(), + "val_loss": best_val_loss, + "final_mae": {"x": mae_x, "y": mae_y, "Y": mae_Y}, + }, + checkpoint_path, + ) + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + error_model.cpu().eval() + dummy_input = torch.randn(1, 7) # [munsell_norm(4) + base_pred(3)] + onnx_path = models_dir / f"multi_mlp_multi_error_predictor{suffix}.onnx" + + torch.onnx.export( + error_model, + dummy_input, + onnx_path, + input_names=["combined_input"], + output_names=["error_correction"], + dynamic_axes={ + "combined_input": {0: "batch"}, + "error_correction": {0: "batch"}, + }, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Save normalization parameters alongside model + params_filename = ( + f"multi_mlp_multi_error_predictor{suffix}_normalization_params.npz" + ) + params_file = models_dir / params_filename + input_params = { + # Combined input: munsell_normalized (4) + base_pred (3) + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (0.0, 10.0), + "base_pred_x_range": (0.0, 1.0), + "base_pred_y_range": (0.0, 1.0), + "base_pred_Y_range": (0.0, 1.0), + } + output_params = { + # Error corrections (unnormalized) + "error_x_range": (-1.0, 1.0), + "error_y_range": (-1.0, 1.0), + "error_Y_range": (-1.0, 1.0), + } + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(error_model, "model") + LOGGER.info("Artifacts logged to MLflow") + + + return error_model + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py b/learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py new file mode 100644 index 0000000000000000000000000000000000000000..1651ef21384d5e40a5c8f8276d6da8ce4572ac8d --- /dev/null +++ b/learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py @@ -0,0 +1,504 @@ +""" +Train a small MLP to approximate the Munsell -> xyY conversion. + +This network will be used in the differentiable Delta-E loss function +to enable end-to-end training with perceptual accuracy. +""" + +from __future__ import annotations + +import copy +import logging + +import click +import colour +import mlflow +import mlflow.pytorch +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +from learning_munsell import PROJECT_ROOT +from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment +from learning_munsell.utilities.data import load_training_data, normalize_munsell_simple + +LOGGER = logging.getLogger(__name__) + + +class MunsellToXYYApproximator(nn.Module): + """ + Small MLP to approximate Munsell to xyY conversion. + + This lightweight network is designed for use in differentiable Delta-E + loss functions, enabling end-to-end training with perceptual accuracy. + + Parameters + ---------- + hidden_dims : list of int, optional + List of hidden layer dimensions. Default is [128, 256, 128]. + + Attributes + ---------- + net : nn.Sequential + Feed-forward network with configurable hidden dimensions, + LayerNorm, and SiLU activations. + """ + + def __init__(self, hidden_dims: list[int] | None = None) -> None: + """Initialize the Munsell to xyY approximator.""" + if hidden_dims is None: + hidden_dims = [128, 256, 128] + super().__init__() + + layers = [] + in_dim = 4 # Hue, Value, Chroma, Code + + for hidden_dim in hidden_dims: + layers.extend( + [ + nn.Linear(in_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.SiLU(), + ] + ) + in_dim = hidden_dim + + layers.append(nn.Linear(in_dim, 3)) # Output: x, y, Y + + self.net = nn.Sequential(*layers) + + def forward(self, munsell: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the approximator. + + Parameters + ---------- + munsell : Tensor + Normalized Munsell specification [hue, value, chroma, code] + of shape (batch_size, 4). + + Returns + ------- + Tensor + Predicted xyY values [x, y, Y] of shape (batch_size, 3). + """ + return self.net(munsell) + + +def denormalize_munsell(normalized: np.ndarray) -> np.ndarray: + """ + Denormalize Munsell specifications back to original range. + + Parameters + ---------- + normalized : ndarray, shape (n_samples, 4) + Normalized Munsell specifications in [0, 1] range. + + Returns + ------- + ndarray, shape (n_samples, 4) + Denormalized Munsell specifications in original scale. + """ + munsell = normalized.copy() + munsell[:, 0] = normalized[:, 0] * 10.0 + munsell[:, 1] = normalized[:, 1] * 10.0 + munsell[:, 2] = normalized[:, 2] * 50.0 + munsell[:, 3] = normalized[:, 3] * 10.0 + return munsell + + +@click.command() +@click.option("--epochs", default=500, help="Maximum number of training epochs.") +@click.option("--batch-size", default=256, help="Training batch size.") +@click.option("--lr", default=1e-3, help="Initial learning rate for AdamW optimizer.") +@click.option( + "--hidden-dims", + default="128,256,128", + help="Comma-separated list of hidden layer dimensions.", +) +def train_approximator( + epochs: int = 500, + batch_size: int = 256, + lr: float = 1e-3, + hidden_dims: str = "128,256,128", +) -> MunsellToXYYApproximator: + """ + Train the Munsell to xyY approximator for Munsell to xyY conversion. + + Parameters + ---------- + epochs : int, optional + Maximum number of training epochs. Default is 500. + batch_size : int, optional + Training batch size. Default is 256. + lr : float, optional + Initial learning rate for AdamW optimizer. Default is 1e-3. + hidden_dims : str, optional + Comma-separated list of hidden layer dimensions. Default is "128,256,128". + + Returns + ------- + MunsellToXYYApproximator + Trained approximator model with best validation loss weights loaded. + + Notes + ----- + The training pipeline: + 1. Loads Munsell training data from cache + 2. Normalizes Munsell specifications to [0, 1] range + 3. Splits data into train/validation sets (90/10) + 4. Trains lightweight MLP approximator + 5. Uses MSE loss function + 6. Learning rate scheduling with CosineAnnealingLR + 7. Saves best model based on validation loss + 8. Exports model to ONNX format + 9. Logs metrics and artifacts to MLflow + """ + # Parse hidden_dims from string + if isinstance(hidden_dims, str): + hidden_dims = [int(x.strip()) for x in hidden_dims.split(",")] + if hidden_dims is None: + hidden_dims = [128, 256, 128] + + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + LOGGER.info("Using device: %s", device) + + # Load data + munsell_specs, xyY_values = load_training_data(direction="to_xyY") + LOGGER.info("Loaded %d valid samples", len(munsell_specs)) + + # Normalize inputs + munsell_normalized = normalize_munsell_simple(munsell_specs) + + # Split train/val + n_samples = len(munsell_specs) + indices = np.random.permutation(n_samples) + train_idx = indices[: int(0.9 * n_samples)] + val_idx = indices[int(0.9 * n_samples) :] + + X_train = torch.from_numpy(munsell_normalized[train_idx]).float() + y_train = torch.from_numpy(xyY_values[train_idx]).float() + X_val = torch.from_numpy(munsell_normalized[val_idx]).float() + y_val = torch.from_numpy(xyY_values[val_idx]).float() + + train_dataset = TensorDataset(X_train, y_train) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + # Create model + model = MunsellToXYYApproximator(hidden_dims=hidden_dims).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Model parameters: %s", f"{total_params:,}") + LOGGER.info( + "Training samples: %d, Validation samples: %d", len(X_train), len(X_val) + ) + + # MLflow setup + run_name = setup_mlflow_experiment("to_xyY", "munsell_to_xyY_approximator") + LOGGER.info("MLflow run: %s", run_name) + + best_val_loss = float("inf") + best_state = None + + LOGGER.info("\nStarting training...") + + with mlflow.start_run(run_name=run_name): + mlflow.log_params( + { + "model": "munsell_to_xyY_approximator", + "epochs": epochs, + "batch_size": batch_size, + "learning_rate": lr, + "hidden_dims": str(hidden_dims), + "total_params": total_params, + } + ) + + for epoch in range(epochs): + # Training + model.train() + train_loss = 0.0 + for X_batch, y_batch in train_loader: + X_batch, y_batch = X_batch.to(device), y_batch.to(device) + optimizer.zero_grad() + pred = model(X_batch) + loss = nn.functional.mse_loss(pred, y_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + + train_loss /= len(X_train) + scheduler.step() + + # Validation + model.eval() + with torch.no_grad(): + X_val_dev = X_val.to(device) + y_val_dev = y_val.to(device) + pred_val = model(X_val_dev) + val_loss = nn.functional.mse_loss(pred_val, y_val_dev).item() + + # Per-component MAE + mae = torch.mean(torch.abs(pred_val - y_val_dev), dim=0).cpu() + + log_training_epoch( + epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"] + ) + mlflow.log_metrics( + { + "mae_x": mae[0].item(), + "mae_y": mae[1].item(), + "mae_Y": mae[2].item(), + }, + step=epoch, + ) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - " + "MAE: x=%.6f, y=%.6f, Y=%.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + mae[0], + mae[1], + mae[2], + ) + elif (epoch + 1) % 50 == 0: + LOGGER.info( + "Epoch %03d/%d - Train: %.6f, Val: %.6f", + epoch + 1, + epochs, + train_loss, + val_loss, + ) + + # Load best model + model.load_state_dict(best_state) + + # Final evaluation + model.eval() + with torch.no_grad(): + X_val_dev = X_val.to(device) + y_val_dev = y_val.to(device) + pred_val = model(X_val_dev) + + # Per-component MAE + mae_x = torch.mean(torch.abs(pred_val[:, 0] - y_val_dev[:, 0])).item() + mae_y = torch.mean(torch.abs(pred_val[:, 1] - y_val_dev[:, 1])).item() + mae_Y = torch.mean(torch.abs(pred_val[:, 2] - y_val_dev[:, 2])).item() + + LOGGER.info("\nFinal Results:") + LOGGER.info(" Best Val Loss: %.6f", best_val_loss) + LOGGER.info(" MAE x: %.6f", mae_x) + LOGGER.info(" MAE y: %.6f", mae_y) + LOGGER.info(" MAE Y: %.6f", mae_Y) + + # Log final metrics + mlflow.log_metrics( + { + "best_val_loss": best_val_loss, + "final_mae_x": mae_x, + "final_mae_y": mae_y, + "final_mae_Y": mae_Y, + "final_epoch": epoch + 1, + } + ) + + # Save model + models_dir = PROJECT_ROOT / "models" / "to_xyY" + models_dir.mkdir(exist_ok=True) + + # Save PyTorch checkpoint + checkpoint_path = models_dir / "munsell_to_xyY_approximator.pth" + torch.save( + { + "model_state_dict": model.state_dict(), + "hidden_dims": hidden_dims, + "val_loss": best_val_loss, + "mae": {"x": mae_x, "y": mae_y, "Y": mae_Y}, + "normalization": { + "hue_scale": 10.0, # hue_in_decade + "value_scale": 10.0, + "chroma_scale": 50.0, + "code_scale": 10.0, + }, + }, + checkpoint_path, + ) + LOGGER.info("Saved checkpoint: %s", checkpoint_path) + + # Export to ONNX + model.cpu() + model.eval() + dummy_input = torch.randn(1, 4) + onnx_path = models_dir / "munsell_to_xyY_approximator.onnx" + + torch.onnx.export( + model, + dummy_input, + onnx_path, + input_names=["munsell_normalized"], + output_names=["xyY"], + dynamic_axes={ + "munsell_normalized": {0: "batch"}, + "xyY": {0: "batch"}, + }, + opset_version=17, + ) + LOGGER.info("Saved ONNX: %s", onnx_path) + + # Also save JAX-compatible weights as numpy + weights_path = models_dir / "munsell_to_xyY_approximator_weights.npz" + weights = {} + for name, param in model.named_parameters(): + weights[name.replace(".", "_")] = param.detach().numpy() + np.savez(weights_path, **weights) + LOGGER.info("Saved JAX-compatible weights: %s", weights_path) + + # Save normalization parameters alongside model + params_file = ( + models_dir / "munsell_to_xyY_approximator_normalization_params.npz" + ) + input_params = { + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (0.0, 10.0), + } + output_params = { + "x_range": (0.0, 1.0), + "y_range": (0.0, 1.0), + "Y_range": (0.0, 1.0), + } + np.savez( + params_file, + input_params=input_params, + output_params=output_params, + ) + LOGGER.info("Normalization parameters saved to: %s", params_file) + + # Log artifacts to MLflow + mlflow.log_artifact(str(checkpoint_path)) + mlflow.log_artifact(str(onnx_path)) + mlflow.log_artifact(str(weights_path)) + mlflow.log_artifact(str(params_file)) + mlflow.pytorch.log_model(model, "model") + LOGGER.info("Artifacts logged to MLflow") + + return model + + +def test_approximator() -> None: + """ + Test the trained approximator against colour library. + + Compares approximator predictions with ground truth from the colour + library on a set of test Munsell colors and prints error statistics. + """ + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + models_dir = PROJECT_ROOT / "models" / "to_xyY" + + # Load model + checkpoint = torch.load( + models_dir / "munsell_to_xyY_approximator.pth", weights_only=True + ) + model = MunsellToXYYApproximator(hidden_dims=checkpoint["hidden_dims"]) + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device) + model.eval() + + # Test on some Munsell colors + test_colors = [ + "5R 5/10", + "5YR 6/8", + "5Y 8/6", + "5GY 6/8", + "5G 5/8", + "5BG 5/6", + "5B 5/8", + "5PB 4/10", + "5P 4/8", + "5RP 5/10", + ] + + print("\nApproximator Test Results:") + print("=" * 70) + print(f"{'Munsell':<12} {'Approx xyY':<30} {'Actual xyY':<30}") + print("=" * 70) + + errors = [] + for color in test_colors: + try: + spec = colour.notation.munsell.munsell_colour_to_munsell_specification( + color + ) + actual_xyY = colour.notation.munsell.munsell_specification_to_xyY(spec) + + # spec format: (hue_in_decade, value, chroma, code) + # Normalize using the correct scales for training data + munsell_norm = np.array( + [ + [ + spec[0] / 10.0, # hue_in_decade (0-10) + spec[1] / 10.0, # value (0-10) + spec[2] / 50.0, # chroma (0-50) + spec[3] / 10.0, # code (0-10) + ] + ], + dtype=np.float32, + ) + + with torch.no_grad(): + pred_xyY = ( + model(torch.from_numpy(munsell_norm).to(device)).cpu().numpy()[0] + ) + + error = np.abs(pred_xyY - actual_xyY) + errors.append(error) + + print( + f"{color:<12} " + f"[{pred_xyY[0]:.4f}, {pred_xyY[1]:.4f}, {pred_xyY[2]:.4f}] " + f"[{actual_xyY[0]:.4f}, {actual_xyY[1]:.4f}, {actual_xyY[2]:.4f}]" + ) + except Exception as e: print(f"{color:<12} Error: {e}") + + if errors: + errors = np.array(errors) + print("=" * 70) + print( + f"Mean Absolute Error: x={errors[:, 0].mean():.6f}, " + f"y={errors[:, 1].mean():.6f}, Y={errors[:, 2].mean():.6f}" + ) + + +def main() -> None: + """ + Main entry point for training and testing the approximator. + + Notes + ----- + Trains the approximator and runs testing. Use 'test' as first argument + to skip training and only test existing model. + """ + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "test": + test_approximator() + else: + train_approximator() + print("\n" + "=" * 70) + test_approximator() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + + main() diff --git a/learning_munsell/utilities/__init__.py b/learning_munsell/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..498a896f52ada5a0c1c66eaf83fc2014a20bf67b --- /dev/null +++ b/learning_munsell/utilities/__init__.py @@ -0,0 +1,82 @@ +"""Utilities for Munsell color learning.""" + +# Data loading and normalization +from learning_munsell.utilities.data import ( + MUNSELL_NORMALIZATION_PARAMS, + XYY_NORMALIZATION_PARAMS, + compute_normalization_params_munsell, + denormalize_munsell, + denormalize_xyY, + load_training_data, + normalize_munsell, + normalize_munsell_simple, + normalize_xyY, +) + +# Loss functions +from learning_munsell.utilities.losses import ( + DEFAULT_WEIGHTS_FROM_XYY, + DEFAULT_WEIGHTS_TO_XYY, + precision_focused_loss, + weighted_mse_loss, +) + +# Training utilities +from learning_munsell.utilities.training import ( + EarlyStopping, + train_epoch, + validate, +) + +# Export utilities +from learning_munsell.utilities.export import ( + export_jax_to_onnx, + export_transformer_to_onnx, +) + +# MLflow and reporting +from learning_munsell.utilities.common import ( + benchmark_inference_speed, + generate_best_models_summary, + generate_html_report_footer, + generate_html_report_header, + generate_ranking_section, + get_model_size_mb, + log_training_epoch, + setup_mlflow_experiment, +) + +__all__ = [ + # Normalization constants + "XYY_NORMALIZATION_PARAMS", + "MUNSELL_NORMALIZATION_PARAMS", + # Data loading and normalization + "load_training_data", + "compute_normalization_params_munsell", + "normalize_munsell", + "normalize_munsell_simple", + "normalize_xyY", + "denormalize_munsell", + "denormalize_xyY", + # Loss functions + "DEFAULT_WEIGHTS_FROM_XYY", + "DEFAULT_WEIGHTS_TO_XYY", + "weighted_mse_loss", + "precision_focused_loss", + # Training utilities + "train_epoch", + "validate", + "EarlyStopping", + # Export utilities + "export_transformer_to_onnx", + "export_jax_to_onnx", + # MLflow and reporting + "setup_mlflow_experiment", + "log_training_epoch", + "get_model_size_mb", + "benchmark_inference_speed", + "generate_html_report_header", + "generate_html_report_footer", + "generate_ranking_section", + "generate_best_models_summary", +] diff --git a/learning_munsell/utilities/common.py b/learning_munsell/utilities/common.py new file mode 100644 index 0000000000000000000000000000000000000000..32411acc8d157539c9cf74e314ab6e45b2a99a9c --- /dev/null +++ b/learning_munsell/utilities/common.py @@ -0,0 +1,340 @@ +""" +Common utilities for learning-munsell. + +Provides shared functions for MLflow tracking, model comparison reports, +and benchmarking across all training and comparison scripts. +""" + +from __future__ import annotations + +import logging +import os +import time +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +import mlflow +import numpy as np + +from learning_munsell import PROJECT_ROOT + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = [ + "setup_mlflow_experiment", + "log_training_epoch", + "get_model_size_mb", + "benchmark_inference_speed", + "generate_html_report_header", + "generate_html_report_footer", + "generate_best_models_summary", + "generate_ranking_section", +] + + +def setup_mlflow_experiment(direction: str, model_name: str) -> str: + """ + Set up MLflow experiment and return run name. + + Parameters + ---------- + direction + Conversion direction, either "from_xyY" or "to_xyY". + model_name + Name of the model being trained. + + Returns + ------- + str + Generated run name with timestamp. + """ + + mlflow.set_tracking_uri(f"sqlite:///{PROJECT_ROOT / 'mlruns.db'}") + mlflow.set_experiment(f"learning-munsell-{direction}") + + # MLflow changes root logger level to WARNING; restore INFO for training output + logging.getLogger().setLevel(logging.INFO) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + return f"{model_name}_{timestamp}" + + +def log_training_epoch( + epoch: int, train_loss: float, val_loss: float, lr: float +) -> None: + """ + Log standard training metrics for an epoch. + + Parameters + ---------- + epoch + Current epoch number. + train_loss + Training loss for the epoch. + val_loss + Validation loss for the epoch. + lr + Current learning rate. + """ + + mlflow.log_metrics( + { + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": lr, + }, + step=epoch, + ) + + +def get_model_size_mb(file_paths: list[Path]) -> float: + """Get total size of model files in MB (includes .data files).""" + total_bytes = 0 + for f in file_paths: + if f.exists(): + total_bytes += os.path.getsize(f) + # Also include .data files for ONNX external data + data_file = Path(str(f) + ".data") + if data_file.exists(): + total_bytes += os.path.getsize(data_file) + return total_bytes / (1024 * 1024) + + +def benchmark_inference_speed( + session_callable: Callable, + input_data: np.ndarray, + num_iterations: int = 10, + warmup_iterations: int = 3, +) -> float: + """ + Benchmark inference speed in milliseconds per sample. + + Parameters + ---------- + session_callable + Function that performs inference + input_data + Input data for inference + num_iterations + Number of iterations for benchmarking + warmup_iterations + Number of warmup iterations + + Returns + ------- + float + Average time per sample in milliseconds + """ + # Warmup + for _ in range(warmup_iterations): + session_callable() + + # Benchmark + start_time = time.perf_counter() + for _ in range(num_iterations): + session_callable() + end_time = time.perf_counter() + + total_time_ms = (end_time - start_time) * 1000 + time_per_iteration_ms = total_time_ms / num_iterations + return time_per_iteration_ms / len(input_data) + + +def generate_html_report_header( + title: str, + subtitle: str, + num_samples: int, +) -> str: + """Generate HTML report header with styling.""" + return f""" + + + + + {title} - {datetime.now().strftime("%Y-%m-%d %H:%M")} + + + + + +
+ +
+

{title}

+
+

{subtitle}

+

+ Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

+

Test Samples: + {num_samples:,} + real Munsell colors

+
+
+""" + + +def generate_html_report_footer() -> str: + """Generate HTML report footer.""" + return """ +
+ + +""" + + +def generate_best_models_summary( + results: dict, + metrics: list[tuple[str, str, str]], +) -> str: + """ + Generate best models summary section. + + Parameters + ---------- + results + Dictionary of model results + metrics + List of (metric_key, display_name, format_string) tuples + """ + html = """ + +
+

Best Models by Metric

+
+""" + + for metric_key, display_name, fmt in metrics: + # Find best model for this metric + valid_results = [ + (name, res[metric_key]) + for name, res in results.items() + if metric_key in res + and not (isinstance(res[metric_key], float) and np.isnan(res[metric_key])) + ] + if not valid_results: + continue + + best_model, best_value = min(valid_results, key=lambda x: x[1]) + + html += f""" +
+
{display_name}
+
+ {fmt.format(best_value)}
+
{best_model}
+
+""" + + html += """ +
+
+""" + return html + + +def generate_ranking_section( + results: dict, + metric_key: str, + title: str, + lower_is_better: bool = True, +) -> str: + """Generate a ranking bar chart section.""" + # Sort results + sorted_results = sorted( + [ + (name, res[metric_key]) + for name, res in results.items() + if not (isinstance(res[metric_key], float) and np.isnan(res[metric_key])) + ], + key=lambda x: x[1], + reverse=not lower_is_better, + ) + + if not sorted_results: + return "" + + max_value = max(v for _, v in sorted_results) if sorted_results else 1.0 + + html = f""" + +
+

{title}

+
+""" + + for rank, (model_name, value) in enumerate(sorted_results, 1): + width_pct = (value / max_value) * 100 if max_value > 0 else 0 + html += f""" +
+
+ {rank}. + {model_name} +
+
+
+
+
{value:.6f}
+
+""" + + html += """ +
+
+""" + return html diff --git a/learning_munsell/utilities/data.py b/learning_munsell/utilities/data.py new file mode 100644 index 0000000000000000000000000000000000000000..9efbfbb6dfc743e242e971ca36b1a92e9ee68f72 --- /dev/null +++ b/learning_munsell/utilities/data.py @@ -0,0 +1,275 @@ +""" +Data loading and normalization utilities. + +Provides shared functions for loading training data and normalizing +xyY and Munsell values across all training scripts. +""" + +from __future__ import annotations + +import numpy as np +from numpy.typing import NDArray + +from learning_munsell import PROJECT_ROOT + +XYY_NORMALIZATION_PARAMS: dict = { + "x_range": (0.0, 1.0), + "y_range": (0.0, 1.0), + "Y_range": (0.0, 1.0), +} + +MUNSELL_NORMALIZATION_PARAMS: dict = { + "hue_range": (0.0, 10.0), + "value_range": (0.0, 10.0), + "chroma_range": (0.0, 50.0), + "code_range": (1.0, 10.0), +} + +__all__ = [ + "XYY_NORMALIZATION_PARAMS", + "MUNSELL_NORMALIZATION_PARAMS", + "load_training_data", + "compute_normalization_params_munsell", + "normalize_xyY", + "denormalize_xyY", + "normalize_munsell", + "denormalize_munsell", + "normalize_munsell_simple", +] + + +def load_training_data( + direction: str = "from_xyY", +) -> tuple[NDArray, NDArray]: + """ + Load training data for Munsell color conversions. + + Parameters + ---------- + direction : str, optional + Direction of conversion. Either "from_xyY" (default) or "to_xyY". + - "from_xyY": Returns (xyY_values, munsell_specs) for xyY->Munsell + - "to_xyY": Returns (munsell_specs, xyY_values) for Munsell->xyY + + Returns + ------- + tuple[ndarray, ndarray] + For "from_xyY": (xyY values of shape (n, 3), Munsell specs of shape (n, 4)) + For "to_xyY": (Munsell specs of shape (n, 4), xyY values of shape (n, 3)) + Invalid entries (containing NaN) are filtered out. + """ + data = np.load(PROJECT_ROOT / "data" / "training_data.npz") + + if direction == "to_xyY": + # For Munsell -> xyY conversion + xyY_values = data["xyY_all"].astype(np.float32) + munsell_specs = data["munsell_all"].astype(np.float32) + + # Filter invalid entries including negative values + valid_mask = ( + ~np.isnan(xyY_values).any(axis=1) + & ~np.isnan(munsell_specs).any(axis=1) + & (munsell_specs[:, 1] >= 0) # Value >= 0 + & (munsell_specs[:, 2] >= 0) # Chroma >= 0 + ) + + return munsell_specs[valid_mask], xyY_values[valid_mask] + else: + # For xyY -> Munsell conversion (default) + xyY_values = data["xyY_values"] + munsell_specs = data["munsell_specs"] + + # Filter invalid entries + valid_mask = ~np.isnan(xyY_values).any(axis=1) & ~np.isnan( + munsell_specs + ).any(axis=1) + + return xyY_values[valid_mask], munsell_specs[valid_mask] + + +def compute_normalization_params_munsell(y: NDArray) -> dict: + """ + Compute min/max normalization parameters for Munsell data. + + Parameters + ---------- + y : ndarray + Munsell specifications of shape (n, 4). + + Returns + ------- + dict + Dictionary with 'hue_range', 'value_range', 'chroma_range', + 'code_range' keys, each containing (min, max) tuples. + """ + return { + "hue_range": (float(y[:, 0].min()), float(y[:, 0].max())), + "value_range": (float(y[:, 1].min()), float(y[:, 1].max())), + "chroma_range": (float(y[:, 2].min()), float(y[:, 2].max())), + "code_range": (float(y[:, 3].min()), float(y[:, 3].max())), + } + + +def normalize_xyY(X: NDArray, params: dict) -> NDArray: + """ + Normalize xyY input values to [0, 1] range. + + Parameters + ---------- + X : ndarray + xyY values of shape (n, 3) where columns are [x, y, Y]. + params : dict + Normalization parameters with keys 'x_range', 'y_range', 'Y_range', + each containing (min, max) tuples. + + Returns + ------- + ndarray + Normalized values in [0, 1] range, dtype float32. + """ + X_norm = np.empty_like(X, dtype=np.float32) + X_norm[:, 0] = (X[:, 0] - params["x_range"][0]) / ( + params["x_range"][1] - params["x_range"][0] + ) + X_norm[:, 1] = (X[:, 1] - params["y_range"][0]) / ( + params["y_range"][1] - params["y_range"][0] + ) + X_norm[:, 2] = (X[:, 2] - params["Y_range"][0]) / ( + params["Y_range"][1] - params["Y_range"][0] + ) + return X_norm + + +def denormalize_xyY(X_norm: NDArray, params: dict) -> NDArray: + """ + Denormalize xyY values from [0, 1] range back to original scale. + + Parameters + ---------- + X_norm : ndarray + Normalized xyY values of shape (n, 3). + params : dict + Normalization parameters with keys 'x_range', 'y_range', 'Y_range'. + + Returns + ------- + ndarray + Denormalized xyY values. + """ + X = np.empty_like(X_norm) + X[:, 0] = ( + X_norm[:, 0] * (params["x_range"][1] - params["x_range"][0]) + + params["x_range"][0] + ) + X[:, 1] = ( + X_norm[:, 1] * (params["y_range"][1] - params["y_range"][0]) + + params["y_range"][0] + ) + X[:, 2] = ( + X_norm[:, 2] * (params["Y_range"][1] - params["Y_range"][0]) + + params["Y_range"][0] + ) + return X + + +def normalize_munsell(y: NDArray, params: dict) -> NDArray: + """ + Normalize Munsell output values to [0, 1] range. + + Parameters + ---------- + y : ndarray + Munsell specifications [hue, value, chroma, code] of shape (n, 4). + params : dict + Normalization parameters with keys 'hue_range', 'value_range', + 'chroma_range', 'code_range', each containing (min, max) tuples. + + Returns + ------- + ndarray + Normalized values in [0, 1] range. + """ + y_norm = np.empty_like(y, dtype=np.float32) + y_norm[:, 0] = (y[:, 0] - params["hue_range"][0]) / ( + params["hue_range"][1] - params["hue_range"][0] + ) + y_norm[:, 1] = (y[:, 1] - params["value_range"][0]) / ( + params["value_range"][1] - params["value_range"][0] + ) + y_norm[:, 2] = (y[:, 2] - params["chroma_range"][0]) / ( + params["chroma_range"][1] - params["chroma_range"][0] + ) + y_norm[:, 3] = (y[:, 3] - params["code_range"][0]) / ( + params["code_range"][1] - params["code_range"][0] + ) + return y_norm + + +def denormalize_munsell(y_norm: NDArray, params: dict) -> NDArray: + """ + Denormalize Munsell values from [0, 1] range back to original scale. + + Parameters + ---------- + y_norm : ndarray + Normalized Munsell specifications of shape (n, 4). + params : dict + Normalization parameters with keys 'hue_range', 'value_range', + 'chroma_range', 'code_range'. + + Returns + ------- + ndarray + Denormalized Munsell specifications. + """ + y = np.empty_like(y_norm) + y[:, 0] = ( + y_norm[:, 0] * (params["hue_range"][1] - params["hue_range"][0]) + + params["hue_range"][0] + ) + y[:, 1] = ( + y_norm[:, 1] * (params["value_range"][1] - params["value_range"][0]) + + params["value_range"][0] + ) + y[:, 2] = ( + y_norm[:, 2] * (params["chroma_range"][1] - params["chroma_range"][0]) + + params["chroma_range"][0] + ) + y[:, 3] = ( + y_norm[:, 3] * (params["code_range"][1] - params["code_range"][0]) + + params["code_range"][0] + ) + return y + + +def normalize_munsell_simple(munsell: NDArray) -> NDArray: + """ + Normalize Munsell values using simple fixed-range scaling. + + This is a simplified normalization used in to_xyY training scripts + that uses fixed maximum values rather than computed min/max. + + Parameters + ---------- + munsell : ndarray + Munsell specifications [hue, value, chroma, code] of shape (n, 4). + + Returns + ------- + ndarray + Normalized values in approximate [0, 1] range, dtype float32. + + Notes + ----- + Uses fixed scaling factors: + - Hue: /10.0 (decade scale) + - Value: /10.0 + - Chroma: /50.0 + - Code: /10.0 + """ + normalized = munsell.copy().astype(np.float32) + normalized[:, 0] = munsell[:, 0] / 10.0 # Hue (in decade) + normalized[:, 1] = munsell[:, 1] / 10.0 # Value + normalized[:, 2] = munsell[:, 2] / 50.0 # Chroma + normalized[:, 3] = munsell[:, 3] / 10.0 # Code + return normalized diff --git a/learning_munsell/utilities/export.py b/learning_munsell/utilities/export.py new file mode 100644 index 0000000000000000000000000000000000000000..a56b76e06722ed027a39ebfec3b0b53b5a90cc81 --- /dev/null +++ b/learning_munsell/utilities/export.py @@ -0,0 +1,339 @@ +""" +ONNX export utilities. + +Provides functions for exporting trained models to ONNX format, +including PyTorch checkpoints and JAX-trained models. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import torch +from torch import nn + +from learning_munsell import PROJECT_ROOT +from learning_munsell.models.networks import TransformerToMunsell + +LOGGER = logging.getLogger(__name__) + +__all__ = [ + "export_transformer_to_onnx", + "export_jax_to_onnx", +] + + +def export_transformer_to_onnx( + checkpoint_path: Path | None = None, + output_path: Path | None = None, +) -> Path: + """ + Export Transformer model from checkpoint to ONNX format. + + Parameters + ---------- + checkpoint_path : Path, optional + Path to the checkpoint file. Defaults to + models/from_xyY/transformer_large_best.pth. + output_path : Path, optional + Path for the ONNX output file. Defaults to + models/from_xyY/transformer_large.onnx. + + Returns + ------- + Path + Path to the exported ONNX file. + + Raises + ------ + FileNotFoundError + If checkpoint file does not exist. + """ + model_directory = PROJECT_ROOT / "models" / "from_xyY" + + if checkpoint_path is None: + checkpoint_path = model_directory / "transformer_large_best.pth" + if output_path is None: + output_path = model_directory / "transformer_large.onnx" + + if not checkpoint_path.exists(): + msg = f"Checkpoint not found: {checkpoint_path}" + raise FileNotFoundError(msg) + + LOGGER.info("Loading checkpoint from %s...", checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu") + + model = TransformerToMunsell( + num_features=3, + embedding_dim=256, + num_blocks=6, + num_heads=8, + ff_dim=1024, + dropout=0.1, + ) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + + LOGGER.info("Exporting to ONNX...") + dummy_input = torch.randn(1, 3) + + torch.onnx.export( + model, + dummy_input, + output_path, + export_params=True, + opset_version=14, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, + do_constant_folding=True, + dynamo=False, + ) + + LOGGER.info("ONNX model exported to: %s", output_path) + + # Verify export + import onnxruntime as ort + + session = ort.InferenceSession(str(output_path)) + test_np = np.random.randn(10, 3).astype(np.float32) + onnx_output = session.run(None, {"xyY": test_np})[0] + + with torch.no_grad(): + torch_output = model(torch.from_numpy(test_np)).numpy() + + max_diff = np.max(np.abs(onnx_output - torch_output)) + LOGGER.info("Max difference between PyTorch and ONNX: %.6f", max_diff) + + if max_diff < 1e-4: + LOGGER.info("ONNX export verified successfully!") + else: + LOGGER.warning("ONNX export may have precision issues") + + return output_path + + +# JAX-specific classes for weight conversion +class _ComponentMLP(nn.Module): + """PyTorch MLP matching JAX ComponentMLP architecture for weight loading.""" + + def __init__(self, input_dim: int = 3, width_multiplier: float = 1.0) -> None: + super().__init__() + + h1 = int(128 * width_multiplier) + h2 = int(256 * width_multiplier) + h3 = int(512 * width_multiplier) + + self.layers = nn.ModuleList( + [ + nn.Linear(input_dim, h1), + nn.LayerNorm(h1), + nn.Linear(h1, h2), + nn.LayerNorm(h2), + nn.Linear(h2, h3), + nn.LayerNorm(h3), + nn.Linear(h3, h2), + nn.LayerNorm(h2), + nn.Linear(h2, h1), + nn.LayerNorm(h1), + nn.Linear(h1, 1), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layers[0](x) + x = torch.relu(x) + x = self.layers[1](x) + x = self.layers[2](x) + x = torch.relu(x) + x = self.layers[3](x) + x = self.layers[4](x) + x = torch.relu(x) + x = self.layers[5](x) + x = self.layers[6](x) + x = torch.relu(x) + x = self.layers[7](x) + x = self.layers[8](x) + x = torch.relu(x) + x = self.layers[9](x) + return self.layers[10](x) + + +class _MultiMLPJAX(nn.Module): + """PyTorch Multi-MLP matching JAX architecture for weight loading.""" + + def __init__(self) -> None: + super().__init__() + self.hue_branch = _ComponentMLP(input_dim=3, width_multiplier=1.0) + self.value_branch = _ComponentMLP(input_dim=3, width_multiplier=1.0) + self.chroma_branch = _ComponentMLP(input_dim=3, width_multiplier=2.0) + self.code_branch = _ComponentMLP(input_dim=3, width_multiplier=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hue = self.hue_branch(x) + value = self.value_branch(x) + chroma = self.chroma_branch(x) + code = self.code_branch(x) + return torch.cat([hue, value, chroma, code], dim=-1) + + +def _load_jax_weights(weights_path: Path, model: nn.Module) -> nn.Module: + """Load JAX/Flax weights into PyTorch model.""" + saved = np.load(weights_path) + + branch_map = { + "ComponentMLP_0": model.hue_branch, + "ComponentMLP_1": model.value_branch, + "ComponentMLP_2": model.chroma_branch, + "ComponentMLP_3": model.code_branch, + } + + layer_map = { + "Dense_0": 0, + "LayerNorm_0": 1, + "Dense_1": 2, + "LayerNorm_1": 3, + "Dense_2": 4, + "LayerNorm_2": 5, + "Dense_3": 6, + "LayerNorm_3": 7, + "Dense_4": 8, + "LayerNorm_4": 9, + "Dense_5": 10, + } + + for key in saved.files: + if key == "metadata": + continue + + parts = key.split("_") + if parts[0] != "params": + continue + + component_name = f"{parts[1]}_{parts[2]}" + layer_name = f"{parts[3]}_{parts[4]}" + param_name = parts[5] + + branch = branch_map[component_name] + layer_idx = layer_map[layer_name] + layer = branch.layers[layer_idx] + + weight = saved[key] + + if "Dense" in layer_name: + if param_name == "kernel": + layer.weight.data = torch.from_numpy(weight.T).float() + elif param_name == "bias": + layer.bias.data = torch.from_numpy(weight).float() + elif "LayerNorm" in layer_name: + if param_name == "scale": + layer.weight.data = torch.from_numpy(weight).float() + elif param_name == "bias": + layer.bias.data = torch.from_numpy(weight).float() + + return model + + +def export_jax_to_onnx( + weights_path: Path | None = None, + output_path: Path | None = None, +) -> Path: + """ + Export JAX-trained Multi-MLP model to ONNX format. + + Loads weights from a JAX-trained model, creates an equivalent PyTorch model, + and exports to ONNX format. + + Parameters + ---------- + weights_path : Path, optional + Path to the JAX weights file (.npz). Defaults to + models/from_xyY/multi_mlp_jax_delta_e.npz. + output_path : Path, optional + Path for the ONNX output file. Defaults to + models/from_xyY/multi_mlp_jax_delta_e.onnx. + + Returns + ------- + Path + Path to the exported ONNX file. + + Raises + ------ + FileNotFoundError + If weights file does not exist. + """ + models_dir = PROJECT_ROOT / "models" / "from_xyY" + + if weights_path is None: + weights_path = models_dir / "multi_mlp_jax_delta_e.npz" + if output_path is None: + output_path = models_dir / "multi_mlp_jax_delta_e.onnx" + + if not weights_path.exists(): + msg = f"JAX weights not found: {weights_path}" + raise FileNotFoundError(msg) + + LOGGER.info("Loading JAX weights from %s", weights_path) + + model = _MultiMLPJAX() + model = _load_jax_weights(weights_path, model) + model.eval() + + total_params = sum(p.numel() for p in model.parameters()) + LOGGER.info("Model parameters: %s", f"{total_params:,}") + + dummy_input = torch.randn(1, 3) + + torch.onnx.export( + model, + dummy_input, + output_path, + input_names=["xyY"], + output_names=["munsell_spec"], + dynamic_axes={"xyY": {0: "batch"}, "munsell_spec": {0: "batch"}}, + opset_version=17, + ) + LOGGER.info("Exported ONNX: %s", output_path) + + # Save normalization params + norm_params_path = output_path.with_name( + output_path.stem + "_normalization_params.npz" + ) + np.savez( + norm_params_path, + output_params={ + "hue_range": [0.5, 10.0], + "value_range": [0.0, 10.0], + "chroma_range": [0.0, 50.0], + "code_range": [1.0, 10.0], + }, + ) + LOGGER.info("Saved normalization params: %s", norm_params_path) + + return output_path + + +def main() -> None: + """Export models to ONNX format.""" + import argparse + + parser = argparse.ArgumentParser(description="Export models to ONNX") + parser.add_argument( + "model", + choices=["transformer", "jax"], + help="Model type to export", + ) + args = parser.parse_args() + + if args.model == "transformer": + export_transformer_to_onnx() + elif args.model == "jax": + export_jax_to_onnx() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) + main() diff --git a/learning_munsell/utilities/losses.py b/learning_munsell/utilities/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d4078b5abefc4af8036fd05f6311a86f9398e3b3 --- /dev/null +++ b/learning_munsell/utilities/losses.py @@ -0,0 +1,105 @@ +""" +PyTorch loss functions. + +Provides shared loss functions for training neural networks, +including weighted MSE and precision-focused losses. +""" + +from __future__ import annotations + +import torch +from torch import Tensor + +__all__ = [ + "DEFAULT_WEIGHTS_FROM_XYY", + "DEFAULT_WEIGHTS_TO_XYY", + "weighted_mse_loss", + "precision_focused_loss", +] + +# Default component weights for from_xyY direction (hue, value, chroma, code) +DEFAULT_WEIGHTS_FROM_XYY = torch.tensor([1.0, 1.0, 5.0, 0.4]) + +# Default component weights for to_xyY direction (x, y, Y) +DEFAULT_WEIGHTS_TO_XYY = torch.tensor([1.0, 1.0, 1.0]) + + +def weighted_mse_loss( + pred: Tensor, + target: Tensor, + weights: Tensor | None = None, +) -> Tensor: + """ + Compute weighted mean squared error loss. + + Parameters + ---------- + pred : Tensor + Predicted values of shape (batch_size, n_components). + target : Tensor + Target values of shape (batch_size, n_components). + weights : Tensor, optional + Component weights of shape (n_components,). + If None, uses equal weights. + + Returns + ------- + Tensor + Scalar weighted MSE loss. + """ + mse = (pred - target) ** 2 + if weights is not None: + weights = weights.to(pred.device) + mse = mse * weights + return mse.mean() + + +def precision_focused_loss(pred: Tensor, target: Tensor) -> Tensor: + """ + Compute precision-focused loss for small residual errors. + + Combines multiple loss terms to heavily penalize small errors, + which is important for achieving sub-JND (Just Noticeable Difference) + accuracy in color prediction. + + Parameters + ---------- + pred : Tensor + Predicted values of shape (batch_size, n_components). + target : Tensor + Target values of shape (batch_size, n_components). + + Returns + ------- + Tensor + Scalar loss value. + + Notes + ----- + The loss combines four components: + - MSE: Standard mean squared error (weight 1.0) + - MAE: Mean absolute error (weight 0.5) + - Log penalty: Penalizes small errors heavily (weight 0.3) + - Huber: Small delta (0.01) for precision on small errors (weight 0.5) + """ + # Standard MSE + mse = torch.mean((pred - target) ** 2) + + # Mean absolute error + mae = torch.mean(torch.abs(pred - target)) + + # Logarithmic penalty - heavily penalizes small errors + log_penalty = torch.mean(torch.log1p(torch.clamp(torch.abs(pred - target) * 1000.0, max=1e6))) + + # Huber loss with small delta for precision + delta = 0.01 + abs_error = torch.abs(pred - target) + huber = torch.where( + abs_error <= delta, + 0.5 * abs_error**2, + delta * (abs_error - 0.5 * delta), + ) + huber_loss = torch.mean(huber) + + # Combine with weights + return 1.0 * mse + 0.5 * mae + 0.3 * log_penalty + 0.5 * huber_loss diff --git a/learning_munsell/utilities/training.py b/learning_munsell/utilities/training.py new file mode 100644 index 0000000000000000000000000000000000000000..44d0a4fe34afe04626419a177502c4ae43b5ea9e --- /dev/null +++ b/learning_munsell/utilities/training.py @@ -0,0 +1,173 @@ +""" +Training utilities. + +Provides shared functions for training loops, validation, and early stopping +across all training scripts. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import nn, optim + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.utils.data import DataLoader + +__all__ = [ + "train_epoch", + "validate", + "EarlyStopping", +] + + +def train_epoch( + model: nn.Module, + dataloader: DataLoader, + optimizer: optim.Optimizer, + criterion: Callable, + device: torch.device, +) -> float: + """ + Train the model for one epoch. + + Parameters + ---------- + model : nn.Module + The neural network model to train. + dataloader : DataLoader + DataLoader providing training batches (X, y). + optimizer : optim.Optimizer + Optimizer for updating model parameters. + criterion : callable + Loss function that takes (predictions, targets) and returns loss. + device : torch.device + Device to run training on. + + Returns + ------- + float + Average loss for the epoch. + """ + model.train() + total_loss = 0.0 + + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + + # Forward pass + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate( + model: nn.Module, + dataloader: DataLoader, + criterion: Callable, + device: torch.device, +) -> float: + """ + Validate the model on validation data. + + Parameters + ---------- + model : nn.Module + The neural network model to validate. + dataloader : DataLoader + DataLoader providing validation batches (X, y). + criterion : callable + Loss function that takes (predictions, targets) and returns loss. + device : torch.device + Device to run validation on. + + Returns + ------- + float + Average loss for the validation set. + """ + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for X_batch, y_batch in dataloader: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + + outputs = model(X_batch) + loss = criterion(outputs, y_batch) + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +class EarlyStopping: + """ + Early stopping handler to prevent overfitting. + + Parameters + ---------- + patience : int + Number of epochs to wait for improvement before stopping. + min_delta : float + Minimum change to qualify as an improvement. + + Attributes + ---------- + best_loss : float + Best validation loss observed. + counter : int + Number of epochs without improvement. + should_stop : bool + Whether training should stop. + """ + + def __init__(self, patience: int = 30, min_delta: float = 0.0) -> None: + """Initialize early stopping with patience and minimum delta.""" + self.patience = patience + self.min_delta = min_delta + self.best_loss = float("inf") + self.counter = 0 + self.should_stop = False + + def __call__(self, val_loss: float) -> bool: + """ + Check if training should stop. + + Parameters + ---------- + val_loss : float + Current validation loss. + + Returns + ------- + bool + True if training should stop. + """ + if val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.should_stop = True + return self.should_stop + + def reset(self) -> None: + """Reset early stopping state.""" + self.best_loss = float("inf") + self.counter = 0 + self.should_stop = False diff --git a/models/from_xyY/deep_wide.onnx b/models/from_xyY/deep_wide.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6fbf01203f7ef23b25e802cf31695d38d4aa8a5d --- /dev/null +++ b/models/from_xyY/deep_wide.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c2fc6876d23f562be73d806caed85192d78c4736a307bfc03b35d5a328ecae2 +size 20372 diff --git a/models/from_xyY/deep_wide.onnx.data b/models/from_xyY/deep_wide.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..914dce10cee3f4c8dfbad30c33a47cea6ad49324 --- /dev/null +++ b/models/from_xyY/deep_wide.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e6e798c7d56e7c47d5a5b06ccc37f8cb09b382bcdf58cecb59caba57c255750 +size 38600704 diff --git a/models/from_xyY/deep_wide_best.pth b/models/from_xyY/deep_wide_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..13bc28074fbd5d9776ef1217274aa7c64b39986d --- /dev/null +++ b/models/from_xyY/deep_wide_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dccba7d18b1b1135100b1528f49ac087889637c5b9ec37f7ae5f441c15408cb9 +size 38660545 diff --git a/models/from_xyY/deep_wide_normalization_params.npz b/models/from_xyY/deep_wide_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/deep_wide_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/ft_transformer.onnx b/models/from_xyY/ft_transformer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3ab797d9421914a3ba7a1413c545dd841f7cbffb --- /dev/null +++ b/models/from_xyY/ft_transformer.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45cfaa7caf8cc792c3656cd445a523a62d6ee75fcb4ff5f706fb50fd0e46c44d +size 293071 diff --git a/models/from_xyY/ft_transformer.onnx.data b/models/from_xyY/ft_transformer.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..2b2188dfa490f411de2eacbdfda3e9aa371b1b87 --- /dev/null +++ b/models/from_xyY/ft_transformer.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5abfe2829bf5f4d038c74b8cdc4da046f532d03aff41c305f23394b19ff7255 +size 8574464 diff --git a/models/from_xyY/ft_transformer_best.pth b/models/from_xyY/ft_transformer_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..32a533091e503845c83329cc5d838c19ae0a06c3 --- /dev/null +++ b/models/from_xyY/ft_transformer_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5621b4a57403755845f841a88f0501af950150b33fac9a98c7c3cd4a55ba2fab +size 8597246 diff --git a/models/from_xyY/ft_transformer_normalization_params.npz b/models/from_xyY/ft_transformer_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/ft_transformer_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/gamma_sweep_results.npz b/models/from_xyY/gamma_sweep_results.npz new file mode 100644 index 0000000000000000000000000000000000000000..3ff96d4f6a9e5e33f3509bbc565392280df0547d --- /dev/null +++ b/models/from_xyY/gamma_sweep_results.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5e035e01d28e3ffdf18531860f64ba45266e6367dc57015f15fa951267bd37d +size 1935 diff --git a/models/from_xyY/gamma_sweep_results_averaged.npz b/models/from_xyY/gamma_sweep_results_averaged.npz new file mode 100644 index 0000000000000000000000000000000000000000..f0ccb251cdafb368fa75f2cae0a8d39f84e75bae --- /dev/null +++ b/models/from_xyY/gamma_sweep_results_averaged.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aeef4c069ab9330d7d2f7201c8891b65993f86f8c5e57a449f36564c64f9942c +size 11746 diff --git a/models/from_xyY/mixture_of_experts.onnx b/models/from_xyY/mixture_of_experts.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b2f41a9d44cc295d697caa07278416274af65bae --- /dev/null +++ b/models/from_xyY/mixture_of_experts.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9ae99d547d0c6c0f728c15221196f4995469efe0d22bd633dbc899cb6e31075 +size 43390 diff --git a/models/from_xyY/mixture_of_experts.onnx.data b/models/from_xyY/mixture_of_experts.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..0571757a26440449007a8fe068759c0923d03ab1 --- /dev/null +++ b/models/from_xyY/mixture_of_experts.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9cbb0a7dbf8b071a7fffe51a1055319687e7405c6e33016567e5508b764b000 +size 4330496 diff --git a/models/from_xyY/mixture_of_experts_best.pth b/models/from_xyY/mixture_of_experts_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..ad014dd191ed85de240953cd29a6d5a4fdb24dea --- /dev/null +++ b/models/from_xyY/mixture_of_experts_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fead787899e97014ed2b0658f548b2276741a698412ed66b9179a5da25b5905 +size 4424599 diff --git a/models/from_xyY/mixture_of_experts_normalization_params.npz b/models/from_xyY/mixture_of_experts_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/mixture_of_experts_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/mlp.onnx b/models/from_xyY/mlp.onnx new file mode 100644 index 0000000000000000000000000000000000000000..de1b2aae74ef9d16fa064acb6692dc057fd28ca3 --- /dev/null +++ b/models/from_xyY/mlp.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f83276a1d30668e55be919fc6cb6bc445cf1ed1ba35626570ae7fba88fb1be54 +size 7885 diff --git a/models/from_xyY/mlp_attention.onnx b/models/from_xyY/mlp_attention.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e90060b305d9d1ba3e36a45047bcae0edd7cb45b --- /dev/null +++ b/models/from_xyY/mlp_attention.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dec531f7bf9184c39d85add161e07bfd4113d790f7fca33754d64f4531a31b08 +size 36264 diff --git a/models/from_xyY/mlp_attention_normalization_params.npz b/models/from_xyY/mlp_attention_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/mlp_attention_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/mlp_error_predictor.onnx b/models/from_xyY/mlp_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ff346060705d002a64ebdb126a687506e224897e --- /dev/null +++ b/models/from_xyY/mlp_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6bcf10a11909afd401f7d47abae96248c47abff330f1d266ca0b1487991359c +size 15661 diff --git a/models/from_xyY/mlp_gamma.onnx b/models/from_xyY/mlp_gamma.onnx new file mode 100644 index 0000000000000000000000000000000000000000..39cab1d3b2610a66ef9f72c2d5b8e0bd2a4f0ceb --- /dev/null +++ b/models/from_xyY/mlp_gamma.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:427f6b0ba9d9a3e9596c07a16ce42e5857de6ac5472fa0ab0b8bc3ed8fc99ca4 +size 8119 diff --git a/models/from_xyY/mlp_gamma_normalization_params.npz b/models/from_xyY/mlp_gamma_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..b898222c0f5ed3f5ba6932d334ec8d00cf81a27b --- /dev/null +++ b/models/from_xyY/mlp_gamma_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de07826b1e2f5197b0b579b842eb61694dd527baaf8dea917d819deada51071c +size 1039 diff --git a/models/from_xyY/mlp_normalization_params.npz b/models/from_xyY/mlp_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/mlp_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_head.onnx b/models/from_xyY/multi_head.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b20aa457b34d8ad7ed612166a1aa3b27121b11aa --- /dev/null +++ b/models/from_xyY/multi_head.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31e7caf6b4febe0e9cabc215f88ab5f630b86710180fcfb89e219a7457080599 +size 17269 diff --git a/models/from_xyY/multi_head.onnx.data b/models/from_xyY/multi_head.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..89a02037169f86f3c340bc6562da6cf5f4207264 --- /dev/null +++ b/models/from_xyY/multi_head.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07a0c775131c4283e38275c66d8c747830c2aa0721e85cb5eba5e6ec03e0f771 +size 3992064 diff --git a/models/from_xyY/multi_head_3stage_error_predictor.onnx b/models/from_xyY/multi_head_3stage_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b4cffd5c54cdc88238f1cd6f897bae648c97edd7 --- /dev/null +++ b/models/from_xyY/multi_head_3stage_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:422f01b192adbf01fadaad99765604f7bce07abbfb4eb267dc57d3f3020e5fc0 +size 40853 diff --git a/models/from_xyY/multi_head_3stage_error_predictor.onnx.data b/models/from_xyY/multi_head_3stage_error_predictor.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..f3bb812cb7faaf6a15422539165fd6b74cca4e5f --- /dev/null +++ b/models/from_xyY/multi_head_3stage_error_predictor.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af480607996b41d736e33cdc8c6878e3455d68994a1ebcb24b85475b972627db +size 7077888 diff --git a/models/from_xyY/multi_head_3stage_error_predictor_best.pth b/models/from_xyY/multi_head_3stage_error_predictor_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..7d4452284da43672ac331b3cdf72c3a18ec02451 --- /dev/null +++ b/models/from_xyY/multi_head_3stage_error_predictor_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0317b0139fb4dbe5834ac1b051f1ea7d840258cb793960f324b37b96ec667921 +size 7068541 diff --git a/models/from_xyY/multi_head_best.pth b/models/from_xyY/multi_head_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..6117c5bbf062a6d040c15f21bc2a5bf4c7b1a509 --- /dev/null +++ b/models/from_xyY/multi_head_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b9d9e61b547521e71473c4891c31df0db5e97929abce4ecd94c02a4f9090e98 +size 4022111 diff --git a/models/from_xyY/multi_head_circular.onnx b/models/from_xyY/multi_head_circular.onnx new file mode 100644 index 0000000000000000000000000000000000000000..409a26f11d42d989e1309a9b79f5fceb6867f0e6 --- /dev/null +++ b/models/from_xyY/multi_head_circular.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:319b3fe28425771569c4175e0a3413115a6615b3d7e345b17f13cc9ff1f0a88b +size 17935 diff --git a/models/from_xyY/multi_head_circular.onnx.data b/models/from_xyY/multi_head_circular.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..43e2ac6004e1bc613a7884d1f83ac22e1e37ff51 --- /dev/null +++ b/models/from_xyY/multi_head_circular.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29c7784bae5ec320e66c2d25de08fff15ad917597c67b1b46493aac5aaeafb1a +size 5162496 diff --git a/models/from_xyY/multi_head_circular.pth b/models/from_xyY/multi_head_circular.pth new file mode 100644 index 0000000000000000000000000000000000000000..712a60844739c81dae1b64aaa8d5d8011cc63f7b --- /dev/null +++ b/models/from_xyY/multi_head_circular.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddfc4a700e930bb8fb1fd6344dd016ea173bb5969f4ed94dbf7a33bb6b86e1b2 +size 5193319 diff --git a/models/from_xyY/multi_head_circular_normalization_params.npz b/models/from_xyY/multi_head_circular_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..be6016b1bea29de339e555b877d44a4f8f3908d6 --- /dev/null +++ b/models/from_xyY/multi_head_circular_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2aab39517437f2da9428449138fcfc413e4636e435ba81c47c890a5b0a75ebf5 +size 545 diff --git a/models/from_xyY/multi_head_cross_attention_error_predictor.onnx b/models/from_xyY/multi_head_cross_attention_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aac9119541e19551952f85ad3d62fe2a9d8d3bbf --- /dev/null +++ b/models/from_xyY/multi_head_cross_attention_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09ce9672f069fe6a5f7990e1ab710ef6c4b8f06811f12cfa9b6c82f2a1f0b6ca +size 173567 diff --git a/models/from_xyY/multi_head_cross_attention_error_predictor.onnx.data b/models/from_xyY/multi_head_cross_attention_error_predictor.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..b5302ec0e34f7b6b76b5f18e74dcbbad972cff97 --- /dev/null +++ b/models/from_xyY/multi_head_cross_attention_error_predictor.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:807e464bd2c838fbd6c76d6e6fb745344aea2e7e21778165c17f8522062f47f6 +size 4237312 diff --git a/models/from_xyY/multi_head_cross_attention_error_predictor.pth b/models/from_xyY/multi_head_cross_attention_error_predictor.pth new file mode 100644 index 0000000000000000000000000000000000000000..19cc0558eb1addc1e97bdb7e3ae850e15bc01510 --- /dev/null +++ b/models/from_xyY/multi_head_cross_attention_error_predictor.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f50bfb570502e6df2b8185223a846fafff497cc4f324bf8ac0437fec392499ed +size 4260533 diff --git a/models/from_xyY/multi_head_gamma.onnx b/models/from_xyY/multi_head_gamma.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e0e584f12a95b8d4c0a01ea0e223a53ed5ae15d6 --- /dev/null +++ b/models/from_xyY/multi_head_gamma.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d67d982f0913c114105ccd7376ebba202cc13461300f8d974806d20e9d1fa78c +size 17737 diff --git a/models/from_xyY/multi_head_gamma.onnx.data b/models/from_xyY/multi_head_gamma.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..1ee218137f88718687ac0e5e49d6da5739d0d146 --- /dev/null +++ b/models/from_xyY/multi_head_gamma.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59e869cdb9327955a1473fbd9d9a9a5820ce09d0c45909636769b4fd2349cb9f +size 3992064 diff --git a/models/from_xyY/multi_head_gamma_best.pth b/models/from_xyY/multi_head_gamma_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..a31d029ef50c24a3b1d287b4bcfb43e42eaa1796 --- /dev/null +++ b/models/from_xyY/multi_head_gamma_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb02fe5e6301096c842a21f37cc1756a05f34b951f2d6eb92bebff32feb69c15 +size 4022891 diff --git a/models/from_xyY/multi_head_gamma_normalization_params.npz b/models/from_xyY/multi_head_gamma_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..b898222c0f5ed3f5ba6932d334ec8d00cf81a27b --- /dev/null +++ b/models/from_xyY/multi_head_gamma_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de07826b1e2f5197b0b579b842eb61694dd527baaf8dea917d819deada51071c +size 1039 diff --git a/models/from_xyY/multi_head_large.onnx b/models/from_xyY/multi_head_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..54b51ee100d86540fcc33cfe907c4ecb03fcc331 --- /dev/null +++ b/models/from_xyY/multi_head_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63fcbec283ad57b8d68c29ea32395ac65c04159061e4ad9755f265ab5478486b +size 17725 diff --git a/models/from_xyY/multi_head_large.onnx.data b/models/from_xyY/multi_head_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..ecfd8af13c04aef2a365f5deedfa1d57e27f4de8 --- /dev/null +++ b/models/from_xyY/multi_head_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22536a501b092dd8747aba9e55e6fffa6f1e80f7d6f7618d940e1414b5ec1a92 +size 3992064 diff --git a/models/from_xyY/multi_head_large_best.pth b/models/from_xyY/multi_head_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..6b8e75c7541a087415308244ba433ac0299bb7f2 --- /dev/null +++ b/models/from_xyY/multi_head_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd022e04a251da4e2ad5fba885e4dedff601d38feb3a6226c3abe2fef8b3e640 +size 4022699 diff --git a/models/from_xyY/multi_head_large_normalization_params.npz b/models/from_xyY/multi_head_large_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_head_large_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_head_multi_error_predictor.onnx b/models/from_xyY/multi_head_multi_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d5f16017aeb4759933d32b982545f5e846f76ca4 --- /dev/null +++ b/models/from_xyY/multi_head_multi_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5878b19ec14c85379b2acd17afc3aaff9c1c7d68d38d2a32cf6e65e3c2712054 +size 40722 diff --git a/models/from_xyY/multi_head_multi_error_predictor.onnx.data b/models/from_xyY/multi_head_multi_error_predictor.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..c90970914d724cfc1fcaecd55f6688f09c0fa3d6 --- /dev/null +++ b/models/from_xyY/multi_head_multi_error_predictor.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04bd070e1f9c1bc724631e132c3af8eabfc9115d689101a95f099618793edc53 +size 7077888 diff --git a/models/from_xyY/multi_head_multi_error_predictor_best.pth b/models/from_xyY/multi_head_multi_error_predictor_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..a5237f6f5f9bea89586b9642afa043ac5de46ec2 --- /dev/null +++ b/models/from_xyY/multi_head_multi_error_predictor_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d5756fa243c2fc3aee04c1ba01f8b680e3790cf1dd311e6bcf976076a626ef3 +size 7068387 diff --git a/models/from_xyY/multi_head_multi_error_predictor_large.onnx b/models/from_xyY/multi_head_multi_error_predictor_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..17b9d4671f965c09c0eebda624e305080012cbd2 --- /dev/null +++ b/models/from_xyY/multi_head_multi_error_predictor_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24ccdfae08b88d439fc9c67b3899e9ec90b7ef90598c42eb98c5dc91ce3d43b1 +size 41520 diff --git a/models/from_xyY/multi_head_multi_error_predictor_large.onnx.data b/models/from_xyY/multi_head_multi_error_predictor_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..04360816fdf58d6a0250c8e9d255595de829e2fa --- /dev/null +++ b/models/from_xyY/multi_head_multi_error_predictor_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f8d3e1f38caccc76a47076e333ebc4fdbac22622ca6b4d4eef07778d4458fe8 +size 7077888 diff --git a/models/from_xyY/multi_head_multi_error_predictor_large_best.pth b/models/from_xyY/multi_head_multi_error_predictor_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..6b24c27defc6f86092b107160547ddb94ff83402 --- /dev/null +++ b/models/from_xyY/multi_head_multi_error_predictor_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce97f7fd819c44b3ee6fffb90b67fc30b2d25997910c8067705b2bb6dcc93e23 +size 7069311 diff --git a/models/from_xyY/multi_head_normalization_params.npz b/models/from_xyY/multi_head_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_head_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_head_refined_real.onnx b/models/from_xyY/multi_head_refined_real.onnx new file mode 100644 index 0000000000000000000000000000000000000000..23ef3930f5482f6caf208f29ec55575db0469811 --- /dev/null +++ b/models/from_xyY/multi_head_refined_real.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c158237834cdbad471b8df17b5fcbf86987cdc6e8e61c8a9ab315b55144d125c +size 95959 diff --git a/models/from_xyY/multi_head_refined_real.onnx.data b/models/from_xyY/multi_head_refined_real.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..91b4beaf556eb1cfe57e4dc0f2baed71f3177d90 --- /dev/null +++ b/models/from_xyY/multi_head_refined_real.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec83d869c9f984262ffe2c10b163d81c24a25b308b92dccc701433bb1d847e0c +size 3992064 diff --git a/models/from_xyY/multi_head_refined_real_best.pth b/models/from_xyY/multi_head_refined_real_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..f81199a2bab82cfd3048bf06a9b510295a45f9f7 --- /dev/null +++ b/models/from_xyY/multi_head_refined_real_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:967915da52266b32d626771624a6c38bbbf64fc9965124124d0d4cd6e8e47858 +size 4023513 diff --git a/models/from_xyY/multi_head_refined_real_normalization_params.npz b/models/from_xyY/multi_head_refined_real_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..61798f3bf24115f6ce1c1560640c46f2c2ca2905 --- /dev/null +++ b/models/from_xyY/multi_head_refined_real_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c02287251a96e17ca6c20ba680b7ff1b43ecadce76a29fa4f678bb2b2b61ff4 +size 1022 diff --git a/models/from_xyY/multi_head_st2084.onnx b/models/from_xyY/multi_head_st2084.onnx new file mode 100644 index 0000000000000000000000000000000000000000..13271a6d59b15e03ca288d988f7db80327844b2c --- /dev/null +++ b/models/from_xyY/multi_head_st2084.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc432c5da6dd424fdca3a5aab75bdb2249f28cf76379734345d7b3c0bdfdb4fa +size 17815 diff --git a/models/from_xyY/multi_head_st2084.onnx.data b/models/from_xyY/multi_head_st2084.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..c8b488b1aad5c89134b0084ccea6ec5fb3d0a3e5 --- /dev/null +++ b/models/from_xyY/multi_head_st2084.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bdb84363a918bffdee9476da5ad44765e7e25f8ee511564d2fadae8bbdf2370 +size 3992064 diff --git a/models/from_xyY/multi_head_st2084_best.pth b/models/from_xyY/multi_head_st2084_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..1445675eeae889ada0d5e3bc09dc7e3b279c18cd --- /dev/null +++ b/models/from_xyY/multi_head_st2084_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa0e7ef06b62c3a513a81d1a3d55af0d53a362d33a91f8931966e75aaea1a290 +size 4023053 diff --git a/models/from_xyY/multi_head_st2084_normalization_params.npz b/models/from_xyY/multi_head_st2084_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..fb9e0936afb5a21e25467be9bf48bac434e3f230 --- /dev/null +++ b/models/from_xyY/multi_head_st2084_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:636227fc4047d597ef16c450f45421e655ad85a7a0c733e3f56e54f09d28d680 +size 1057 diff --git a/models/from_xyY/multi_head_weighted_boundary.onnx b/models/from_xyY/multi_head_weighted_boundary.onnx new file mode 100644 index 0000000000000000000000000000000000000000..01889dd5f29f0ddf8381f196b0f8d17993613829 --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02c0a5bbf4bb58299a83ab31187a5736f390382c51f7e596259cad80fbb9eb27 +size 18637 diff --git a/models/from_xyY/multi_head_weighted_boundary.onnx.data b/models/from_xyY/multi_head_weighted_boundary.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..7e1d98c94debe5b8b61d7fbfbbfb9d3d0213c4f4 --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:168a95d80ad6ff6aa5c2f10a062e8afbe0d61c5535e83496b39426d15e8b5164 +size 3992064 diff --git a/models/from_xyY/multi_head_weighted_boundary_best.pth b/models/from_xyY/multi_head_weighted_boundary_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..9882a5735858e32b25edacb8a39b86c3bcbb4525 --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22e0008df9294edc57b5743be0ca0c7321b74c7c32ae1e7a483bf2e1885e8d30 +size 4023939 diff --git a/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor.onnx b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8826edfee42d70f7c130c8dcf1edb53ffcf50b05 --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c110d45d3f1c95e6c1e85c15b31add55805a7409cd485ea17558bbd8cf25522b +size 43010 diff --git a/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor.onnx.data b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..98496ba8eacc4ecb5643713580447ddbcfe339ac --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2f0f8b391519dca05b5eb7f0312e8a6ad6a937ceaa86bc402d1d4187c748826 +size 7077888 diff --git a/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx new file mode 100644 index 0000000000000000000000000000000000000000..92d0e9d38adccc2619eed599262a9c2ca34d880e --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:034ba9ad10c1ff1e5c465eb94416eafd70a7d83de89c09511743ba996a6fdd15 +size 45242 diff --git a/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx.data b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..6d34c335a651c5b7730b120828976a2d72b8c318 --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed3f69a5115489fe92fd41b1a61d9dcc37a6e20bd4f497e20049cfe264e4d552 +size 7077888 diff --git a/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary_best.pth b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..4e1b57c5ef55f4822324481148fce1f9be94f63e --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_multi_error_predictor_weighted_boundary_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd8cbcfc73873e9f4224d3a77757d8a5fed86ff72557ccb607c06bb1a4dce4c1 +size 7083659 diff --git a/models/from_xyY/multi_head_weighted_boundary_normalization_params.npz b/models/from_xyY/multi_head_weighted_boundary_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_head_weighted_boundary_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_mlp.onnx b/models/from_xyY/multi_mlp.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ca8b353aa7f30729dabf91c816a3657173999347 --- /dev/null +++ b/models/from_xyY/multi_mlp.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbcec138e47fd1c14f123b7585441655a6657fb77b17a794fbf8cdc1caf80181 +size 184821 diff --git a/models/from_xyY/multi_mlp_large.onnx b/models/from_xyY/multi_mlp_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9a643d023abf51ad9aa6fceebd98dccb5e68c7ba --- /dev/null +++ b/models/from_xyY/multi_mlp_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b7194a6aec653427194c7a421a57d5d7c3045eec56dc0266fc364df7cde8af4 +size 184419 diff --git a/models/from_xyY/multi_mlp_large.onnx.data b/models/from_xyY/multi_mlp_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..e1f535f415c3b9f54752e766d5bc8f660fdd8860 --- /dev/null +++ b/models/from_xyY/multi_mlp_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c68014ae938cd76fd47f88dc5ebdf7669bb0b9803dcf5b5e2d5ab7b68c5bc071 +size 9371648 diff --git a/models/from_xyY/multi_mlp_large_best.pth b/models/from_xyY/multi_mlp_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..1c50a2ab2ad4ee202fea509ce81db7f031de309b --- /dev/null +++ b/models/from_xyY/multi_mlp_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ca0ae85e73dc2df2f4002471b12f5772c1293c42a75f0c506affa3e5c92bdba +size 9365033 diff --git a/models/from_xyY/multi_mlp_large_normalization_params.npz b/models/from_xyY/multi_mlp_large_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_mlp_large_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_mlp_multi_error_predictor.onnx b/models/from_xyY/multi_mlp_multi_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..15066f71d475b80495df5ba0da8a1836ba62cab0 --- /dev/null +++ b/models/from_xyY/multi_mlp_multi_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d897aff41c198f8e9b9bbc6007a2598a63c673d203e87d441ba692202537169 +size 40988 diff --git a/models/from_xyY/multi_mlp_multi_error_predictor_large.onnx b/models/from_xyY/multi_mlp_multi_error_predictor_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e50d030e24454ed1bba88e8666bda325c6f0f859 --- /dev/null +++ b/models/from_xyY/multi_mlp_multi_error_predictor_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa2bb842411954ef691ee73c64679751206afc0ab4ab166cc777012189dcecc1 +size 41391 diff --git a/models/from_xyY/multi_mlp_multi_error_predictor_large.onnx.data b/models/from_xyY/multi_mlp_multi_error_predictor_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..ae72efd011b1354c450324ad1869913e69eb6329 --- /dev/null +++ b/models/from_xyY/multi_mlp_multi_error_predictor_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89a1d4c1795987a396a3996f604e309e5cefd7a2bdf37c6abdd84a15bd27ff79 +size 7077888 diff --git a/models/from_xyY/multi_mlp_multi_error_predictor_large_best.pth b/models/from_xyY/multi_mlp_multi_error_predictor_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..c81a050bf591a208f975c4194fccd13a1cb55183 --- /dev/null +++ b/models/from_xyY/multi_mlp_multi_error_predictor_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a9fc37387dc93c1f4c7b3490506c04734c07e620c097152e985a1f3e6515faa +size 7069157 diff --git a/models/from_xyY/multi_mlp_normalization_params.npz b/models/from_xyY/multi_mlp_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_mlp_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_mlp_weighted_boundary.onnx b/models/from_xyY/multi_mlp_weighted_boundary.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b5f2b3897c1ff83f09ca6dfc9351db2624105f98 --- /dev/null +++ b/models/from_xyY/multi_mlp_weighted_boundary.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ef8fe95476fe5d4b6d730040fd88253b69000074cc243289c32f2aaa274fa75 +size 187111 diff --git a/models/from_xyY/multi_mlp_weighted_boundary_multi_error_predictor.onnx b/models/from_xyY/multi_mlp_weighted_boundary_multi_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..98666956c09fb0db4da71d60cdad0f82bb50cd03 --- /dev/null +++ b/models/from_xyY/multi_mlp_weighted_boundary_multi_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e367cefafc45e6bee6f21add62940d76887430ad9f6bbed2a7ba63359799d9b +size 43258 diff --git a/models/from_xyY/multi_mlp_weighted_boundary_normalization_params.npz b/models/from_xyY/multi_mlp_weighted_boundary_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_mlp_weighted_boundary_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/multi_resnet_error_predictor_large.onnx b/models/from_xyY/multi_resnet_error_predictor_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..42c5a68cf4f450fba0e3abe8f8f7c0bee6163039 --- /dev/null +++ b/models/from_xyY/multi_resnet_error_predictor_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f416c6d14cfdd7d12fb0e1c239819733b5b9a307498fe38d649a598a671e398 +size 57152 diff --git a/models/from_xyY/multi_resnet_error_predictor_large.onnx.data b/models/from_xyY/multi_resnet_error_predictor_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..7f2f8cca241283dfff2968d13dfae6177989b695 --- /dev/null +++ b/models/from_xyY/multi_resnet_error_predictor_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6896c5bcb4cf266670806d2db8cd538fb54cf96d99f4c5c3059c2b34558bfe7a +size 14849024 diff --git a/models/from_xyY/multi_resnet_error_predictor_large_best.pth b/models/from_xyY/multi_resnet_error_predictor_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..6ee4db89241bbc7e873c4155c2dc8eb88b64e0e6 --- /dev/null +++ b/models/from_xyY/multi_resnet_error_predictor_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da6cc64385c2f1580f773973c4abbe0d76b2383a7ef2abef7c6181cf8d4721f3 +size 15023263 diff --git a/models/from_xyY/multi_resnet_large.onnx b/models/from_xyY/multi_resnet_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..688695f5a45ef8b81f98b7bb6c0ba4b9d6499520 --- /dev/null +++ b/models/from_xyY/multi_resnet_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a73d4a9e75e3b124d54c761f81dfbc563acec889d5dbfd7e189d1dcee7bd64bb +size 54769 diff --git a/models/from_xyY/multi_resnet_large.onnx.data b/models/from_xyY/multi_resnet_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..a1e8e4f40791178c9a6159f5a3a3cee569c17bd3 --- /dev/null +++ b/models/from_xyY/multi_resnet_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2f9da82a609c07f6ffcb40fa42fd9f1768d81326c3a9bdbb68de5cc0485abe8 +size 14828544 diff --git a/models/from_xyY/multi_resnet_large_best.pth b/models/from_xyY/multi_resnet_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..8dda8cefc011e71c77b91a0d4be11136716fd1d4 --- /dev/null +++ b/models/from_xyY/multi_resnet_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edd0b4f5d01e84529813d266913bb873ba66c4a8a87a5fd757c1f0c9df62c3da +size 14999039 diff --git a/models/from_xyY/multi_resnet_large_normalization_params.npz b/models/from_xyY/multi_resnet_large_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/multi_resnet_large_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/transformer_error_predictor_large_best.pth b/models/from_xyY/transformer_error_predictor_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..a20cce5a1cf543970396873b15b91e428e1b79a4 --- /dev/null +++ b/models/from_xyY/transformer_error_predictor_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47b7191af082aaf5a66045af129da016aa7990317ece1d47f9be4a85b5b219f6 +size 7068541 diff --git a/models/from_xyY/transformer_large.onnx b/models/from_xyY/transformer_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7bc47c52e235cc9520470cc3835461adedd2389b --- /dev/null +++ b/models/from_xyY/transformer_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:606c8612d3f1fd42d59e00f99984941b4bd799c8c52512a47032b06d540b8be8 +size 19919572 diff --git a/models/from_xyY/transformer_large_best.pth b/models/from_xyY/transformer_large_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..b9fddabaec0e1a81b45bc7bceda133ab0fe27633 --- /dev/null +++ b/models/from_xyY/transformer_large_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ae4b85672f68b09721a77b1713b12dd6a33d75a0edc964bf3f76def9c9cac95 +size 19792865 diff --git a/models/from_xyY/transformer_large_normalization_params.npz b/models/from_xyY/transformer_large_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/transformer_large_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/from_xyY/transformer_multi_error_predictor_large.onnx b/models/from_xyY/transformer_multi_error_predictor_large.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4654d06175998000a1ebf5a1541ac3e4a406851f --- /dev/null +++ b/models/from_xyY/transformer_multi_error_predictor_large.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44a196fe76c2ee892380687935320822856b9bdbd0d5cc2fcd774d78694ee34c +size 41645 diff --git a/models/from_xyY/transformer_multi_error_predictor_large.onnx.data b/models/from_xyY/transformer_multi_error_predictor_large.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..6bf6484f4c6ec59749f03616b39d2492f1565fd8 --- /dev/null +++ b/models/from_xyY/transformer_multi_error_predictor_large.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d951b2478f31e335a4475330953cfe8317da56fcb6c3076e08877ebe14883162 +size 7077888 diff --git a/models/from_xyY/unified_mlp.onnx b/models/from_xyY/unified_mlp.onnx new file mode 100644 index 0000000000000000000000000000000000000000..eca80b7f9d7c8662aa4a2fbc378bcca7ea271c9b --- /dev/null +++ b/models/from_xyY/unified_mlp.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87212fe0e3a054d33d95abae03410f26188fda19b4842f4cc28e0a950cc47b26 +size 16536 diff --git a/models/from_xyY/unified_mlp_normalization_params.npz b/models/from_xyY/unified_mlp_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..0753e70279be6cb1b80024cfde667cde41b04fdc --- /dev/null +++ b/models/from_xyY/unified_mlp_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91ac6cb85a8fd8b29dd3cc9ad1978746a795fc3949b138b47c75f47bf30dede1 +size 1004 diff --git a/models/to_xyY/multi_head.onnx b/models/to_xyY/multi_head.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3a21f8a91d216e98a03863b3568bb2a55f1f2e77 --- /dev/null +++ b/models/to_xyY/multi_head.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eca37693faddf35da3815d4e4971986875f9970b36a99ff904db635bf33e3fea +size 12468 diff --git a/models/to_xyY/multi_head.onnx.data b/models/to_xyY/multi_head.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..2b602eff392639442557b363a6651451c4e4940f --- /dev/null +++ b/models/to_xyY/multi_head.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae7f726ec7ce240daf0c5301e0c0b8218294425894cf5cae22e54224930fae3c +size 2665984 diff --git a/models/to_xyY/multi_head.pth b/models/to_xyY/multi_head.pth new file mode 100644 index 0000000000000000000000000000000000000000..7a0f3cd289a33a98924f814edbed8a05aa26c711 --- /dev/null +++ b/models/to_xyY/multi_head.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5f907fae7661d1e011b1df3c338cece3129dc4c7f1c4914301eb5d92ab7d2b4 +size 2687989 diff --git a/models/to_xyY/multi_head_multi_error_predictor.onnx b/models/to_xyY/multi_head_multi_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c414560157caca1be8bbde531bfe5f9b4289eb80 --- /dev/null +++ b/models/to_xyY/multi_head_multi_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:435236ecea2b58a25000da2bf3d3a794b3fca3e58f38ba09891b4e6a1a19dde6 +size 29496 diff --git a/models/to_xyY/multi_head_multi_error_predictor.onnx.data b/models/to_xyY/multi_head_multi_error_predictor.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..7f32f8683cf62310c9ae7b3abdc1c48d265d221f --- /dev/null +++ b/models/to_xyY/multi_head_multi_error_predictor.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56f09660102e82b74bf1ec5cecefb5b7a0dccb1e36e75335d67fb4b62954ac15 +size 4021248 diff --git a/models/to_xyY/multi_head_multi_error_predictor.pth b/models/to_xyY/multi_head_multi_error_predictor.pth new file mode 100644 index 0000000000000000000000000000000000000000..2ebc20ec4ee0a86e6a3f661b0812193c53f6138b --- /dev/null +++ b/models/to_xyY/multi_head_multi_error_predictor.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b5e4a2104689f64ee4b0c1e22c3f1e12b79cb28057a8214481ed0e8ecac37c8 +size 4060430 diff --git a/models/to_xyY/multi_head_multi_error_predictor_normalization_params.npz b/models/to_xyY/multi_head_multi_error_predictor_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..d07d6ead6793b397c49273975e93c0eb74e69c41 --- /dev/null +++ b/models/to_xyY/multi_head_multi_error_predictor_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dec1cc5084bdeaf0691eee39a226b539ac9592de3ccda5d6684803423de98ae8 +size 1088 diff --git a/models/to_xyY/multi_head_normalization_params.npz b/models/to_xyY/multi_head_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_head_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/multi_head_optimized.onnx b/models/to_xyY/multi_head_optimized.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e369b7434417096b7ec6ee6425dac638009aa06f --- /dev/null +++ b/models/to_xyY/multi_head_optimized.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93517c09086c5c39f616356945eaa27d8fc244203718bdef817686c21060a306 +size 13049 diff --git a/models/to_xyY/multi_head_optimized.onnx.data b/models/to_xyY/multi_head_optimized.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..9cbcd80720f59872115c72f9d3a378d9b971ed10 --- /dev/null +++ b/models/to_xyY/multi_head_optimized.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc47ea391a8853f70825f06a6d46d329291b7ebe60e1fb7b5799858857dd9551 +size 6029312 diff --git a/models/to_xyY/multi_head_optimized.pth b/models/to_xyY/multi_head_optimized.pth new file mode 100644 index 0000000000000000000000000000000000000000..5fd3cee6759f1e04d4e096f02308efed54dfc79d --- /dev/null +++ b/models/to_xyY/multi_head_optimized.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7f99022d82c578c2a27ce057f98249bdee351b9aa341688c81ad954d3963564 +size 5988003 diff --git a/models/to_xyY/multi_head_optimized_normalization_params.npz b/models/to_xyY/multi_head_optimized_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_head_optimized_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/multi_mlp.onnx b/models/to_xyY/multi_mlp.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0e28106421e824bfd8efefa8479f8601ccffdc37 --- /dev/null +++ b/models/to_xyY/multi_mlp.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:169124985d7b96754e74369e30d2c3275aba2bdbfd55401c20fc7c223ce5a905 +size 23145 diff --git a/models/to_xyY/multi_mlp_error_predictor.onnx b/models/to_xyY/multi_mlp_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3e5899c48f069ec9b97727bacc9714a833ee55ee --- /dev/null +++ b/models/to_xyY/multi_mlp_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c54e745262beaf00523a0a8b65fdfca19c0ebd175082220706d9856e0036fb05 +size 19871 diff --git a/models/to_xyY/multi_mlp_error_predictor_normalization_params.npz b/models/to_xyY/multi_mlp_error_predictor_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_mlp_error_predictor_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/multi_mlp_multi_error_predictor.onnx b/models/to_xyY/multi_mlp_multi_error_predictor.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b122133eaa6d71eb6cbf10299e926b1dcef1b71b --- /dev/null +++ b/models/to_xyY/multi_mlp_multi_error_predictor.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35531d1181271549ed445df49ac7753966a61c675ab748f8312fbb2751fb7b6d +size 29594 diff --git a/models/to_xyY/multi_mlp_multi_error_predictor_normalization_params.npz b/models/to_xyY/multi_mlp_multi_error_predictor_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_mlp_multi_error_predictor_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/multi_mlp_multi_error_predictor_optimized.onnx b/models/to_xyY/multi_mlp_multi_error_predictor_optimized.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2fd1539080e6b6da48454da40dbecde9a96377fa --- /dev/null +++ b/models/to_xyY/multi_mlp_multi_error_predictor_optimized.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7c4d3b9c8dde2fa42e4d3ec4c62c10b9416a3d8bfd782d4cd6883038282a488 +size 30576 diff --git a/models/to_xyY/multi_mlp_multi_error_predictor_optimized_normalization_params.npz b/models/to_xyY/multi_mlp_multi_error_predictor_optimized_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_mlp_multi_error_predictor_optimized_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/multi_mlp_normalization_params.npz b/models/to_xyY/multi_mlp_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_mlp_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/multi_mlp_optimized.onnx b/models/to_xyY/multi_mlp_optimized.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c44f42057ebcdf8d7ef7c22f9725ca360da83a46 --- /dev/null +++ b/models/to_xyY/multi_mlp_optimized.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7dca0c5f7a2c0252c0b11be5773cb46acaef4e9fc4606a0190acdd836c47057b +size 19389 diff --git a/models/to_xyY/multi_mlp_optimized_normalization_params.npz b/models/to_xyY/multi_mlp_optimized_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/multi_mlp_optimized_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/munsell_to_xyY_approximator.onnx b/models/to_xyY/munsell_to_xyY_approximator.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a8138fb652cf1c9730eaecb2e78ed6ae6b57b80c --- /dev/null +++ b/models/to_xyY/munsell_to_xyY_approximator.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ebb2f74cec1677e3fba02edb0d6412327c1a17cf48309840ff9adf770cc8769 +size 3550 diff --git a/models/to_xyY/munsell_to_xyY_approximator.onnx.data b/models/to_xyY/munsell_to_xyY_approximator.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..dea037cecbf7b384c65c996862f52cc3dbf0963f --- /dev/null +++ b/models/to_xyY/munsell_to_xyY_approximator.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2311563c71b05cf3283fa7cd74921ea98564c8591605ff03d08857e30fdff04 +size 271872 diff --git a/models/to_xyY/munsell_to_xyY_approximator.pth b/models/to_xyY/munsell_to_xyY_approximator.pth new file mode 100644 index 0000000000000000000000000000000000000000..d79e7f3613d5397598971e621c8661f3aa9bc324 --- /dev/null +++ b/models/to_xyY/munsell_to_xyY_approximator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edf28cb8e18d24ebc63444e5f11cf081aa2a6fdc4e96ebb2639f93a4b21f3a3e +size 277833 diff --git a/models/to_xyY/munsell_to_xyY_approximator_normalization_params.npz b/models/to_xyY/munsell_to_xyY_approximator_normalization_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..f4aa5f68fc51cc7abadce5461ff2cfc517c1b7ce --- /dev/null +++ b/models/to_xyY/munsell_to_xyY_approximator_normalization_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0de120a7f1896a3ddb29c0c51dfc8e58bd0891af6750a6db9a9bae777a6089d +size 986 diff --git a/models/to_xyY/munsell_to_xyY_approximator_weights.npz b/models/to_xyY/munsell_to_xyY_approximator_weights.npz new file mode 100644 index 0000000000000000000000000000000000000000..6db712b30b63272a07965c1a9b756a00eab32c05 --- /dev/null +++ b/models/to_xyY/munsell_to_xyY_approximator_weights.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6a59296ca146a1ea7398593e4b4756389c623672ada5583bddc2f3aa2dcb4bb +size 275462 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7dfd0b8677cd265346d7fae88409cb8414aa368b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,121 @@ +[project] +name = "learning-munsell" +version = "0.1.0" +description = "Learning Munsell: ML-based CIE xyY to Munsell specification conversion" +readme = "README.md" +requires-python = ">=3.11,<3.15" +authors = [ + { name = "Colour Developers", email = "colour-developers@colour-science.org" }, +] +license = { text = "BSD-3-Clause" } +dependencies = [ + "click>=8.0.0", + "numpy>=2.0.0,<3", + "onnxruntime>=1.16.0", + "torch>=2.0.0", + "scikit-learn>=1.3.0", + "onnx>=1.15.0", + "matplotlib>=3.9", + "tqdm>=4.66.0", + "mlflow>=2.10.0", + "colour-science>=0.4.7", + "onnxscript>=0.5.6", + "optuna>=3.0.0", + "jax>=0.4.20", + "jaxlib>=0.4.20", + "flax>=0.10.7", + "optax>=0.2.6", + "scipy>=1.12.0,<2", + "tensorboard>=2.20.0", + "netron>=8.7.7", +] + +[dependency-groups] +dev = [ + "pre-commit", + "pyright", + "pytest", + "ruff", +] + +[tool.codespell] +ignore-words-list = "colour" + +[tool.isort] +ensure_newline_before_comments = true +force_grid_wrap = 0 +include_trailing_comma = true +line_length = 88 +multi_line_output = 3 +split_on_trailing_comma = true +use_parentheses = true + +[tool.pyright] +reportMissingImports = false +reportMissingModuleSource = false +reportUnboundVariable = false +reportUnnecessaryCast = true +reportUnnecessaryTypeIgnoreComment = true +reportUnsupportedDunderAll = false +reportUnusedExpression = false + +[tool.ruff] +target-version = "py311" +line-length = 88 + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "C", # Pylint - Convention + "C90", # mccabe + "COM", # flake8-commas + "ERA", # eradicate + "FBT", # flake8-boolean-trap + "FIX", # flake8-fixme + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib [Enable] + "TD", # flake8-todos + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `**kwargs` + "D200", # One-line docstring should fit on one line + "D202", # No blank lines allowed after function docstring + "D205", # 1 blank line required between summary line and description + "D301", # Use `r"""` if any backslashes in a docstring + "D400", # First line should end with a period + "I001", # Import block is un-sorted or un-formatted + "N801", # Class name `.*` should use CapWords convention + "N802", # Function name `.*` should be lowercase + "N803", # Argument name `.*` should be lowercase + "N806", # Variable `.*` in function should be lowercase + "N813", # Camelcase `.*` imported as lowercase `.*` + "N815", # Variable `.*` in class scope should not be mixedCase + "N816", # Variable `.*` in global scope should not be mixedCase + "NPY002", # Replace legacy `np.random.random` call with `np.random.Generator` + "PGH003", # Use specific rule codes when ignoring type issues + "PLR0912", # Too many branches + "PLR0913", # Too many arguments in function definition + "PLR0915", # Too many statements + "PLR2004", # Magic value used in comparison, consider replacing `.*` with a constant variable + "PYI036", # Star-args in `.*` should be annotated with `object` + "PYI051", # `Literal[".*"]` is redundant in a union with `str` + "PYI056", # Calling `.append()` on `__all__` may not be supported by all type checkers (use `+=` instead) + "RUF022", # [*] `__all__` is not sorted + "TRY003", # Avoid specifying long messages outside the exception class + "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` + "N999", # Invalid module name (for from_xyY, to_xyY directories) + "DTZ005", # datetime.now() without tz argument (timestamps for logging) +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["D104"] +"learning_munsell/training/*" = ["INP001", "T201"] +"learning_munsell/comparison/*" = ["INP001", "T201"] +# HTML templates in compare_all_models.py files have long lines +"learning_munsell/comparison/*/compare_all_models.py" = ["E501"] +"learning_munsell/data_generation/*" = ["INP001", "T201"] +"learning_munsell/interpolation/*" = ["INP001", "T201"] + +[tool.ruff.format] +docstring-code-format = true