Commit
·
fa06c67
0
Parent(s):
Initial commit.
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- .gitignore +37 -0
- .pre-commit-config.yaml +39 -0
- LICENSE +11 -0
- README.md +278 -0
- docs/_static/gamma_sweep_plot.pdf +0 -0
- docs/_static/gamma_sweep_plot.png +3 -0
- docs/learning_munsell.md +478 -0
- learning_munsell/__init__.py +7 -0
- learning_munsell/analysis/__init__.py +1 -0
- learning_munsell/analysis/error_analysis.py +304 -0
- learning_munsell/comparison/from_xyY/__init__.py +1 -0
- learning_munsell/comparison/from_xyY/compare_all_models.py +1292 -0
- learning_munsell/comparison/from_xyY/compare_gamma_model.py +390 -0
- learning_munsell/comparison/to_xyY/__init__.py +1 -0
- learning_munsell/comparison/to_xyY/compare_all_models.py +617 -0
- learning_munsell/data_generation/generate_training_data.py +310 -0
- learning_munsell/interpolation/__init__.py +1 -0
- learning_munsell/interpolation/from_xyY/__init__.py +43 -0
- learning_munsell/interpolation/from_xyY/compare_methods.py +208 -0
- learning_munsell/interpolation/from_xyY/delaunay_interpolator.py +283 -0
- learning_munsell/interpolation/from_xyY/kdtree_interpolator.py +263 -0
- learning_munsell/interpolation/from_xyY/rbf_interpolator.py +300 -0
- learning_munsell/losses/__init__.py +17 -0
- learning_munsell/losses/jax_delta_e.py +299 -0
- learning_munsell/models/__init__.py +47 -0
- learning_munsell/models/networks.py +1294 -0
- learning_munsell/training/from_xyY/__init__.py +1 -0
- learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py +503 -0
- learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py +541 -0
- learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py +552 -0
- learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py +471 -0
- learning_munsell/training/from_xyY/refine_multi_head_real.py +358 -0
- learning_munsell/training/from_xyY/train_deep_wide.py +371 -0
- learning_munsell/training/from_xyY/train_ft_transformer.py +356 -0
- learning_munsell/training/from_xyY/train_mixture_of_experts.py +620 -0
- learning_munsell/training/from_xyY/train_mlp.py +269 -0
- learning_munsell/training/from_xyY/train_mlp_attention.py +460 -0
- learning_munsell/training/from_xyY/train_mlp_error_predictor.py +457 -0
- learning_munsell/training/from_xyY/train_mlp_gamma.py +297 -0
- learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py +411 -0
- learning_munsell/training/from_xyY/train_multi_head_circular.py +479 -0
- learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py +640 -0
- learning_munsell/training/from_xyY/train_multi_head_gamma.py +300 -0
- learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py +605 -0
- learning_munsell/training/from_xyY/train_multi_head_large.py +246 -0
- learning_munsell/training/from_xyY/train_multi_head_mlp.py +269 -0
- learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py +378 -0
- learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py +409 -0
- learning_munsell/training/from_xyY/train_multi_head_st2084.py +313 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.onnx.data filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Common Files
|
| 2 |
+
*.egg-info
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
.DS_Store
|
| 6 |
+
.coverage*
|
| 7 |
+
uv.lock
|
| 8 |
+
|
| 9 |
+
# Common Directories
|
| 10 |
+
.fleet/
|
| 11 |
+
.idea/
|
| 12 |
+
.ipynb_checkpoints/
|
| 13 |
+
.python-version
|
| 14 |
+
.vs/
|
| 15 |
+
.vscode/
|
| 16 |
+
.sandbox/
|
| 17 |
+
build/
|
| 18 |
+
dist/
|
| 19 |
+
docs/_build/
|
| 20 |
+
docs/generated/
|
| 21 |
+
node_modules/
|
| 22 |
+
references/
|
| 23 |
+
|
| 24 |
+
__pycache__
|
| 25 |
+
|
| 26 |
+
.claude/settings.local.json
|
| 27 |
+
.claude/scratchpad.md
|
| 28 |
+
|
| 29 |
+
# Project Directories
|
| 30 |
+
data/
|
| 31 |
+
logs/
|
| 32 |
+
mlartifacts/
|
| 33 |
+
mlruns/
|
| 34 |
+
mlruns.db
|
| 35 |
+
reports/
|
| 36 |
+
results/
|
| 37 |
+
runs/
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: "v5.0.0"
|
| 4 |
+
hooks:
|
| 5 |
+
- id: check-added-large-files
|
| 6 |
+
- id: check-case-conflict
|
| 7 |
+
- id: check-merge-conflict
|
| 8 |
+
- id: check-symlinks
|
| 9 |
+
- id: check-yaml
|
| 10 |
+
- id: debug-statements
|
| 11 |
+
- id: end-of-file-fixer
|
| 12 |
+
- id: mixed-line-ending
|
| 13 |
+
- id: requirements-txt-fixer
|
| 14 |
+
- id: trailing-whitespace
|
| 15 |
+
- repo: https://github.com/codespell-project/codespell
|
| 16 |
+
rev: v2.4.1
|
| 17 |
+
hooks:
|
| 18 |
+
- id: codespell
|
| 19 |
+
args: ["--ignore-words-list=colour"]
|
| 20 |
+
- repo: https://github.com/PyCQA/isort
|
| 21 |
+
rev: "6.0.1"
|
| 22 |
+
hooks:
|
| 23 |
+
- id: isort
|
| 24 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 25 |
+
rev: "v0.12.4"
|
| 26 |
+
hooks:
|
| 27 |
+
- id: ruff-format
|
| 28 |
+
- id: ruff
|
| 29 |
+
args: [--fix]
|
| 30 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
| 31 |
+
rev: "v4.0.0-alpha.8"
|
| 32 |
+
hooks:
|
| 33 |
+
- id: prettier
|
| 34 |
+
- repo: https://github.com/pre-commit/pygrep-hooks
|
| 35 |
+
rev: "v1.10.0"
|
| 36 |
+
hooks:
|
| 37 |
+
- id: rst-backticks
|
| 38 |
+
- id: rst-directive-colons
|
| 39 |
+
- id: rst-inline-touching-normal
|
LICENSE
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright 2025 Colour Developers
|
| 2 |
+
|
| 3 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 4 |
+
|
| 5 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 6 |
+
|
| 7 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 8 |
+
|
| 9 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 10 |
+
|
| 11 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
|
README.md
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: bsd-3-clause
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- python
|
| 7 |
+
- colour
|
| 8 |
+
- color
|
| 9 |
+
- colour-science
|
| 10 |
+
- color-science
|
| 11 |
+
- colour-spaces
|
| 12 |
+
- color-spaces
|
| 13 |
+
- colourspace
|
| 14 |
+
- colorspace
|
| 15 |
+
pipeline_tag: tabular-regression
|
| 16 |
+
library_name: onnxruntime
|
| 17 |
+
metrics:
|
| 18 |
+
- mae
|
| 19 |
+
model-index:
|
| 20 |
+
- name: from_xyY (CIE xyY to Munsell)
|
| 21 |
+
results:
|
| 22 |
+
- task:
|
| 23 |
+
type: tabular-regression
|
| 24 |
+
name: CIE xyY to Munsell Specification
|
| 25 |
+
dataset:
|
| 26 |
+
name: CIE xyY to Munsell Specification
|
| 27 |
+
type: munsell-renotation
|
| 28 |
+
metrics:
|
| 29 |
+
- type: delta-e
|
| 30 |
+
value: 0.52
|
| 31 |
+
name: Delta-E CIE2000
|
| 32 |
+
- type: inference_time_ms
|
| 33 |
+
value: 0.089
|
| 34 |
+
name: Inference Time (ms/sample)
|
| 35 |
+
- name: to_xyY (Munsell to CIE xyY)
|
| 36 |
+
results:
|
| 37 |
+
- task:
|
| 38 |
+
type: tabular-regression
|
| 39 |
+
name: Munsell Specification to CIE xyY
|
| 40 |
+
dataset:
|
| 41 |
+
name: Munsell Specification to CIE xyY
|
| 42 |
+
type: munsell-renotation
|
| 43 |
+
metrics:
|
| 44 |
+
- type: delta-e
|
| 45 |
+
value: 0.48
|
| 46 |
+
name: Delta-E CIE2000
|
| 47 |
+
- type: inference_time_ms
|
| 48 |
+
value: 0.008
|
| 49 |
+
name: Inference Time (ms/sample)
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
# Learning Munsell - Machine Learning for Munsell Color Conversions
|
| 53 |
+
|
| 54 |
+
A project implementing machine learning-based methods for bidirectional conversion between CIE xyY colourspace values and Munsell specifications.
|
| 55 |
+
|
| 56 |
+
**Two Conversion Directions:**
|
| 57 |
+
|
| 58 |
+
- **from_xyY**: CIE xyY to Munsell specification
|
| 59 |
+
- **to_xyY**: Munsell specification to CIE xyY
|
| 60 |
+
|
| 61 |
+
## Project Overview
|
| 62 |
+
|
| 63 |
+
### Objective
|
| 64 |
+
|
| 65 |
+
Provide 100-1000x speedup for batch Munsell conversions compared to colour-science routines while maintaining high accuracy.
|
| 66 |
+
|
| 67 |
+
### Results
|
| 68 |
+
|
| 69 |
+
**from_xyY** (CIE xyY to Munsell) — evaluated on all 2,734 REAL Munsell colors:
|
| 70 |
+
|
| 71 |
+
| Model | Delta-E | Speed (ms) |
|
| 72 |
+
|----------------------------------------------------------| ---------- | ---------- |
|
| 73 |
+
| Colour Library (Baseline) | 0.00 | 111.90 |
|
| 74 |
+
| **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 |
|
| 75 |
+
| Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 |
|
| 76 |
+
| Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 |
|
| 77 |
+
| Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 |
|
| 78 |
+
| MLP + Error Predictor | 0.53 | 0.030 |
|
| 79 |
+
| Multi-ResNet (Large Dataset) | 0.54 | 0.044 |
|
| 80 |
+
| Multi-Head + Multi-Error Predictor | 0.54 | 0.042 |
|
| 81 |
+
| Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 |
|
| 82 |
+
| Deep + Wide | 0.60 | 0.074 |
|
| 83 |
+
| Multi-Head (Large Dataset) | 0.66 | 0.013 |
|
| 84 |
+
| Mixture of Experts | 0.80 | 0.020 |
|
| 85 |
+
| Transformer (Large Dataset) | 0.82 | 0.123 |
|
| 86 |
+
| Multi-MLP | 0.86 | 0.027 |
|
| 87 |
+
| MLP + Self-Attention | 0.88 | 0.173 |
|
| 88 |
+
| MLP (Base Only) | 1.09 | **0.007** |
|
| 89 |
+
| Unified MLP | 1.12 | 0.072 |
|
| 90 |
+
|
| 91 |
+
- **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) — Delta-E 0.52, 1,252x faster
|
| 92 |
+
- **Fastest**: MLP Base Only (0.007 ms/sample) — 15,492x faster than Colour library
|
| 93 |
+
- **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large — 1,951x faster with Delta-E 0.52
|
| 94 |
+
|
| 95 |
+
**to_xyY** (Munsell to CIE xyY) — evaluated on all 2,734 REAL Munsell colors:
|
| 96 |
+
|
| 97 |
+
| Model | Delta-E | Speed (ms) |
|
| 98 |
+
| --------------------------------------------- | ---------- | ----------- |
|
| 99 |
+
| Colour Library (Baseline) | 0.00 | 1.27 |
|
| 100 |
+
| **Multi-MLP (Optimized)** | **0.48** | 0.008 |
|
| 101 |
+
| Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 |
|
| 102 |
+
| Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 |
|
| 103 |
+
| Multi-MLP | 0.66 | 0.016 |
|
| 104 |
+
| Multi-MLP + Error Predictor | 0.67 | 0.018 |
|
| 105 |
+
| Multi-Head (Optimized) | 0.71 | 0.015 |
|
| 106 |
+
| Multi-Head | 0.78 | 0.008 |
|
| 107 |
+
| Multi-Head + Multi-Error Predictor | 1.11 | 0.028 |
|
| 108 |
+
| Simple MLP | 1.42 | **0.0008** |
|
| 109 |
+
|
| 110 |
+
- **Best Accuracy**: Multi-MLP (Optimized) — Delta-E 0.48, 154x faster
|
| 111 |
+
- **Fastest**: Simple MLP (0.0008 ms/sample) — 1,654x faster than Colour library
|
| 112 |
+
|
| 113 |
+
### Approach
|
| 114 |
+
|
| 115 |
+
- **25+ architectures** tested for from_xyY (MLP, Multi-Head, Multi-MLP, Multi-ResNet, Transformers, Mixture of Experts)
|
| 116 |
+
- **9 architectures** tested for to_xyY (Simple MLP, Multi-Head, Multi-MLP with error predictors)
|
| 117 |
+
- **Two-stage models** (base + error predictor) on large dataset proved most effective
|
| 118 |
+
- **Best model**: Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52
|
| 119 |
+
- **Training data**: ~1.4M samples from dense xyY grid with boundary refinement and forward Munsell sampling
|
| 120 |
+
- **Deployment**: ONNX format with ONNX Runtime
|
| 121 |
+
|
| 122 |
+
For detailed architecture comparisons, model benchmarks, training pipeline details, and experimental results, see [docs/learning_munsell.md](docs/learning_munsell.md).
|
| 123 |
+
|
| 124 |
+
## Installation
|
| 125 |
+
|
| 126 |
+
**Dependencies (Runtime)**:
|
| 127 |
+
|
| 128 |
+
- numpy >= 2.0
|
| 129 |
+
- onnxruntime >= 1.16
|
| 130 |
+
|
| 131 |
+
**Dependencies (Training)**:
|
| 132 |
+
|
| 133 |
+
- torch >= 2.0
|
| 134 |
+
- scikit-learn >= 1.3
|
| 135 |
+
- matplotlib >= 3.9
|
| 136 |
+
- mlflow >= 2.10
|
| 137 |
+
- optuna >= 3.0
|
| 138 |
+
- colour-science >= 0.4.7
|
| 139 |
+
- click >= 8.0
|
| 140 |
+
- onnx >= 1.15
|
| 141 |
+
- onnxscript >= 0.5.6
|
| 142 |
+
- tqdm >= 4.66
|
| 143 |
+
- jax >= 0.4.20
|
| 144 |
+
- jaxlib >= 0.4.20
|
| 145 |
+
- flax >= 0.10.7
|
| 146 |
+
- optax >= 0.2.6
|
| 147 |
+
- scipy >= 1.12
|
| 148 |
+
- tensorboard >= 2.20
|
| 149 |
+
|
| 150 |
+
From the project root:
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
cd learning-munsell
|
| 154 |
+
|
| 155 |
+
# Install all dependencies (creates virtual environment automatically)
|
| 156 |
+
uv sync
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## Usage
|
| 160 |
+
|
| 161 |
+
### Generate Training Data
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
uv run python learning_munsell/data_generation/generate_training_data.py
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
**Note**: This step is computationally expensive (uses iterative algorithm for ground truth).
|
| 168 |
+
|
| 169 |
+
### Train Models
|
| 170 |
+
|
| 171 |
+
**xyY to Munsell (from_xyY)**
|
| 172 |
+
|
| 173 |
+
Best performing model (Multi-ResNet + Multi-Error Predictor on Large Dataset):
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
# Train base Multi-ResNet on large dataset (~1.4M samples)
|
| 177 |
+
uv run python learning_munsell/training/from_xyY/train_multi_resnet_large.py
|
| 178 |
+
|
| 179 |
+
# Train multi-error predictor
|
| 180 |
+
uv run python learning_munsell/training/from_xyY/train_multi_resnet_error_predictor_large.py
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
Alternative (Multi-Head architecture):
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
uv run python learning_munsell/training/from_xyY/train_multi_head_large.py
|
| 187 |
+
uv run python learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
Other architectures:
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
uv run python learning_munsell/training/from_xyY/train_unified_mlp.py
|
| 194 |
+
uv run python learning_munsell/training/from_xyY/train_multi_mlp.py
|
| 195 |
+
uv run python learning_munsell/training/from_xyY/train_mlp_attention.py
|
| 196 |
+
uv run python learning_munsell/training/from_xyY/train_deep_wide.py
|
| 197 |
+
uv run python learning_munsell/training/from_xyY/train_ft_transformer.py
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
**Munsell to xyY (to_xyY)**
|
| 201 |
+
|
| 202 |
+
Best performing model (Multi-MLP Optimized):
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
uv run python learning_munsell/training/to_xyY/train_multi_mlp.py
|
| 206 |
+
uv run python learning_munsell/training/to_xyY/train_multi_head.py
|
| 207 |
+
uv run python learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py
|
| 208 |
+
uv run python learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py
|
| 209 |
+
uv run python learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
Train the differentiable approximator for use in Delta-E loss:
|
| 213 |
+
|
| 214 |
+
```bash
|
| 215 |
+
uv run python learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
### Hyperparameter Search
|
| 219 |
+
|
| 220 |
+
```bash
|
| 221 |
+
uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py
|
| 222 |
+
uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### Compare All Models
|
| 226 |
+
|
| 227 |
+
```bash
|
| 228 |
+
uv run python learning_munsell/comparison/from_xyY/compare_all_models.py
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
Generates comprehensive HTML report at `reports/from_xyY/model_comparison.html`.
|
| 232 |
+
|
| 233 |
+
### Monitor Training
|
| 234 |
+
|
| 235 |
+
**MLflow**:
|
| 236 |
+
|
| 237 |
+
```bash
|
| 238 |
+
uv run mlflow ui --backend-store-uri "sqlite:///mlruns.db" --port=5000
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
Open <http://localhost:5000> in your browser.
|
| 242 |
+
|
| 243 |
+
## Directory Structure
|
| 244 |
+
|
| 245 |
+
```
|
| 246 |
+
learning-munsell/
|
| 247 |
+
+-- data/ # Training data
|
| 248 |
+
| +-- training_data.npz # Generated training samples
|
| 249 |
+
| +-- training_data_large.npz # Large dataset (~1.4M samples)
|
| 250 |
+
| +-- training_data_params.json # Generation parameters
|
| 251 |
+
| +-- training_data_large_params.json
|
| 252 |
+
+-- models/ # Trained models (ONNX + PyTorch)
|
| 253 |
+
| +-- from_xyY/ # xyY to Munsell models (25+ ONNX models)
|
| 254 |
+
| | +-- multi_resnet_error_predictor_large.onnx # BEST
|
| 255 |
+
| | +-- ... (additional model variants)
|
| 256 |
+
| +-- to_xyY/ # Munsell to xyY models (9 ONNX models)
|
| 257 |
+
| +-- multi_mlp_optimized.onnx # BEST
|
| 258 |
+
| +-- ... (additional model variants)
|
| 259 |
+
+-- learning_munsell/ # Source code
|
| 260 |
+
| +-- analysis/ # Analysis scripts
|
| 261 |
+
| +-- comparison/ # Model comparison scripts
|
| 262 |
+
| +-- data_generation/ # Data generation scripts
|
| 263 |
+
| +-- interpolation/ # Classical interpolation methods
|
| 264 |
+
| +-- losses/ # Loss functions (JAX Delta-E)
|
| 265 |
+
| +-- models/ # Model architecture definitions
|
| 266 |
+
| +-- training/ # Model training scripts
|
| 267 |
+
| +-- utilities/ # Shared utilities
|
| 268 |
+
+-- docs/ # Documentation
|
| 269 |
+
+-- reports/ # HTML comparison reports
|
| 270 |
+
+-- logs/ # Script output logs
|
| 271 |
+
+-- mlruns.db # MLflow experiment tracking database
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
## About
|
| 275 |
+
|
| 276 |
+
**Learning Munsell** by Colour Developers
|
| 277 |
+
Research project for the Colour library
|
| 278 |
+
<https://github.com/colour-science/colour>
|
docs/_static/gamma_sweep_plot.pdf
ADDED
|
Binary file (22.2 kB). View file
|
|
|
docs/_static/gamma_sweep_plot.png
ADDED
|
Git LFS Details
|
docs/learning_munsell.md
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Learning Munsell
|
| 2 |
+
|
| 3 |
+
Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications:
|
| 8 |
+
|
| 9 |
+
- **xyY to Munsell (from_xyY)**: 25+ architectures, best Delta-E 0.52
|
| 10 |
+
- **Munsell to xyY (to_xyY)**: 9 architectures, best Delta-E 0.48
|
| 11 |
+
|
| 12 |
+
### Delta-E Interpretation
|
| 13 |
+
|
| 14 |
+
- **< 1.0**: Not perceptible by human eye
|
| 15 |
+
- **1-2**: Perceptible through close observation
|
| 16 |
+
- **2-10**: Perceptible at a glance
|
| 17 |
+
- **> 10**: Colors are perceived as completely different
|
| 18 |
+
|
| 19 |
+
Our best models achieve **Delta-E 0.48-0.52**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**.
|
| 20 |
+
|
| 21 |
+
## xyY to Munsell (from_xyY)
|
| 22 |
+
|
| 23 |
+
### Performance Benchmarks
|
| 24 |
+
|
| 25 |
+
Comprehensive comparison using all 2,734 REAL Munsell colors:
|
| 26 |
+
|
| 27 |
+
| Model | Delta-E | Speed (ms) |
|
| 28 |
+
|----------------------------------------------------------|-------------|------------|
|
| 29 |
+
| Colour Library (Baseline) | 0.00 | 111.90 |
|
| 30 |
+
| **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 |
|
| 31 |
+
| Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 |
|
| 32 |
+
| Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 |
|
| 33 |
+
| Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 |
|
| 34 |
+
| MLP + Error Predictor | 0.53 | 0.030 |
|
| 35 |
+
| Multi-ResNet (Large Dataset) | 0.54 | 0.044 |
|
| 36 |
+
| Multi-Head + Multi-Error Predictor | 0.54 | 0.042 |
|
| 37 |
+
| Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 |
|
| 38 |
+
| Deep + Wide | 0.60 | 0.074 |
|
| 39 |
+
| Multi-Head (Large Dataset) | 0.66 | 0.013 |
|
| 40 |
+
| Mixture of Experts | 0.80 | 0.020 |
|
| 41 |
+
| Transformer (Large Dataset) | 0.82 | 0.123 |
|
| 42 |
+
| Multi-MLP | 0.86 | 0.027 |
|
| 43 |
+
| MLP + Self-Attention | 0.88 | 0.173 |
|
| 44 |
+
| MLP (Base Only) | 1.09 | **0.007** |
|
| 45 |
+
| Unified MLP | 1.12 | 0.072 |
|
| 46 |
+
|
| 47 |
+
Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate).
|
| 48 |
+
|
| 49 |
+
**Best Models**:
|
| 50 |
+
|
| 51 |
+
- **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) - Delta-E 0.52
|
| 52 |
+
- **Fastest**: MLP Base Only (0.007 ms/sample) - 15,492x faster than Colour library
|
| 53 |
+
- **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large - 1,951x faster with Delta-E 0.52
|
| 54 |
+
|
| 55 |
+
### Model Architectures
|
| 56 |
+
|
| 57 |
+
25+ architectures were systematically evaluated:
|
| 58 |
+
|
| 59 |
+
**Single-Stage Models**
|
| 60 |
+
|
| 61 |
+
1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs
|
| 62 |
+
2. **Unified MLP** - Single large MLP with shared features
|
| 63 |
+
3. **Multi-Head** - Shared encoder with 4 independent decoder heads
|
| 64 |
+
4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples
|
| 65 |
+
5. **Multi-MLP** - 4 completely independent MLP branches (one per output)
|
| 66 |
+
6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples
|
| 67 |
+
7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting
|
| 68 |
+
8. **Deep + Wide** - Combined deep and wide network paths
|
| 69 |
+
9. **Mixture of Experts** - Gating network selecting specialized expert networks
|
| 70 |
+
10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data
|
| 71 |
+
11. **FT-Transformer** - Feature Tokenizer Transformer (standard size)
|
| 72 |
+
|
| 73 |
+
**Two-Stage Models**
|
| 74 |
+
|
| 75 |
+
12. **MLP + Error Predictor** - Base MLP with unified error correction
|
| 76 |
+
13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors
|
| 77 |
+
14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant
|
| 78 |
+
15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors
|
| 79 |
+
16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant
|
| 80 |
+
17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches (BEST)
|
| 81 |
+
|
| 82 |
+
The **Multi-ResNet + Multi-Error Predictor (Large Dataset)** architecture achieved the best results with Delta-E 0.52.
|
| 83 |
+
|
| 84 |
+
### Training Methodology
|
| 85 |
+
|
| 86 |
+
**Data Generation**
|
| 87 |
+
|
| 88 |
+
1. **Dense xyY Grid** (~500K samples)
|
| 89 |
+
- Regular grid in valid xyY space (MacAdam limits for Illuminant C)
|
| 90 |
+
- Captures general input distribution
|
| 91 |
+
2. **Boundary Refinement** (~700K samples)
|
| 92 |
+
- Adaptive dense sampling near Munsell gamut boundaries
|
| 93 |
+
- Uses `maximum_chroma_from_renotation` to detect edges
|
| 94 |
+
- Focuses on regions where iterative algorithm is most complex
|
| 95 |
+
- Includes Y/GY/G hue regions with high value/chroma (challenging areas)
|
| 96 |
+
3. **Forward Augmentation** (~200K samples)
|
| 97 |
+
- Dense Munsell space sampling via `munsell_specification_to_xyY`
|
| 98 |
+
- Ensures coverage of known valid colors
|
| 99 |
+
|
| 100 |
+
Total: ~1.4M samples for large dataset training.
|
| 101 |
+
|
| 102 |
+
**Loss Functions**
|
| 103 |
+
|
| 104 |
+
Two loss function approaches were tested:
|
| 105 |
+
|
| 106 |
+
*Precision-Focused Loss* (Default):
|
| 107 |
+
|
| 108 |
+
```
|
| 109 |
+
total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
- MSE: Standard mean squared error
|
| 113 |
+
- MAE: Mean absolute error
|
| 114 |
+
- Log penalty: Heavily penalizes small errors (pushes toward high precision)
|
| 115 |
+
- Huber loss: Small delta (0.01) for precision on small errors
|
| 116 |
+
|
| 117 |
+
*Pure MSE Loss* (Optimized config):
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
total_loss = MSE
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy.
|
| 124 |
+
|
| 125 |
+
### Design Rationale
|
| 126 |
+
|
| 127 |
+
**Two-Stage Architecture**
|
| 128 |
+
|
| 129 |
+
The error predictor stage corrects systematic biases in the base model:
|
| 130 |
+
|
| 131 |
+
1. Base model learns the general xyY to Munsell mapping
|
| 132 |
+
2. Error predictor learns residual corrections specific to each component
|
| 133 |
+
3. Combined prediction: `final = base_prediction + error_correction`
|
| 134 |
+
|
| 135 |
+
This decomposition allows each stage to specialize and reduces the complexity each network must learn.
|
| 136 |
+
|
| 137 |
+
**Independent Branch Design**
|
| 138 |
+
|
| 139 |
+
Munsell components have different characteristics:
|
| 140 |
+
|
| 141 |
+
- **Hue**: Circular (0-10, wrapping), most complex
|
| 142 |
+
- **Value**: Linear (0-10), easiest to predict
|
| 143 |
+
- **Chroma**: Highly variable range depending on hue/value
|
| 144 |
+
- **Code**: Discrete hue sector (0-9)
|
| 145 |
+
|
| 146 |
+
Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization.
|
| 147 |
+
|
| 148 |
+
**Architecture Details**
|
| 149 |
+
|
| 150 |
+
*MLP (Base Only)*
|
| 151 |
+
|
| 152 |
+
Simple feedforward network predicting all 4 outputs simultaneously:
|
| 153 |
+
|
| 154 |
+
Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code)
|
| 155 |
+
|
| 156 |
+
- Smallest model (~8KB ONNX)
|
| 157 |
+
- Fastest inference (0.007 ms)
|
| 158 |
+
- Baseline for comparison
|
| 159 |
+
|
| 160 |
+
*Unified MLP*
|
| 161 |
+
|
| 162 |
+
Single large MLP with shared internal features:
|
| 163 |
+
|
| 164 |
+
Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4)
|
| 165 |
+
|
| 166 |
+
- Shared representations across all outputs
|
| 167 |
+
- Moderate size, good speed
|
| 168 |
+
|
| 169 |
+
*Multi-Head MLP*
|
| 170 |
+
|
| 171 |
+
Shared encoder with specialized decoder heads:
|
| 172 |
+
|
| 173 |
+
Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1)
|
| 174 |
+
├──► Value Head (512→256→128→1)
|
| 175 |
+
├──► Chroma Head (512→384→256→128→1)
|
| 176 |
+
└──► Code Head (512→256→128→1)
|
| 177 |
+
|
| 178 |
+
- Shared encoder learns common color space features
|
| 179 |
+
- 4 specialized decoder heads branch from shared representation
|
| 180 |
+
- Parameter efficient (encoder weights shared)
|
| 181 |
+
- Fast inference (encoder computed once)
|
| 182 |
+
|
| 183 |
+
*Multi-MLP*
|
| 184 |
+
|
| 185 |
+
Fully independent branches with no weight sharing:
|
| 186 |
+
|
| 187 |
+
Input (3) ──► Hue Branch (3→128→256→512→256→128→1)
|
| 188 |
+
Input (3) ──► Value Branch (3→128→256→512→256→128→1)
|
| 189 |
+
Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider]
|
| 190 |
+
Input (3) ──► Code Branch (3→128→256→512→256→128→1)
|
| 191 |
+
|
| 192 |
+
- 4 completely independent MLPs
|
| 193 |
+
- Each branch learns its own features from scratch
|
| 194 |
+
- Chroma branch is wider (2x) to handle its complexity
|
| 195 |
+
- Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors)
|
| 196 |
+
|
| 197 |
+
*Multi-ResNet*
|
| 198 |
+
|
| 199 |
+
Deep branches with residual-style connections:
|
| 200 |
+
|
| 201 |
+
Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers]
|
| 202 |
+
Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers]
|
| 203 |
+
Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider]
|
| 204 |
+
Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers]
|
| 205 |
+
|
| 206 |
+
- Deeper architecture than Multi-MLP
|
| 207 |
+
- BatchNorm + SiLU activation
|
| 208 |
+
- Best accuracy when combined with error predictor (Delta-E 0.52)
|
| 209 |
+
- Largest model (~14MB base, ~28MB with error predictor)
|
| 210 |
+
|
| 211 |
+
*Deep + Wide*
|
| 212 |
+
|
| 213 |
+
Combined deep and wide network paths:
|
| 214 |
+
|
| 215 |
+
Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4)
|
| 216 |
+
└──► Wide Path (direct connection) ─┘
|
| 217 |
+
|
| 218 |
+
- Deep path captures complex patterns
|
| 219 |
+
- Wide path preserves direct input information
|
| 220 |
+
- Good for mixed linear/nonlinear relationships
|
| 221 |
+
|
| 222 |
+
*MLP + Self-Attention*
|
| 223 |
+
|
| 224 |
+
MLP with attention mechanism for feature weighting:
|
| 225 |
+
|
| 226 |
+
Input (3) ──► MLP ──► Self-Attention ──► Output (4)
|
| 227 |
+
|
| 228 |
+
- Attention weights learn feature importance
|
| 229 |
+
- Slower due to attention computation (0.173 ms)
|
| 230 |
+
- Did not improve over simpler MLPs
|
| 231 |
+
|
| 232 |
+
*Mixture of Experts*
|
| 233 |
+
|
| 234 |
+
Gating network selecting specialized expert networks:
|
| 235 |
+
|
| 236 |
+
Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4)
|
| 237 |
+
|
| 238 |
+
- Multiple expert networks specialize in different input regions
|
| 239 |
+
- Gating network learns which expert to use
|
| 240 |
+
- More complex but did not outperform Multi-MLP
|
| 241 |
+
|
| 242 |
+
*FT-Transformer*
|
| 243 |
+
|
| 244 |
+
Feature Tokenizer Transformer for tabular data:
|
| 245 |
+
|
| 246 |
+
Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4)
|
| 247 |
+
|
| 248 |
+
- Each input feature tokenized separately
|
| 249 |
+
- Self-attention across feature tokens
|
| 250 |
+
- Good for tabular data with feature interactions
|
| 251 |
+
- Slower inference due to attention computation
|
| 252 |
+
|
| 253 |
+
*Error Predictor (Two-Stage)*
|
| 254 |
+
|
| 255 |
+
Second-stage network that corrects base model errors:
|
| 256 |
+
|
| 257 |
+
Stage 1: Input (3) ──► Base Model ──► Base Prediction (4)
|
| 258 |
+
Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4)
|
| 259 |
+
Final: Base Prediction + Error Correction = Final Output
|
| 260 |
+
|
| 261 |
+
- Learns residual corrections for each component
|
| 262 |
+
- Can have unified (1 network) or multi (4 networks) error predictors
|
| 263 |
+
- Consistently improves accuracy across all base architectures
|
| 264 |
+
- Best results: Multi-ResNet + Multi-Error Predictor (Delta-E 0.52)
|
| 265 |
+
|
| 266 |
+
**Loss-Metric Mismatch**
|
| 267 |
+
|
| 268 |
+
An important finding: **optimizing MSE does not optimize Delta-E**.
|
| 269 |
+
|
| 270 |
+
The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because:
|
| 271 |
+
|
| 272 |
+
- MSE treats all component errors equally
|
| 273 |
+
- Delta-E (CIE2000) weights errors based on human perception
|
| 274 |
+
- The precision-focused loss with custom weights better approximates perceptual importance
|
| 275 |
+
|
| 276 |
+
**Weighted Boundary Loss (Experimental)**
|
| 277 |
+
|
| 278 |
+
Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by:
|
| 279 |
+
|
| 280 |
+
1. Applying 3x loss weight to samples in challenging regions:
|
| 281 |
+
- Hue: 0.18-0.35 (normalized range covering Y/YG/G)
|
| 282 |
+
- Value > 0.7 (high brightness)
|
| 283 |
+
- Chroma > 0.5 (high saturation)
|
| 284 |
+
2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits
|
| 285 |
+
|
| 286 |
+
**Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model (Multi-MLP W+B + Multi-Error Predictor W+B Large, Delta-E 0.524) and the standard large dataset model (Multi-MLP + Multi-Error Predictor Large, Delta-E 0.525) achieve nearly identical results, making explicit loss weighting optional. The best overall model is Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52.
|
| 287 |
+
|
| 288 |
+
### Experimental Findings
|
| 289 |
+
|
| 290 |
+
The following experiments were conducted but did not improve results:
|
| 291 |
+
|
| 292 |
+
**Delta-E Training**
|
| 293 |
+
|
| 294 |
+
Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator.
|
| 295 |
+
|
| 296 |
+
*Hypothesis*: Perceptual Delta-E loss might outperform MSE-trained models.
|
| 297 |
+
|
| 298 |
+
*Implementation*: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients.
|
| 299 |
+
|
| 300 |
+
*Results*: While Delta-E is comparable, **hue accuracy is ~10x worse**:
|
| 301 |
+
|
| 302 |
+
| Metric (Normalized MAE) | Delta-E Model | MSE Model |
|
| 303 |
+
|--------------------------|---------------|-----------|
|
| 304 |
+
| Hue MAE | 0.30 | 0.03 |
|
| 305 |
+
| Value MAE | 0.002 | 0.004 |
|
| 306 |
+
| Chroma MAE | 0.007 | 0.008 |
|
| 307 |
+
| Code MAE | 0.07 | 0.01 |
|
| 308 |
+
| **Delta-E (perceptual)** | **0.52** | **0.50** |
|
| 309 |
+
|
| 310 |
+
*Key Takeaway*: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification.
|
| 311 |
+
|
| 312 |
+
**Classical Interpolation**
|
| 313 |
+
|
| 314 |
+
Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors.
|
| 315 |
+
|
| 316 |
+
*Results (Validation MAE)*:
|
| 317 |
+
|
| 318 |
+
| Component | RBF | KD-Tree | Delaunay | ML (Best) |
|
| 319 |
+
|-----------|------|---------|----------|-----------|
|
| 320 |
+
| Hue | 1.40 | 1.40 | 1.29 | **0.03** |
|
| 321 |
+
| Value | 0.01 | 0.10 | 0.02 | 0.05 |
|
| 322 |
+
| Chroma | 0.22 | 0.99 | 0.35 | **0.11** |
|
| 323 |
+
| Code | 0.33 | 0.28 | 0.28 | **0.00** |
|
| 324 |
+
|
| 325 |
+
*Key Insight*: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy.
|
| 326 |
+
|
| 327 |
+
**Circular Hue Loss**
|
| 328 |
+
|
| 329 |
+
Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps).
|
| 330 |
+
|
| 331 |
+
*Results*: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24).
|
| 332 |
+
|
| 333 |
+
*Key Takeaway*: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization.
|
| 334 |
+
|
| 335 |
+
**REAL-Only Refinement**
|
| 336 |
+
|
| 337 |
+
Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995).
|
| 338 |
+
|
| 339 |
+
*Results*: Essentially identical performance (Delta-E 1.5233 vs 1.5191).
|
| 340 |
+
|
| 341 |
+
*Key Takeaway*: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate.
|
| 342 |
+
|
| 343 |
+
**Gamma Normalization**
|
| 344 |
+
|
| 345 |
+
Gamma correction to the Y (luminance) channel during normalization.
|
| 346 |
+
|
| 347 |
+
*Results*: No consistent improvement across gamma values 1.0-3.0:
|
| 348 |
+
|
| 349 |
+
| Gamma | Median ΔE (± std) |
|
| 350 |
+
|----------------|-------------------|
|
| 351 |
+
| 1.0 (baseline) | 0.730 ± 0.054 |
|
| 352 |
+
| 2.5 (best) | 0.683 ± 0.132 |
|
| 353 |
+
|
| 354 |
+

|
| 355 |
+
|
| 356 |
+
*Key Takeaway*: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise.
|
| 357 |
+
|
| 358 |
+
## Munsell to xyY (to_xyY)
|
| 359 |
+
|
| 360 |
+
### Performance Benchmarks
|
| 361 |
+
|
| 362 |
+
Comprehensive comparison using all 2,734 REAL Munsell colors:
|
| 363 |
+
|
| 364 |
+
| Model | Delta-E | Speed (ms) |
|
| 365 |
+
|-----------------------------------------------|-------------|------------|
|
| 366 |
+
| Colour Library (Baseline) | 0.00 | 1.27 |
|
| 367 |
+
| **Multi-MLP (Optimized)** | **0.48** | 0.008 |
|
| 368 |
+
| Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 |
|
| 369 |
+
| Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 |
|
| 370 |
+
| Multi-MLP | 0.66 | 0.016 |
|
| 371 |
+
| Multi-MLP + Error Predictor | 0.67 | 0.018 |
|
| 372 |
+
| Multi-Head (Optimized) | 0.71 | 0.015 |
|
| 373 |
+
| Multi-Head | 0.78 | 0.008 |
|
| 374 |
+
| Multi-Head + Multi-Error Predictor | 1.11 | 0.028 |
|
| 375 |
+
| Simple MLP | 1.42 | **0.0008** |
|
| 376 |
+
|
| 377 |
+
**Best Models**:
|
| 378 |
+
|
| 379 |
+
- **Best Accuracy**: Multi-MLP (Optimized) - Delta-E 0.48
|
| 380 |
+
- **Fastest**: Simple MLP (0.0008 ms/sample) - 1,654x faster than Colour library
|
| 381 |
+
- **Best Balance**: Multi-MLP (Optimized) - 154x faster with Delta-E 0.48
|
| 382 |
+
|
| 383 |
+
### Model Architectures
|
| 384 |
+
|
| 385 |
+
9 architectures were evaluated for the Munsell to xyY direction:
|
| 386 |
+
|
| 387 |
+
**Single-Stage Models**
|
| 388 |
+
|
| 389 |
+
1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs
|
| 390 |
+
2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y)
|
| 391 |
+
3. **Multi-Head (Optimized)** - Hyperparameter-optimized variant
|
| 392 |
+
4. **Multi-MLP** - 3 completely independent MLP branches
|
| 393 |
+
5. **Multi-MLP (Optimized)** - Hyperparameter-optimized variant (BEST)
|
| 394 |
+
|
| 395 |
+
**Two-Stage Models**
|
| 396 |
+
|
| 397 |
+
6. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction
|
| 398 |
+
7. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors
|
| 399 |
+
8. **Multi-MLP (Opt) + Multi-Error Predictor (Opt)** - Optimized two-stage
|
| 400 |
+
9. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction
|
| 401 |
+
|
| 402 |
+
The **Multi-MLP (Optimized)** architecture achieved the best results with Delta-E 0.48.
|
| 403 |
+
|
| 404 |
+
### Differentiable Approximator
|
| 405 |
+
|
| 406 |
+
A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss:
|
| 407 |
+
|
| 408 |
+
- **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU
|
| 409 |
+
- **Accuracy**: MAE ~0.0006 for x, y, and Y components
|
| 410 |
+
- **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz)
|
| 411 |
+
|
| 412 |
+
This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables.
|
| 413 |
+
|
| 414 |
+
## Shared Infrastructure
|
| 415 |
+
|
| 416 |
+
### Hyperparameter Optimization
|
| 417 |
+
|
| 418 |
+
Optuna was used for systematic hyperparameter search over:
|
| 419 |
+
|
| 420 |
+
- Learning rate (1e-4 to 1e-3)
|
| 421 |
+
- Batch size (256, 512, 1024)
|
| 422 |
+
- Dropout rate (0.0 to 0.2)
|
| 423 |
+
- Chroma branch width multiplier (1.0 to 2.0)
|
| 424 |
+
- Loss function weights (MSE, Huber)
|
| 425 |
+
|
| 426 |
+
Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization.
|
| 427 |
+
|
| 428 |
+
### Training Infrastructure
|
| 429 |
+
|
| 430 |
+
- **Optimizer**: AdamW with weight decay
|
| 431 |
+
- **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5)
|
| 432 |
+
- **Early stopping**: Patience=20 epochs
|
| 433 |
+
- **Checkpointing**: Best model saved based on validation loss
|
| 434 |
+
- **Logging**: MLflow for experiment tracking
|
| 435 |
+
|
| 436 |
+
### JAX Delta-E Implementation
|
| 437 |
+
|
| 438 |
+
Located in `learning_munsell/losses/jax_delta_e.py`:
|
| 439 |
+
|
| 440 |
+
- Differentiable xyY -> XYZ -> Lab color space conversions
|
| 441 |
+
- Full CIE 2000 Delta-E implementation with gradient support
|
| 442 |
+
- JIT-compiled functions for performance
|
| 443 |
+
|
| 444 |
+
Usage:
|
| 445 |
+
|
| 446 |
+
```python
|
| 447 |
+
from learning_munsell.losses import delta_E_loss, delta_E_CIE2000
|
| 448 |
+
|
| 449 |
+
# Compute perceptual loss between predicted and target xyY
|
| 450 |
+
loss = delta_E_loss(pred_xyY, target_xyY)
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
## Limitations
|
| 454 |
+
|
| 455 |
+
### BatchNorm Instability on MPS
|
| 456 |
+
|
| 457 |
+
Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend:
|
| 458 |
+
|
| 459 |
+
1. **Validation loss spikes** during training
|
| 460 |
+
2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1)
|
| 461 |
+
3. **Non-reproducible behavior**
|
| 462 |
+
|
| 463 |
+
**Affected Models**: Large dataset error predictors using BatchNorm.
|
| 464 |
+
|
| 465 |
+
**Workarounds**:
|
| 466 |
+
|
| 467 |
+
1. Use CPU for training
|
| 468 |
+
2. Replace BatchNorm with LayerNorm
|
| 469 |
+
3. Use smaller models (300K samples vs 2M)
|
| 470 |
+
4. Skip error predictor stage for affected models
|
| 471 |
+
|
| 472 |
+
The recommended production model (`multi_resnet_error_predictor_large.onnx`) was trained on the large dataset and does not exhibit this instability.
|
| 473 |
+
|
| 474 |
+
**References**:
|
| 475 |
+
|
| 476 |
+
- [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602)
|
| 477 |
+
- [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230)
|
| 478 |
+
- [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/)
|
learning_munsell/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Learning Munsell - Machine Learning for Munsell Color Conversions."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
__all__ = ["PROJECT_ROOT"]
|
| 6 |
+
|
| 7 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
learning_munsell/analysis/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Analysis utilities for Munsell color conversion models."""
|
learning_munsell/analysis/error_analysis.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analyze error distribution to identify problematic regions in Munsell space.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Runs the best model on all REAL Munsell colors
|
| 6 |
+
2. Computes Delta-E for each sample
|
| 7 |
+
3. Identifies samples with high error (Delta-E > threshold)
|
| 8 |
+
4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues
|
| 9 |
+
5. Outputs statistics and visualizations
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import onnxruntime as ort
|
| 17 |
+
from colour import XYZ_to_Lab, xyY_to_XYZ
|
| 18 |
+
from colour.difference import delta_E_CIE2000
|
| 19 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
|
| 20 |
+
from colour.notation.munsell import (
|
| 21 |
+
CCS_ILLUMINANT_MUNSELL,
|
| 22 |
+
munsell_colour_to_munsell_specification,
|
| 23 |
+
munsell_specification_to_xyY,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from learning_munsell import PROJECT_ROOT
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 29 |
+
LOGGER = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
HUE_NAMES = {
|
| 32 |
+
1: "R",
|
| 33 |
+
2: "YR",
|
| 34 |
+
3: "Y",
|
| 35 |
+
4: "GY",
|
| 36 |
+
5: "G",
|
| 37 |
+
6: "BG",
|
| 38 |
+
7: "B",
|
| 39 |
+
8: "PB",
|
| 40 |
+
9: "P",
|
| 41 |
+
10: "RP",
|
| 42 |
+
0: "RP",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_model_and_params(model_name: str):
|
| 47 |
+
"""Load ONNX model and normalization parameters."""
|
| 48 |
+
model_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 49 |
+
|
| 50 |
+
model_path = model_dir / f"{model_name}.onnx"
|
| 51 |
+
params_path = model_dir / f"{model_name}_normalization_params.npz"
|
| 52 |
+
|
| 53 |
+
if not model_path.exists():
|
| 54 |
+
raise FileNotFoundError(f"Model not found: {model_path}")
|
| 55 |
+
if not params_path.exists():
|
| 56 |
+
raise FileNotFoundError(f"Params not found: {params_path}")
|
| 57 |
+
|
| 58 |
+
session = ort.InferenceSession(str(model_path))
|
| 59 |
+
params = np.load(params_path, allow_pickle=True)
|
| 60 |
+
input_params = params["input_params"].item()
|
| 61 |
+
output_params = params["output_params"].item()
|
| 62 |
+
|
| 63 |
+
return session, input_params, output_params
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray:
|
| 67 |
+
"""Normalize xyY input."""
|
| 68 |
+
normalized = np.copy(xyY).astype(np.float32)
|
| 69 |
+
# Scale Y from 0-100 to 0-1 range before normalization
|
| 70 |
+
normalized[..., 2] = xyY[..., 2] / 100.0
|
| 71 |
+
normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / (
|
| 72 |
+
params["x_range"][1] - params["x_range"][0]
|
| 73 |
+
)
|
| 74 |
+
normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / (
|
| 75 |
+
params["y_range"][1] - params["y_range"][0]
|
| 76 |
+
)
|
| 77 |
+
normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / (
|
| 78 |
+
params["Y_range"][1] - params["Y_range"][0]
|
| 79 |
+
)
|
| 80 |
+
return normalized
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray:
|
| 84 |
+
"""Denormalize Munsell output."""
|
| 85 |
+
denorm = np.copy(pred)
|
| 86 |
+
denorm[..., 0] = (
|
| 87 |
+
pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
|
| 88 |
+
+ params["hue_range"][0]
|
| 89 |
+
)
|
| 90 |
+
denorm[..., 1] = (
|
| 91 |
+
pred[..., 1] * (params["value_range"][1] - params["value_range"][0])
|
| 92 |
+
+ params["value_range"][0]
|
| 93 |
+
)
|
| 94 |
+
denorm[..., 2] = (
|
| 95 |
+
pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
|
| 96 |
+
+ params["chroma_range"][0]
|
| 97 |
+
)
|
| 98 |
+
denorm[..., 3] = (
|
| 99 |
+
pred[..., 3] * (params["code_range"][1] - params["code_range"][0])
|
| 100 |
+
+ params["code_range"][0]
|
| 101 |
+
)
|
| 102 |
+
return denorm
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float:
|
| 106 |
+
"""Compute Delta-E between predicted spec (via xyY) and ground truth xyY."""
|
| 107 |
+
try:
|
| 108 |
+
pred_xyY = munsell_specification_to_xyY(pred_spec)
|
| 109 |
+
pred_XYZ = xyY_to_XYZ(pred_xyY)
|
| 110 |
+
pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 111 |
+
|
| 112 |
+
# Ground truth Y is in 0-100 range, need to scale to 0-1
|
| 113 |
+
gt_xyY_scaled = gt_xyY.copy()
|
| 114 |
+
gt_xyY_scaled[2] = gt_xyY[2] / 100.0
|
| 115 |
+
gt_XYZ = xyY_to_XYZ(gt_xyY_scaled)
|
| 116 |
+
gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 117 |
+
|
| 118 |
+
return delta_E_CIE2000(gt_Lab, pred_Lab)
|
| 119 |
+
except Exception:
|
| 120 |
+
return np.nan
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def analyze_errors(model_name: str = "multi_head_large", threshold: float = 3.0):
|
| 124 |
+
"""Analyze error distribution for a model."""
|
| 125 |
+
LOGGER.info("=" * 80)
|
| 126 |
+
LOGGER.info("Error Analysis for %s", model_name)
|
| 127 |
+
LOGGER.info("=" * 80)
|
| 128 |
+
|
| 129 |
+
# Load model
|
| 130 |
+
session, input_params, output_params = load_model_and_params(model_name)
|
| 131 |
+
input_name = session.get_inputs()[0].name
|
| 132 |
+
|
| 133 |
+
# Collect data
|
| 134 |
+
results = []
|
| 135 |
+
|
| 136 |
+
for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL:
|
| 137 |
+
hue_code_str, value, chroma = munsell_spec_tuple
|
| 138 |
+
munsell_str = f"{hue_code_str} {value}/{chroma}"
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
gt_spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 142 |
+
gt_xyY = np.array(xyY_gt)
|
| 143 |
+
|
| 144 |
+
# Predict
|
| 145 |
+
xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_params)
|
| 146 |
+
pred_norm = session.run(None, {input_name: xyY_norm})[0]
|
| 147 |
+
pred_spec = denormalize_output(pred_norm, output_params)[0]
|
| 148 |
+
|
| 149 |
+
# Clamp to valid ranges
|
| 150 |
+
pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0)
|
| 151 |
+
pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0)
|
| 152 |
+
pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0)
|
| 153 |
+
pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0)
|
| 154 |
+
pred_spec[3] = np.round(pred_spec[3])
|
| 155 |
+
|
| 156 |
+
# Compute Delta-E
|
| 157 |
+
delta_e = compute_delta_e(pred_spec, gt_xyY)
|
| 158 |
+
|
| 159 |
+
if not np.isnan(delta_e):
|
| 160 |
+
results.append({
|
| 161 |
+
"munsell_str": munsell_str,
|
| 162 |
+
"gt_spec": gt_spec,
|
| 163 |
+
"pred_spec": pred_spec,
|
| 164 |
+
"delta_e": delta_e,
|
| 165 |
+
"hue": gt_spec[0],
|
| 166 |
+
"value": gt_spec[1],
|
| 167 |
+
"chroma": gt_spec[2],
|
| 168 |
+
"code": int(gt_spec[3]),
|
| 169 |
+
"gt_xyY": gt_xyY,
|
| 170 |
+
})
|
| 171 |
+
except Exception as e:
|
| 172 |
+
LOGGER.warning("Failed for %s: %s", munsell_str, e)
|
| 173 |
+
|
| 174 |
+
LOGGER.info("\nTotal samples evaluated: %d", len(results))
|
| 175 |
+
|
| 176 |
+
# Overall statistics
|
| 177 |
+
delta_es = [r["delta_e"] for r in results]
|
| 178 |
+
LOGGER.info("\nOverall Delta-E Statistics:")
|
| 179 |
+
LOGGER.info(" Mean: %.4f", np.mean(delta_es))
|
| 180 |
+
LOGGER.info(" Median: %.4f", np.median(delta_es))
|
| 181 |
+
LOGGER.info(" Std: %.4f", np.std(delta_es))
|
| 182 |
+
LOGGER.info(" Min: %.4f", np.min(delta_es))
|
| 183 |
+
LOGGER.info(" Max: %.4f", np.max(delta_es))
|
| 184 |
+
|
| 185 |
+
# Distribution
|
| 186 |
+
LOGGER.info("\nDelta-E Distribution:")
|
| 187 |
+
for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]:
|
| 188 |
+
count = sum(1 for d in delta_es if d <= thresh)
|
| 189 |
+
pct = 100 * count / len(delta_es)
|
| 190 |
+
LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct)
|
| 191 |
+
|
| 192 |
+
# High error samples
|
| 193 |
+
high_error = [r for r in results if r["delta_e"] > threshold]
|
| 194 |
+
LOGGER.info("\nSamples with Delta-E > %.1f: %d (%.1f%%)",
|
| 195 |
+
threshold, len(high_error), 100 * len(high_error) / len(results))
|
| 196 |
+
|
| 197 |
+
# Analyze by hue family
|
| 198 |
+
LOGGER.info("\n" + "=" * 40)
|
| 199 |
+
LOGGER.info("Analysis by Hue Family")
|
| 200 |
+
LOGGER.info("=" * 40)
|
| 201 |
+
|
| 202 |
+
by_hue = defaultdict(list)
|
| 203 |
+
for r in results:
|
| 204 |
+
hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}")
|
| 205 |
+
by_hue[hue_name].append(r["delta_e"])
|
| 206 |
+
|
| 207 |
+
LOGGER.info("\n%-4s %5s %6s %6s %6s %s",
|
| 208 |
+
"Hue", "Count", "Mean", "Median", "Max", ">3.0")
|
| 209 |
+
for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]:
|
| 210 |
+
if hue_name in by_hue:
|
| 211 |
+
des = by_hue[hue_name]
|
| 212 |
+
high = sum(1 for d in des if d > 3.0)
|
| 213 |
+
LOGGER.info("%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
|
| 214 |
+
hue_name, len(des), np.mean(des), np.median(des),
|
| 215 |
+
np.max(des), high, 100*high/len(des))
|
| 216 |
+
|
| 217 |
+
# Analyze by value range
|
| 218 |
+
LOGGER.info("\n" + "=" * 40)
|
| 219 |
+
LOGGER.info("Analysis by Value Range")
|
| 220 |
+
LOGGER.info("=" * 40)
|
| 221 |
+
|
| 222 |
+
value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)]
|
| 223 |
+
LOGGER.info("\n%-8s %5s %6s %6s %6s %s",
|
| 224 |
+
"Value", "Count", "Mean", "Median", "Max", ">3.0")
|
| 225 |
+
for v_min, v_max in value_ranges:
|
| 226 |
+
des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max]
|
| 227 |
+
if des:
|
| 228 |
+
high = sum(1 for d in des if d > 3.0)
|
| 229 |
+
LOGGER.info("[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
|
| 230 |
+
v_min, v_max, len(des), np.mean(des), np.median(des),
|
| 231 |
+
np.max(des), high, 100*high/len(des) if des else 0)
|
| 232 |
+
|
| 233 |
+
# Analyze by chroma range
|
| 234 |
+
LOGGER.info("\n" + "=" * 40)
|
| 235 |
+
LOGGER.info("Analysis by Chroma Range")
|
| 236 |
+
LOGGER.info("=" * 40)
|
| 237 |
+
|
| 238 |
+
chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)]
|
| 239 |
+
LOGGER.info("\n%-8s %5s %6s %6s %6s %s",
|
| 240 |
+
"Chroma", "Count", "Mean", "Median", "Max", ">3.0")
|
| 241 |
+
for c_min, c_max in chroma_ranges:
|
| 242 |
+
des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max]
|
| 243 |
+
if des:
|
| 244 |
+
high = sum(1 for d in des if d > 3.0)
|
| 245 |
+
LOGGER.info("[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
|
| 246 |
+
c_min, c_max, len(des), np.mean(des), np.median(des),
|
| 247 |
+
np.max(des), high, 100*high/len(des) if des else 0)
|
| 248 |
+
|
| 249 |
+
# Top 20 worst samples
|
| 250 |
+
LOGGER.info("\n" + "=" * 40)
|
| 251 |
+
LOGGER.info("Top 20 Worst Samples")
|
| 252 |
+
LOGGER.info("=" * 40)
|
| 253 |
+
|
| 254 |
+
worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20]
|
| 255 |
+
LOGGER.info("\n%-15s %6s %-20s %-20s",
|
| 256 |
+
"Munsell", "DeltaE", "GT Spec", "Pred Spec")
|
| 257 |
+
for r in worst:
|
| 258 |
+
gt = f"[{r['gt_spec'][0]:.1f}, {r['gt_spec'][1]:.1f}, {r['gt_spec'][2]:.1f}, {int(r['gt_spec'][3])}]"
|
| 259 |
+
pred = f"[{r['pred_spec'][0]:.1f}, {r['pred_spec'][1]:.1f}, {r['pred_spec'][2]:.1f}, {int(r['pred_spec'][3])}]"
|
| 260 |
+
LOGGER.info("%-15s %6.2f %-20s %-20s",
|
| 261 |
+
r["munsell_str"], r["delta_e"], gt, pred)
|
| 262 |
+
|
| 263 |
+
# Analyze component errors for high-error samples
|
| 264 |
+
LOGGER.info("\n" + "=" * 40)
|
| 265 |
+
LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold)
|
| 266 |
+
LOGGER.info("=" * 40)
|
| 267 |
+
|
| 268 |
+
if high_error:
|
| 269 |
+
hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error]
|
| 270 |
+
value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error]
|
| 271 |
+
chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error]
|
| 272 |
+
code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error]
|
| 273 |
+
|
| 274 |
+
LOGGER.info("\n%-10s %6s %6s %6s",
|
| 275 |
+
"Component", "Mean", "Median", "Max")
|
| 276 |
+
LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Hue",
|
| 277 |
+
np.mean(hue_errors), np.median(hue_errors), np.max(hue_errors))
|
| 278 |
+
LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Value",
|
| 279 |
+
np.mean(value_errors), np.median(value_errors), np.max(value_errors))
|
| 280 |
+
LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Chroma",
|
| 281 |
+
np.mean(chroma_errors), np.median(chroma_errors), np.max(chroma_errors))
|
| 282 |
+
LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Code",
|
| 283 |
+
np.mean(code_errors), np.median(code_errors), np.max(code_errors))
|
| 284 |
+
|
| 285 |
+
return results
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def main():
|
| 289 |
+
"""Run error analysis."""
|
| 290 |
+
# Try the best models
|
| 291 |
+
models = [
|
| 292 |
+
"multi_head_large",
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
for model_name in models:
|
| 296 |
+
try:
|
| 297 |
+
analyze_errors(model_name, threshold=3.0)
|
| 298 |
+
except FileNotFoundError as e:
|
| 299 |
+
LOGGER.warning("Skipping %s: %s", model_name, e)
|
| 300 |
+
LOGGER.info("\n")
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if __name__ == "__main__":
|
| 304 |
+
main()
|
learning_munsell/comparison/from_xyY/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Comparison scripts for xyY to Munsell conversion models."""
|
learning_munsell/comparison/from_xyY/compare_all_models.py
ADDED
|
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compare all ML models for xyY to Munsell conversion on real Munsell data.
|
| 3 |
+
|
| 4 |
+
Models to compare:
|
| 5 |
+
1. MLP (Base only)
|
| 6 |
+
2. MLP + Error Predictor (Two-stage)
|
| 7 |
+
3. Unified MLP
|
| 8 |
+
4. MLP + Self-Attention
|
| 9 |
+
5. MLP + Self-Attention + Error Predictor
|
| 10 |
+
6. Deep + Wide
|
| 11 |
+
7. Mixture of Experts
|
| 12 |
+
8. FT-Transformer
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
import warnings
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import onnxruntime as ort
|
| 24 |
+
from colour import XYZ_to_Lab, xyY_to_XYZ
|
| 25 |
+
from colour.difference import delta_E_CIE2000
|
| 26 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
|
| 27 |
+
from colour.notation.munsell import (
|
| 28 |
+
CCS_ILLUMINANT_MUNSELL,
|
| 29 |
+
munsell_colour_to_munsell_specification,
|
| 30 |
+
munsell_specification_to_xyY,
|
| 31 |
+
xyY_to_munsell_specification,
|
| 32 |
+
)
|
| 33 |
+
from numpy.typing import NDArray
|
| 34 |
+
|
| 35 |
+
from learning_munsell import PROJECT_ROOT
|
| 36 |
+
from learning_munsell.utilities.common import (
|
| 37 |
+
benchmark_inference_speed,
|
| 38 |
+
get_model_size_mb,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 42 |
+
LOGGER = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def normalize_input(X: NDArray, params: dict[str, Any] | None) -> NDArray:
|
| 46 |
+
"""Normalize xyY input.
|
| 47 |
+
|
| 48 |
+
If params is None, xyY is assumed to already be in [0, 1] range (no normalization needed).
|
| 49 |
+
"""
|
| 50 |
+
if params is None:
|
| 51 |
+
# xyY is already in [0, 1] range - no normalization needed
|
| 52 |
+
return X.astype(np.float32)
|
| 53 |
+
|
| 54 |
+
X_norm = np.copy(X)
|
| 55 |
+
X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
|
| 56 |
+
params["x_range"][1] - params["x_range"][0]
|
| 57 |
+
)
|
| 58 |
+
X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
|
| 59 |
+
params["y_range"][1] - params["y_range"][0]
|
| 60 |
+
)
|
| 61 |
+
X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / (
|
| 62 |
+
params["Y_range"][1] - params["Y_range"][0]
|
| 63 |
+
)
|
| 64 |
+
return X_norm.astype(np.float32)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
|
| 68 |
+
"""Denormalize Munsell output."""
|
| 69 |
+
y = np.copy(y_norm)
|
| 70 |
+
y[..., 0] = (
|
| 71 |
+
y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
|
| 72 |
+
+ params["hue_range"][0]
|
| 73 |
+
)
|
| 74 |
+
y[..., 1] = (
|
| 75 |
+
y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
|
| 76 |
+
+ params["value_range"][0]
|
| 77 |
+
)
|
| 78 |
+
y[..., 2] = (
|
| 79 |
+
y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
|
| 80 |
+
+ params["chroma_range"][0]
|
| 81 |
+
)
|
| 82 |
+
y[..., 3] = (
|
| 83 |
+
y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
|
| 84 |
+
+ params["code_range"][0]
|
| 85 |
+
)
|
| 86 |
+
return y
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def clamp_munsell_specification(specification: NDArray) -> NDArray:
|
| 90 |
+
"""Clamp Munsell specification to valid ranges."""
|
| 91 |
+
|
| 92 |
+
clamped = np.copy(specification)
|
| 93 |
+
clamped[..., 0] = np.clip(specification[..., 0], 0.0, 10.0) # Hue: [0, 10]
|
| 94 |
+
clamped[..., 1] = np.clip(specification[..., 1], 1.0, 9.0) # Value: [1, 9] (colour library constraint)
|
| 95 |
+
clamped[..., 2] = np.clip(specification[..., 2], 0.0, 50.0) # Chroma: [0, 50]
|
| 96 |
+
clamped[..., 3] = np.clip(specification[..., 3], 1.0, 10.0) # Code: [1, 10]
|
| 97 |
+
|
| 98 |
+
return clamped
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def evaluate_model(
|
| 102 |
+
session: ort.InferenceSession,
|
| 103 |
+
X_norm: NDArray,
|
| 104 |
+
ground_truth: NDArray,
|
| 105 |
+
params: dict[str, Any],
|
| 106 |
+
input_name: str = "xyY",
|
| 107 |
+
reference_Lab: NDArray | None = None,
|
| 108 |
+
) -> dict[str, Any]:
|
| 109 |
+
"""Evaluate a single model."""
|
| 110 |
+
pred_norm = session.run(None, {input_name: X_norm})[0]
|
| 111 |
+
pred = denormalize_output(pred_norm, params)
|
| 112 |
+
errors = np.abs(pred - ground_truth)
|
| 113 |
+
|
| 114 |
+
result = {
|
| 115 |
+
"hue_mae": np.mean(errors[:, 0]),
|
| 116 |
+
"value_mae": np.mean(errors[:, 1]),
|
| 117 |
+
"chroma_mae": np.mean(errors[:, 2]),
|
| 118 |
+
"code_mae": np.mean(errors[:, 3]),
|
| 119 |
+
"max_errors": np.max(errors, axis=1),
|
| 120 |
+
"hue_errors": errors[:, 0],
|
| 121 |
+
"value_errors": errors[:, 1],
|
| 122 |
+
"chroma_errors": errors[:, 2],
|
| 123 |
+
"code_errors": errors[:, 3],
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# Compute Delta-E against ground truth
|
| 127 |
+
if reference_Lab is not None:
|
| 128 |
+
delta_E_values = []
|
| 129 |
+
for idx in range(len(pred)):
|
| 130 |
+
try:
|
| 131 |
+
# Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab
|
| 132 |
+
ml_spec = clamp_munsell_specification(pred[idx])
|
| 133 |
+
|
| 134 |
+
# Round Code to nearest integer before round-trip conversion
|
| 135 |
+
ml_spec_for_conversion = ml_spec.copy()
|
| 136 |
+
ml_spec_for_conversion[3] = round(ml_spec[3])
|
| 137 |
+
|
| 138 |
+
ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
|
| 139 |
+
ml_XYZ = xyY_to_XYZ(ml_xyy)
|
| 140 |
+
ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 141 |
+
|
| 142 |
+
delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
|
| 143 |
+
delta_E_values.append(delta_E)
|
| 144 |
+
except (RuntimeError, ValueError):
|
| 145 |
+
# Skip if conversion fails
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
|
| 149 |
+
else:
|
| 150 |
+
result["delta_E"] = np.nan
|
| 151 |
+
|
| 152 |
+
return result
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def generate_html_report(
|
| 156 |
+
results: dict[str, dict[str, Any]],
|
| 157 |
+
num_samples: int,
|
| 158 |
+
output_file: Path,
|
| 159 |
+
baseline_inference_time_ms: float | None = None,
|
| 160 |
+
) -> None:
|
| 161 |
+
"""Generate HTML report with visualizations."""
|
| 162 |
+
# Calculate metrics
|
| 163 |
+
avg_maes = {}
|
| 164 |
+
for model_name, result in results.items():
|
| 165 |
+
avg_maes[model_name] = np.mean(
|
| 166 |
+
[
|
| 167 |
+
result["hue_mae"],
|
| 168 |
+
result["value_mae"],
|
| 169 |
+
result["chroma_mae"],
|
| 170 |
+
result["code_mae"],
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Sort by average MAE
|
| 175 |
+
sorted_models = sorted(avg_maes.items(), key=lambda x: x[1])
|
| 176 |
+
|
| 177 |
+
# Precision thresholds
|
| 178 |
+
thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0]
|
| 179 |
+
|
| 180 |
+
html = f"""<!DOCTYPE html>
|
| 181 |
+
<html lang="en" class="dark">
|
| 182 |
+
<head>
|
| 183 |
+
<meta charset="UTF-8">
|
| 184 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 185 |
+
<title>ML Model Comparison Report - {datetime.now().strftime("%Y-%m-%d %H:%M")}</title>
|
| 186 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 187 |
+
<script>
|
| 188 |
+
tailwind.config = {{
|
| 189 |
+
darkMode: 'class',
|
| 190 |
+
theme: {{
|
| 191 |
+
extend: {{
|
| 192 |
+
colors: {{
|
| 193 |
+
border: "hsl(240 3.7% 15.9%)",
|
| 194 |
+
input: "hsl(240 3.7% 15.9%)",
|
| 195 |
+
ring: "hsl(240 4.9% 83.9%)",
|
| 196 |
+
background: "hsl(240 10% 3.9%)",
|
| 197 |
+
foreground: "hsl(0 0% 98%)",
|
| 198 |
+
primary: {{
|
| 199 |
+
DEFAULT: "hsl(263 70% 60%)",
|
| 200 |
+
foreground: "hsl(0 0% 98%)",
|
| 201 |
+
}},
|
| 202 |
+
secondary: {{
|
| 203 |
+
DEFAULT: "hsl(240 3.7% 15.9%)",
|
| 204 |
+
foreground: "hsl(0 0% 98%)",
|
| 205 |
+
}},
|
| 206 |
+
muted: {{
|
| 207 |
+
DEFAULT: "hsl(240 3.7% 15.9%)",
|
| 208 |
+
foreground: "hsl(240 5% 64.9%)",
|
| 209 |
+
}},
|
| 210 |
+
accent: {{
|
| 211 |
+
DEFAULT: "hsl(240 3.7% 15.9%)",
|
| 212 |
+
foreground: "hsl(0 0% 98%)",
|
| 213 |
+
}},
|
| 214 |
+
card: {{
|
| 215 |
+
DEFAULT: "hsl(240 10% 6%)",
|
| 216 |
+
foreground: "hsl(0 0% 98%)",
|
| 217 |
+
}},
|
| 218 |
+
}}
|
| 219 |
+
}}
|
| 220 |
+
}}
|
| 221 |
+
}}
|
| 222 |
+
</script>
|
| 223 |
+
<style>
|
| 224 |
+
.gradient-primary {{
|
| 225 |
+
background: linear-gradient(135deg, hsl(263 70% 50%) 0%, hsl(280 70% 45%) 100%);
|
| 226 |
+
}}
|
| 227 |
+
.bar-fill {{
|
| 228 |
+
background: linear-gradient(90deg, hsl(263 70% 60%) 0%, hsl(280 70% 55%) 100%);
|
| 229 |
+
transition: width 0.5s cubic-bezier(0.4, 0, 0.2, 1);
|
| 230 |
+
}}
|
| 231 |
+
</style>
|
| 232 |
+
</head>
|
| 233 |
+
<body class="bg-background text-foreground antialiased">
|
| 234 |
+
<div class="max-w-7xl mx-auto p-6 space-y-6">
|
| 235 |
+
<!-- Header -->
|
| 236 |
+
<div class="gradient-primary rounded-lg p-8 shadow-2xl border border-primary/20">
|
| 237 |
+
<h1 class="text-4xl font-bold text-white mb-2">ML Model Comparison Report</h1>
|
| 238 |
+
<div class="text-white/90 space-y-1">
|
| 239 |
+
<p class="text-lg">xyY to Munsell Specification Conversion</p>
|
| 240 |
+
<p class="text-sm">Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p>
|
| 241 |
+
<p class="text-sm">Test Samples: <span class="font-semibold">{num_samples:,}</span> real Munsell colors</p>
|
| 242 |
+
</div>
|
| 243 |
+
</div>
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
# Best Models Summary (FIRST - moved to top)
|
| 247 |
+
# Find best models for each metric
|
| 248 |
+
delta_E_values = [
|
| 249 |
+
r["delta_E"] for r in results.values() if not np.isnan(r["delta_E"])
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
best_delta_E = (
|
| 253 |
+
min(
|
| 254 |
+
results.items(),
|
| 255 |
+
key=lambda x: x[1]["delta_E"]
|
| 256 |
+
if not np.isnan(x[1]["delta_E"])
|
| 257 |
+
else float("inf"),
|
| 258 |
+
)[0]
|
| 259 |
+
if delta_E_values
|
| 260 |
+
else None
|
| 261 |
+
)
|
| 262 |
+
best_avg = sorted_models[0][0]
|
| 263 |
+
|
| 264 |
+
# Performance Metrics Table (FIRST - as summary)
|
| 265 |
+
# Find best for each metric
|
| 266 |
+
best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0]
|
| 267 |
+
best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0]
|
| 268 |
+
|
| 269 |
+
# Add Best Models Summary HTML
|
| 270 |
+
html += f"""
|
| 271 |
+
<!-- Best Models Summary -->
|
| 272 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 273 |
+
<h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Best Models by Metric</h2>
|
| 274 |
+
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
|
| 275 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 276 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Smallest Size</div>
|
| 277 |
+
<div class="text-3xl font-bold text-primary mb-3">{results[best_size]["model_size_mb"]:.2f} MB</div>
|
| 278 |
+
<div class="text-sm text-foreground/80">{best_size}</div>
|
| 279 |
+
</div>
|
| 280 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 281 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Fastest Speed</div>
|
| 282 |
+
<div class="text-3xl font-bold text-primary mb-3">{results[best_speed]["inference_time_ms"]:.4f} ms</div>
|
| 283 |
+
<div class="text-sm text-foreground/80">{best_speed}</div>
|
| 284 |
+
</div>
|
| 285 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 286 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Delta-E</div>
|
| 287 |
+
<div class="text-3xl font-bold text-primary mb-3">{results[best_delta_E]["delta_E"]:.4f}</div>
|
| 288 |
+
<div class="text-sm text-foreground/80">{best_delta_E}</div>
|
| 289 |
+
</div>
|
| 290 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 291 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Average MAE</div>
|
| 292 |
+
<div class="text-3xl font-bold text-primary mb-3">{avg_maes[best_avg]:.4f}</div>
|
| 293 |
+
<div class="text-sm text-foreground/80">{best_avg}</div>
|
| 294 |
+
</div>
|
| 295 |
+
</div>
|
| 296 |
+
</div>
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
# Get baseline speed (Colour Library Iterative)
|
| 300 |
+
baseline_speed = baseline_inference_time_ms
|
| 301 |
+
|
| 302 |
+
# Sort by Delta-E for performance table (best first)
|
| 303 |
+
sorted_by_delta_E = sorted(
|
| 304 |
+
results.items(),
|
| 305 |
+
key=lambda x: x[1]["delta_E"]
|
| 306 |
+
if not np.isnan(x[1]["delta_E"])
|
| 307 |
+
else float("inf"),
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Calculate maximum speed multiplier (fastest model) for highlighting
|
| 311 |
+
max_speed_multiplier = 0.0
|
| 312 |
+
best_multiplier_model = None
|
| 313 |
+
for model_name, result in results.items():
|
| 314 |
+
speed_ms = result["inference_time_ms"]
|
| 315 |
+
if speed_ms > 0:
|
| 316 |
+
speed_multiplier = baseline_speed / speed_ms
|
| 317 |
+
if speed_multiplier > max_speed_multiplier:
|
| 318 |
+
max_speed_multiplier = speed_multiplier
|
| 319 |
+
best_multiplier_model = model_name
|
| 320 |
+
|
| 321 |
+
html += """
|
| 322 |
+
<!-- Performance Metrics Table -->
|
| 323 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 324 |
+
<h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Model Performance Metrics</h2>
|
| 325 |
+
<div class="overflow-x-auto">
|
| 326 |
+
<table class="w-full text-sm">
|
| 327 |
+
<thead>
|
| 328 |
+
<tr class="border-b border-border">
|
| 329 |
+
<th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
|
| 330 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">
|
| 331 |
+
Size (MB)
|
| 332 |
+
<div class="text-xs font-normal text-muted-foreground/70 mt-1">ONNX files</div>
|
| 333 |
+
</th>
|
| 334 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">
|
| 335 |
+
Speed (ms/sample)
|
| 336 |
+
<div class="text-xs font-normal text-muted-foreground/70 mt-1">10 iterations</div>
|
| 337 |
+
</th>
|
| 338 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">
|
| 339 |
+
vs Baseline
|
| 340 |
+
<div class="text-xs font-normal text-muted-foreground/70 mt-1">Colour Iterative</div>
|
| 341 |
+
</th>
|
| 342 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">
|
| 343 |
+
Delta-E
|
| 344 |
+
<div class="text-xs font-normal text-muted-foreground/70 mt-1">vs Colour Lib</div>
|
| 345 |
+
</th>
|
| 346 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">Average MAE</th>
|
| 347 |
+
</tr>
|
| 348 |
+
</thead>
|
| 349 |
+
<tbody>
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
for model_name, result in sorted_by_delta_E:
|
| 353 |
+
size_mb = result["model_size_mb"]
|
| 354 |
+
speed_ms = result["inference_time_ms"]
|
| 355 |
+
avg_mae = avg_maes[model_name]
|
| 356 |
+
delta_E = result["delta_E"]
|
| 357 |
+
|
| 358 |
+
# Calculate relative speed (how many times faster than baseline)
|
| 359 |
+
speed_multiplier = baseline_speed / speed_ms if speed_ms > 0 else 0
|
| 360 |
+
|
| 361 |
+
size_class = "text-primary font-semibold" if model_name == best_size else ""
|
| 362 |
+
speed_class = "text-primary font-semibold" if model_name == best_speed else ""
|
| 363 |
+
avg_class = "text-primary font-semibold" if model_name == best_avg else ""
|
| 364 |
+
delta_E_class = (
|
| 365 |
+
"text-primary font-semibold" if model_name == best_delta_E else ""
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Format Delta-E value
|
| 369 |
+
delta_E_str = f"{delta_E:.4f}" if not np.isnan(delta_E) else "—"
|
| 370 |
+
|
| 371 |
+
# Highlight only the fastest model
|
| 372 |
+
if abs(speed_multiplier - 1.0) < 0.01:
|
| 373 |
+
# Baseline
|
| 374 |
+
multiplier_class = "text-muted-foreground"
|
| 375 |
+
multiplier_text = "1.0x"
|
| 376 |
+
elif model_name == best_multiplier_model:
|
| 377 |
+
# Fastest model (highest multiplier)
|
| 378 |
+
multiplier_class = "text-primary font-semibold"
|
| 379 |
+
if speed_multiplier > 1000:
|
| 380 |
+
multiplier_text = f"{speed_multiplier:.0f}x"
|
| 381 |
+
elif speed_multiplier > 100:
|
| 382 |
+
multiplier_text = f"{speed_multiplier:.1f}x"
|
| 383 |
+
else:
|
| 384 |
+
multiplier_text = f"{speed_multiplier:.2f}x"
|
| 385 |
+
elif speed_multiplier > 1.0:
|
| 386 |
+
# Faster than baseline but not the fastest
|
| 387 |
+
multiplier_class = ""
|
| 388 |
+
if speed_multiplier > 1000:
|
| 389 |
+
multiplier_text = f"{speed_multiplier:.0f}x"
|
| 390 |
+
elif speed_multiplier > 100:
|
| 391 |
+
multiplier_text = f"{speed_multiplier:.1f}x"
|
| 392 |
+
else:
|
| 393 |
+
multiplier_text = f"{speed_multiplier:.2f}x"
|
| 394 |
+
else:
|
| 395 |
+
# Slower than baseline
|
| 396 |
+
multiplier_class = "text-destructive"
|
| 397 |
+
multiplier_text = f"{speed_multiplier:.2f}x"
|
| 398 |
+
|
| 399 |
+
html += f"""
|
| 400 |
+
<tr class="border-b border-border/50 hover:bg-muted/30 transition-colors">
|
| 401 |
+
<td class="py-3 px-4 font-medium">{model_name}</td>
|
| 402 |
+
<td class="py-3 px-4 text-right {size_class}">{size_mb:.2f}</td>
|
| 403 |
+
<td class="py-3 px-4 text-right {speed_class}">{speed_ms:.4f}</td>
|
| 404 |
+
<td class="py-3 px-4 text-right {multiplier_class}">{multiplier_text}</td>
|
| 405 |
+
<td class="py-3 px-4 text-right {delta_E_class}">{delta_E_str}</td>
|
| 406 |
+
<td class="py-3 px-4 text-right {avg_class}">{avg_mae:.4f}</td>
|
| 407 |
+
</tr>
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
html += """
|
| 411 |
+
</tbody>
|
| 412 |
+
</table>
|
| 413 |
+
</div>
|
| 414 |
+
<div class="mt-6 p-4 bg-muted/30 rounded-md border border-primary/20">
|
| 415 |
+
<div class="text-sm space-y-2">
|
| 416 |
+
<div><span class="text-primary font-semibold">Note:</span> Speed measured with 10 iterations (3 warmup + 10 benchmark) on 2,734 samples.</div>
|
| 417 |
+
<div class="text-xs text-muted-foreground">Two-stage models include both base and error predictor. Highlighted values show best in each metric.</div>
|
| 418 |
+
<div class="text-xs text-muted-foreground">Baseline comparison: Speed multipliers show relative performance vs Colour Library's iterative xyY_to_munsell_specification(). Values <1.0x are faster.</div>
|
| 419 |
+
</div>
|
| 420 |
+
</div>
|
| 421 |
+
</div>
|
| 422 |
+
"""
|
| 423 |
+
|
| 424 |
+
# Overall ranking by Delta-E
|
| 425 |
+
html += """
|
| 426 |
+
<!-- Overall Ranking -->
|
| 427 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 428 |
+
<h2 class="text-2xl font-semibold mb-4 pb-2 border-b border-primary/30">Overall Ranking (by Delta-E)</h2>
|
| 429 |
+
<div class="space-y-1">
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
# Sort by Delta-E (best = lowest)
|
| 433 |
+
sorted_by_delta_E_ranking = sorted(
|
| 434 |
+
[
|
| 435 |
+
(name, res["delta_E"])
|
| 436 |
+
for name, res in results.items()
|
| 437 |
+
if not np.isnan(res["delta_E"])
|
| 438 |
+
],
|
| 439 |
+
key=lambda x: x[1],
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
max_delta_E = (
|
| 443 |
+
max(delta_E for _, delta_E in sorted_by_delta_E_ranking)
|
| 444 |
+
if sorted_by_delta_E_ranking
|
| 445 |
+
else 1.0
|
| 446 |
+
)
|
| 447 |
+
for rank, (model_name, delta_E) in enumerate(sorted_by_delta_E_ranking, 1):
|
| 448 |
+
width_pct = (delta_E / max_delta_E) * 100
|
| 449 |
+
html += f"""
|
| 450 |
+
<div class="flex items-center gap-3 p-2 rounded-md hover:bg-muted/50 transition-colors">
|
| 451 |
+
<div class="flex-none w-80 text-sm font-medium">
|
| 452 |
+
<span class="text-muted-foreground">{rank}.</span> {model_name}
|
| 453 |
+
</div>
|
| 454 |
+
<div class="flex-1 h-6 bg-muted rounded-md overflow-hidden">
|
| 455 |
+
<div class="bar-fill h-full rounded-md" style="width: {width_pct}%"></div>
|
| 456 |
+
</div>
|
| 457 |
+
<div class="flex-none w-20 text-right font-bold text-primary">{delta_E:.4f}</div>
|
| 458 |
+
</div>
|
| 459 |
+
"""
|
| 460 |
+
|
| 461 |
+
html += """
|
| 462 |
+
</div>
|
| 463 |
+
</div>
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
# Precision Threshold Table
|
| 467 |
+
html += """
|
| 468 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 469 |
+
<h2 class="text-2xl font-semibold mb-3 pb-3 border-b border-primary/30">Accuracy at Precision Thresholds</h2>
|
| 470 |
+
<p class="text-sm text-muted-foreground mb-6">Percentage of predictions where max error across all components is below threshold:</p>
|
| 471 |
+
<div class="overflow-x-auto">
|
| 472 |
+
<table class="w-full text-sm">
|
| 473 |
+
<thead>
|
| 474 |
+
<tr class="border-b border-border">
|
| 475 |
+
<th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
for threshold in thresholds:
|
| 479 |
+
html += f' <th class="text-right py-3 px-4 font-semibold text-muted-foreground">< {threshold:.0e}</th>\n'
|
| 480 |
+
|
| 481 |
+
html += """
|
| 482 |
+
</tr>
|
| 483 |
+
</thead>
|
| 484 |
+
<tbody>
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
# Find best (highest) accuracy for each threshold column
|
| 488 |
+
best_accuracies = {}
|
| 489 |
+
min_accuracies = {}
|
| 490 |
+
for threshold in thresholds:
|
| 491 |
+
accuracies = [
|
| 492 |
+
np.mean(results[model_name]["max_errors"] < threshold) * 100
|
| 493 |
+
for model_name, _ in sorted_models
|
| 494 |
+
]
|
| 495 |
+
best_accuracies[threshold] = max(accuracies)
|
| 496 |
+
min_accuracies[threshold] = min(accuracies)
|
| 497 |
+
|
| 498 |
+
for model_name, _ in sorted_models:
|
| 499 |
+
result = results[model_name]
|
| 500 |
+
row_class = (
|
| 501 |
+
"bg-primary/10 border-l-2 border-l-primary"
|
| 502 |
+
if model_name == best_avg
|
| 503 |
+
else ""
|
| 504 |
+
)
|
| 505 |
+
html += f"""
|
| 506 |
+
<tr class="border-b border-border hover:bg-muted/30 transition-colors {row_class}">
|
| 507 |
+
<td class="text-left py-3 px-4 font-medium">{model_name}</td>
|
| 508 |
+
"""
|
| 509 |
+
for threshold in thresholds:
|
| 510 |
+
accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
|
| 511 |
+
# Only highlight if there's meaningful variation
|
| 512 |
+
# (>0.1% difference between best and worst)
|
| 513 |
+
has_variation = (
|
| 514 |
+
best_accuracies[threshold] - min_accuracies[threshold]
|
| 515 |
+
) > 0.1
|
| 516 |
+
is_best = abs(accuracy_pct - best_accuracies[threshold]) < 0.01
|
| 517 |
+
cell_class = (
|
| 518 |
+
"text-right py-3 px-4 font-bold text-primary"
|
| 519 |
+
if (has_variation and is_best)
|
| 520 |
+
else "text-right py-3 px-4"
|
| 521 |
+
)
|
| 522 |
+
html += f' <td class="{cell_class}">{accuracy_pct:.2f}%</td>\n'
|
| 523 |
+
|
| 524 |
+
html += """
|
| 525 |
+
</tr>
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
html += """
|
| 529 |
+
</tbody>
|
| 530 |
+
</table>
|
| 531 |
+
</div>
|
| 532 |
+
</div>
|
| 533 |
+
|
| 534 |
+
</div>
|
| 535 |
+
</body>
|
| 536 |
+
</html>
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
# Write HTML file
|
| 540 |
+
with open(output_file, "w") as f:
|
| 541 |
+
f.write(html)
|
| 542 |
+
|
| 543 |
+
LOGGER.info("")
|
| 544 |
+
LOGGER.info("HTML report saved to: %s", output_file)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def main() -> None:
|
| 548 |
+
"""Compare all models."""
|
| 549 |
+
LOGGER.info("=" * 80)
|
| 550 |
+
LOGGER.info("Comprehensive Model Comparison")
|
| 551 |
+
LOGGER.info("=" * 80)
|
| 552 |
+
|
| 553 |
+
# Paths
|
| 554 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 555 |
+
|
| 556 |
+
# Load real Munsell dataset
|
| 557 |
+
LOGGER.info("")
|
| 558 |
+
LOGGER.info("Loading real Munsell dataset...")
|
| 559 |
+
xyY_samples = []
|
| 560 |
+
ground_truth = []
|
| 561 |
+
|
| 562 |
+
for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
|
| 563 |
+
try:
|
| 564 |
+
hue_code, value, chroma = munsell_spec_tuple
|
| 565 |
+
munsell_str = f"{hue_code} {value}/{chroma}"
|
| 566 |
+
spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 567 |
+
xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
|
| 568 |
+
xyY_samples.append(xyY_scaled)
|
| 569 |
+
ground_truth.append(spec)
|
| 570 |
+
except Exception: # noqa: BLE001, S112
|
| 571 |
+
continue
|
| 572 |
+
|
| 573 |
+
xyY_samples = np.array(xyY_samples)
|
| 574 |
+
ground_truth = np.array(ground_truth)
|
| 575 |
+
LOGGER.info("Loaded %d valid Munsell colors", len(xyY_samples))
|
| 576 |
+
|
| 577 |
+
# Define models to compare
|
| 578 |
+
models = [
|
| 579 |
+
{
|
| 580 |
+
"name": "MLP (Base Only)",
|
| 581 |
+
"files": [model_directory / "mlp.onnx"],
|
| 582 |
+
"params_file": model_directory / "mlp_normalization_params.npz",
|
| 583 |
+
"type": "single",
|
| 584 |
+
},
|
| 585 |
+
{
|
| 586 |
+
"name": "MLP + Error Predictor",
|
| 587 |
+
"files": [
|
| 588 |
+
model_directory / "mlp.onnx",
|
| 589 |
+
model_directory / "mlp_error_predictor.onnx",
|
| 590 |
+
],
|
| 591 |
+
"params_file": model_directory / "mlp_normalization_params.npz",
|
| 592 |
+
"type": "two_stage",
|
| 593 |
+
},
|
| 594 |
+
{
|
| 595 |
+
"name": "Unified MLP",
|
| 596 |
+
"files": [model_directory / "unified_mlp.onnx"],
|
| 597 |
+
"params_file": model_directory / "unified_mlp_normalization_params.npz",
|
| 598 |
+
"type": "single",
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"name": "MLP + Self-Attention",
|
| 602 |
+
"files": [model_directory / "mlp_attention.onnx"],
|
| 603 |
+
"params_file": model_directory
|
| 604 |
+
/ "mlp_attention_normalization_params.npz",
|
| 605 |
+
"type": "single",
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"name": "MLP + Self-Attention + Error Predictor",
|
| 609 |
+
"files": [
|
| 610 |
+
model_directory / "mlp_attention.onnx",
|
| 611 |
+
model_directory / "mlp_attention_error_predictor.onnx",
|
| 612 |
+
],
|
| 613 |
+
"params_file": model_directory
|
| 614 |
+
/ "mlp_attention_normalization_params.npz",
|
| 615 |
+
"type": "two_stage",
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"name": "Deep + Wide",
|
| 619 |
+
"files": [model_directory / "deep_wide.onnx"],
|
| 620 |
+
"params_file": model_directory / "deep_wide_normalization_params.npz",
|
| 621 |
+
"type": "single",
|
| 622 |
+
},
|
| 623 |
+
{
|
| 624 |
+
"name": "Mixture of Experts",
|
| 625 |
+
"files": [model_directory / "mixture_of_experts.onnx"],
|
| 626 |
+
"params_file": model_directory
|
| 627 |
+
/ "mixture_of_experts_normalization_params.npz",
|
| 628 |
+
"type": "single",
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"name": "FT-Transformer",
|
| 632 |
+
"files": [model_directory / "ft_transformer.onnx"],
|
| 633 |
+
"params_file": model_directory / "ft_transformer_normalization_params.npz",
|
| 634 |
+
"type": "single",
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"name": "Multi-Head",
|
| 638 |
+
"files": [model_directory / "multi_head.onnx"],
|
| 639 |
+
"params_file": model_directory / "multi_head_normalization_params.npz",
|
| 640 |
+
"type": "single",
|
| 641 |
+
},
|
| 642 |
+
{
|
| 643 |
+
"name": "Multi-Head (Optimized)",
|
| 644 |
+
"files": [model_directory / "multi_head_optimized.onnx"],
|
| 645 |
+
"params_file": model_directory / "multi_head_optimized_normalization_params.npz",
|
| 646 |
+
"type": "single",
|
| 647 |
+
},
|
| 648 |
+
{
|
| 649 |
+
"name": "Multi-Head + Error Predictor",
|
| 650 |
+
"files": [
|
| 651 |
+
model_directory / "multi_head.onnx",
|
| 652 |
+
model_directory / "multi_head_error_predictor.onnx",
|
| 653 |
+
],
|
| 654 |
+
"params_file": model_directory / "multi_head_normalization_params.npz",
|
| 655 |
+
"type": "two_stage",
|
| 656 |
+
},
|
| 657 |
+
{
|
| 658 |
+
"name": "Multi-MLP",
|
| 659 |
+
"files": [model_directory / "multi_mlp.onnx"],
|
| 660 |
+
"params_file": model_directory / "multi_mlp_normalization_params.npz",
|
| 661 |
+
"type": "single",
|
| 662 |
+
},
|
| 663 |
+
{
|
| 664 |
+
"name": "Multi-MLP + Error Predictor",
|
| 665 |
+
"files": [
|
| 666 |
+
model_directory / "multi_mlp.onnx",
|
| 667 |
+
model_directory / "multi_mlp_error_predictor.onnx",
|
| 668 |
+
],
|
| 669 |
+
"params_file": model_directory / "multi_mlp_normalization_params.npz",
|
| 670 |
+
"type": "two_stage",
|
| 671 |
+
},
|
| 672 |
+
{
|
| 673 |
+
"name": "Multi-MLP + Multi-Error Predictor",
|
| 674 |
+
"files": [
|
| 675 |
+
model_directory / "multi_mlp.onnx",
|
| 676 |
+
model_directory / "multi_mlp_multi_error_predictor.onnx",
|
| 677 |
+
],
|
| 678 |
+
"params_file": model_directory / "multi_mlp_normalization_params.npz",
|
| 679 |
+
"type": "two_stage",
|
| 680 |
+
},
|
| 681 |
+
{
|
| 682 |
+
"name": "Multi-MLP + Multi-Error Predictor (Optimized)",
|
| 683 |
+
"files": [
|
| 684 |
+
model_directory / "multi_mlp.onnx",
|
| 685 |
+
model_directory / "multi_mlp_multi_error_predictor_optimized.onnx",
|
| 686 |
+
],
|
| 687 |
+
"params_file": model_directory / "multi_mlp_normalization_params.npz",
|
| 688 |
+
"type": "two_stage",
|
| 689 |
+
},
|
| 690 |
+
{
|
| 691 |
+
"name": "Multi-MLP (Optimized)",
|
| 692 |
+
"files": [model_directory / "multi_mlp_optimized.onnx"],
|
| 693 |
+
"params_file": model_directory / "multi_mlp_optimized_normalization_params.npz",
|
| 694 |
+
"type": "single",
|
| 695 |
+
},
|
| 696 |
+
{
|
| 697 |
+
"name": "Multi-Head + Multi-Error Predictor",
|
| 698 |
+
"files": [
|
| 699 |
+
model_directory / "multi_head.onnx",
|
| 700 |
+
model_directory / "multi_head_multi_error_predictor.onnx",
|
| 701 |
+
],
|
| 702 |
+
"params_file": model_directory / "multi_head_normalization_params.npz",
|
| 703 |
+
"type": "two_stage",
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"name": "Multi-Head + Cross-Attention Error Predictor",
|
| 707 |
+
"files": [
|
| 708 |
+
model_directory / "multi_head.onnx",
|
| 709 |
+
model_directory / "multi_head_cross_attention_error_predictor.onnx",
|
| 710 |
+
],
|
| 711 |
+
"params_file": model_directory / "multi_head_normalization_params.npz",
|
| 712 |
+
"type": "two_stage",
|
| 713 |
+
},
|
| 714 |
+
{
|
| 715 |
+
"name": "Multi-Head (Optimized) + Multi-Error Predictor (Optimized)",
|
| 716 |
+
"files": [
|
| 717 |
+
model_directory / "multi_head_optimized.onnx",
|
| 718 |
+
model_directory / "multi_head_error_predictor_optimized.onnx",
|
| 719 |
+
],
|
| 720 |
+
"params_file": model_directory / "multi_head_optimized_normalization_params.npz",
|
| 721 |
+
"type": "two_stage",
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
"name": "Multi-Head (Circular Loss)",
|
| 725 |
+
"files": [model_directory / "multi_head_circular.onnx"],
|
| 726 |
+
"params_file": model_directory / "multi_head_circular_normalization_params.npz",
|
| 727 |
+
"type": "single",
|
| 728 |
+
},
|
| 729 |
+
{
|
| 730 |
+
"name": "Multi-Head (Large Dataset)",
|
| 731 |
+
"files": [model_directory / "multi_head_large.onnx"],
|
| 732 |
+
"params_file": model_directory / "multi_head_large_normalization_params.npz",
|
| 733 |
+
"type": "single",
|
| 734 |
+
},
|
| 735 |
+
{
|
| 736 |
+
"name": "Multi-Head + Multi-Error Predictor (Large Dataset)",
|
| 737 |
+
"files": [
|
| 738 |
+
model_directory / "multi_head_large.onnx",
|
| 739 |
+
model_directory / "multi_head_multi_error_predictor_large.onnx",
|
| 740 |
+
],
|
| 741 |
+
"params_file": model_directory / "multi_head_large_normalization_params.npz",
|
| 742 |
+
"type": "two_stage",
|
| 743 |
+
},
|
| 744 |
+
{
|
| 745 |
+
"name": "Multi-MLP (Large Dataset)",
|
| 746 |
+
"files": [model_directory / "multi_mlp_large.onnx"],
|
| 747 |
+
"params_file": model_directory / "multi_mlp_large_normalization_params.npz",
|
| 748 |
+
"type": "single",
|
| 749 |
+
},
|
| 750 |
+
{
|
| 751 |
+
"name": "Multi-MLP + Multi-Error Predictor (Large Dataset)",
|
| 752 |
+
"files": [
|
| 753 |
+
model_directory / "multi_mlp_large.onnx",
|
| 754 |
+
model_directory / "multi_mlp_multi_error_predictor_large.onnx",
|
| 755 |
+
],
|
| 756 |
+
"params_file": model_directory / "multi_mlp_large_normalization_params.npz",
|
| 757 |
+
"type": "two_stage",
|
| 758 |
+
},
|
| 759 |
+
{
|
| 760 |
+
"name": "Transformer (Large Dataset)",
|
| 761 |
+
"files": [model_directory / "transformer_large.onnx"],
|
| 762 |
+
"params_file": model_directory / "transformer_large_normalization_params.npz",
|
| 763 |
+
"type": "single",
|
| 764 |
+
},
|
| 765 |
+
{
|
| 766 |
+
"name": "Transformer + Error Predictor (Large Dataset)",
|
| 767 |
+
"files": [
|
| 768 |
+
model_directory / "transformer_large.onnx",
|
| 769 |
+
model_directory / "transformer_multi_error_predictor_large.onnx",
|
| 770 |
+
],
|
| 771 |
+
"params_file": model_directory / "transformer_large_normalization_params.npz",
|
| 772 |
+
"type": "two_stage",
|
| 773 |
+
},
|
| 774 |
+
{
|
| 775 |
+
"name": "Multi-Head Refined (REAL Only)",
|
| 776 |
+
"files": [model_directory / "multi_head_refined_real.onnx"],
|
| 777 |
+
"params_file": model_directory / "multi_head_refined_real_normalization_params.npz",
|
| 778 |
+
"type": "single",
|
| 779 |
+
},
|
| 780 |
+
{
|
| 781 |
+
"name": "Multi-Head Refined + Error Predictor (REAL Only)",
|
| 782 |
+
"files": [
|
| 783 |
+
model_directory / "multi_head_refined_real.onnx",
|
| 784 |
+
model_directory / "multi_head_multi_error_predictor_refined_real.onnx",
|
| 785 |
+
],
|
| 786 |
+
"params_file": model_directory / "multi_head_refined_real_normalization_params.npz",
|
| 787 |
+
"type": "two_stage",
|
| 788 |
+
},
|
| 789 |
+
{
|
| 790 |
+
"name": "Multi-Head + Multi-Error Predictor + Multi-Error Predictor (3-Stage)",
|
| 791 |
+
"files": [
|
| 792 |
+
model_directory / "multi_head_large.onnx",
|
| 793 |
+
model_directory / "multi_head_multi_error_predictor_large.onnx",
|
| 794 |
+
model_directory / "multi_head_3stage_error_predictor.onnx",
|
| 795 |
+
],
|
| 796 |
+
"params_file": model_directory / "multi_head_large_normalization_params.npz",
|
| 797 |
+
"type": "three_stage",
|
| 798 |
+
},
|
| 799 |
+
{
|
| 800 |
+
"name": "Multi-Head (Weighted + Boundary Loss)",
|
| 801 |
+
"files": [model_directory / "multi_head_weighted_boundary.onnx"],
|
| 802 |
+
"params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz",
|
| 803 |
+
"type": "single",
|
| 804 |
+
},
|
| 805 |
+
{
|
| 806 |
+
"name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor",
|
| 807 |
+
"files": [
|
| 808 |
+
model_directory / "multi_head_weighted_boundary.onnx",
|
| 809 |
+
model_directory / "multi_head_weighted_boundary_multi_error_predictor.onnx",
|
| 810 |
+
],
|
| 811 |
+
"params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz",
|
| 812 |
+
"type": "two_stage",
|
| 813 |
+
},
|
| 814 |
+
{
|
| 815 |
+
"name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor (Weighted + Boundary Loss)",
|
| 816 |
+
"files": [
|
| 817 |
+
model_directory / "multi_head_weighted_boundary.onnx",
|
| 818 |
+
model_directory / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx",
|
| 819 |
+
],
|
| 820 |
+
"params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz",
|
| 821 |
+
"type": "two_stage",
|
| 822 |
+
},
|
| 823 |
+
{
|
| 824 |
+
"name": "Multi-MLP (Weighted + Boundary Loss) (Large Dataset)",
|
| 825 |
+
"files": [model_directory / "multi_mlp_weighted_boundary.onnx"],
|
| 826 |
+
"params_file": model_directory / "multi_mlp_weighted_boundary_normalization_params.npz",
|
| 827 |
+
"type": "single",
|
| 828 |
+
},
|
| 829 |
+
{
|
| 830 |
+
"name": "Multi-MLP (Weighted + Boundary Loss) + Multi-Error Predictor (Weighted + Boundary Loss) (Large Dataset)",
|
| 831 |
+
"files": [
|
| 832 |
+
model_directory / "multi_mlp_weighted_boundary.onnx",
|
| 833 |
+
model_directory / "multi_mlp_weighted_boundary_multi_error_predictor.onnx",
|
| 834 |
+
],
|
| 835 |
+
"params_file": model_directory / "multi_mlp_weighted_boundary_normalization_params.npz",
|
| 836 |
+
"type": "two_stage",
|
| 837 |
+
},
|
| 838 |
+
{
|
| 839 |
+
"name": "Multi-ResNet (Large Dataset)",
|
| 840 |
+
"files": [model_directory / "multi_resnet_large.onnx"],
|
| 841 |
+
"params_file": model_directory / "multi_resnet_large_normalization_params.npz",
|
| 842 |
+
"type": "single",
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
"name": "Multi-ResNet + Multi-Error Predictor (Large Dataset)",
|
| 846 |
+
"files": [
|
| 847 |
+
model_directory / "multi_resnet_large.onnx",
|
| 848 |
+
model_directory / "multi_resnet_error_predictor_large.onnx",
|
| 849 |
+
],
|
| 850 |
+
"params_file": model_directory / "multi_resnet_large_normalization_params.npz",
|
| 851 |
+
"type": "two_stage",
|
| 852 |
+
},
|
| 853 |
+
]
|
| 854 |
+
|
| 855 |
+
# Benchmark colour library's iterative implementation first
|
| 856 |
+
LOGGER.info("")
|
| 857 |
+
LOGGER.info("=" * 80)
|
| 858 |
+
LOGGER.info("Colour Library (Iterative)")
|
| 859 |
+
LOGGER.info("=" * 80)
|
| 860 |
+
|
| 861 |
+
# Benchmark the iterative xyY_to_munsell_specification function
|
| 862 |
+
# Note: Using full dataset (100% of samples)
|
| 863 |
+
|
| 864 |
+
# Set random seed for reproducibility
|
| 865 |
+
np.random.seed(42)
|
| 866 |
+
|
| 867 |
+
# Use 100% of samples for comprehensive benchmarking
|
| 868 |
+
sample_count = len(xyY_samples)
|
| 869 |
+
sampled_indices = np.arange(len(xyY_samples))
|
| 870 |
+
xyY_benchmark_samples = xyY_samples[sampled_indices]
|
| 871 |
+
|
| 872 |
+
# Measure inference time on sampled Munsell colors
|
| 873 |
+
start_time = time.perf_counter()
|
| 874 |
+
convergence_failures = 0
|
| 875 |
+
successful_inferences = 0
|
| 876 |
+
|
| 877 |
+
with warnings.catch_warnings():
|
| 878 |
+
warnings.simplefilter("ignore")
|
| 879 |
+
for xyy in xyY_benchmark_samples:
|
| 880 |
+
try:
|
| 881 |
+
xyY_to_munsell_specification(xyy)
|
| 882 |
+
successful_inferences += 1
|
| 883 |
+
except (RuntimeError, ValueError):
|
| 884 |
+
# Out-of-gamut color that doesn't converge or not in renotation system
|
| 885 |
+
convergence_failures += 1
|
| 886 |
+
|
| 887 |
+
end_time = time.perf_counter()
|
| 888 |
+
|
| 889 |
+
# Calculate average time per successful inference (in milliseconds)
|
| 890 |
+
total_time_s = end_time - start_time
|
| 891 |
+
colour_inference_time_ms = (
|
| 892 |
+
(total_time_s / successful_inferences) * 1000
|
| 893 |
+
if successful_inferences > 0
|
| 894 |
+
else 0
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
LOGGER.info("")
|
| 898 |
+
LOGGER.info("Performance Metrics:")
|
| 899 |
+
LOGGER.info(" Successful inferences: %d", successful_inferences)
|
| 900 |
+
LOGGER.info(" Convergence failures: %d", convergence_failures)
|
| 901 |
+
LOGGER.info(" Inference Speed: %.4f ms/sample", colour_inference_time_ms)
|
| 902 |
+
LOGGER.info(" Note: This is the baseline iterative implementation")
|
| 903 |
+
|
| 904 |
+
# Store the baseline speed
|
| 905 |
+
baseline_inference_time_ms = colour_inference_time_ms
|
| 906 |
+
|
| 907 |
+
# Convert ground truth Munsell specs to CIE Lab for Delta-E comparison
|
| 908 |
+
# Path: Munsell spec → xyY → XYZ → Lab
|
| 909 |
+
LOGGER.info("")
|
| 910 |
+
LOGGER.info(
|
| 911 |
+
"Converting ground truth to CIE Lab for Delta-E comparison..."
|
| 912 |
+
)
|
| 913 |
+
LOGGER.info(" Path: Munsell spec \u2192 xyY \u2192 XYZ \u2192 Lab")
|
| 914 |
+
reference_Lab = []
|
| 915 |
+
for spec in ground_truth:
|
| 916 |
+
try:
|
| 917 |
+
# Munsell specification → xyY
|
| 918 |
+
xyy = munsell_specification_to_xyY(spec)
|
| 919 |
+
# xyY → XYZ
|
| 920 |
+
XYZ = xyY_to_XYZ(xyy)
|
| 921 |
+
# XYZ → Lab (Illuminant C for Munsell)
|
| 922 |
+
Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 923 |
+
reference_Lab.append(Lab)
|
| 924 |
+
except (RuntimeError, ValueError):
|
| 925 |
+
# If conversion fails, use NaN
|
| 926 |
+
reference_Lab.append(np.array([np.nan, np.nan, np.nan]))
|
| 927 |
+
|
| 928 |
+
reference_Lab = np.array(reference_Lab)
|
| 929 |
+
LOGGER.info(
|
| 930 |
+
" Converted %d ground truth specs to CIE Lab",
|
| 931 |
+
len(reference_Lab),
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
# Use the same sampled subset for ML model evaluations (for fair comparison)
|
| 935 |
+
xyY_samples = xyY_benchmark_samples
|
| 936 |
+
ground_truth = ground_truth[sampled_indices]
|
| 937 |
+
|
| 938 |
+
# Evaluate each model
|
| 939 |
+
results = {}
|
| 940 |
+
|
| 941 |
+
for model_info in models:
|
| 942 |
+
model_name = model_info["name"]
|
| 943 |
+
LOGGER.info("")
|
| 944 |
+
LOGGER.info("=" * 80)
|
| 945 |
+
LOGGER.info(model_name)
|
| 946 |
+
LOGGER.info("=" * 80)
|
| 947 |
+
|
| 948 |
+
# Load normalization params for this model
|
| 949 |
+
params = np.load(model_info["params_file"], allow_pickle=True)
|
| 950 |
+
# input_params may not exist if xyY is already in [0, 1] range
|
| 951 |
+
input_params = (
|
| 952 |
+
params["input_params"].item()
|
| 953 |
+
if "input_params" in params.files
|
| 954 |
+
else None
|
| 955 |
+
)
|
| 956 |
+
output_params = params["output_params"].item()
|
| 957 |
+
|
| 958 |
+
# Normalize input with this model's params (None means no normalization)
|
| 959 |
+
X_norm = normalize_input(xyY_samples, input_params)
|
| 960 |
+
|
| 961 |
+
# Calculate model size
|
| 962 |
+
model_size_mb = get_model_size_mb(model_info["files"])
|
| 963 |
+
|
| 964 |
+
if model_info["type"] == "two_stage":
|
| 965 |
+
# Two-stage model
|
| 966 |
+
base_session = ort.InferenceSession(str(model_info["files"][0]))
|
| 967 |
+
error_session = ort.InferenceSession(str(model_info["files"][1]))
|
| 968 |
+
|
| 969 |
+
# Define inference callable for benchmarking
|
| 970 |
+
def two_stage_inference(
|
| 971 |
+
_base_session: ort.InferenceSession = base_session,
|
| 972 |
+
_error_session: ort.InferenceSession = error_session,
|
| 973 |
+
_X_norm: NDArray = X_norm,
|
| 974 |
+
) -> NDArray:
|
| 975 |
+
base_pred = _base_session.run(None, {"xyY": _X_norm})[0]
|
| 976 |
+
combined = np.concatenate([_X_norm, base_pred], axis=1).astype(
|
| 977 |
+
np.float32
|
| 978 |
+
)
|
| 979 |
+
error_corr = _error_session.run(None, {"combined_input": combined})[
|
| 980 |
+
0
|
| 981 |
+
]
|
| 982 |
+
return base_pred + error_corr
|
| 983 |
+
|
| 984 |
+
# Benchmark speed
|
| 985 |
+
inference_time_ms = benchmark_inference_speed(
|
| 986 |
+
two_stage_inference, X_norm
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
# Get predictions
|
| 990 |
+
base_pred_norm = base_session.run(None, {"xyY": X_norm})[0]
|
| 991 |
+
combined_input = np.concatenate(
|
| 992 |
+
[X_norm, base_pred_norm], axis=1
|
| 993 |
+
).astype(np.float32)
|
| 994 |
+
error_correction_norm = error_session.run(
|
| 995 |
+
None, {"combined_input": combined_input}
|
| 996 |
+
)[0]
|
| 997 |
+
final_pred_norm = base_pred_norm + error_correction_norm
|
| 998 |
+
pred = denormalize_output(final_pred_norm, output_params)
|
| 999 |
+
errors = np.abs(pred - ground_truth)
|
| 1000 |
+
|
| 1001 |
+
result = {
|
| 1002 |
+
"hue_mae": np.mean(errors[:, 0]),
|
| 1003 |
+
"value_mae": np.mean(errors[:, 1]),
|
| 1004 |
+
"chroma_mae": np.mean(errors[:, 2]),
|
| 1005 |
+
"code_mae": np.mean(errors[:, 3]),
|
| 1006 |
+
"max_errors": np.max(errors, axis=1),
|
| 1007 |
+
"hue_errors": errors[:, 0],
|
| 1008 |
+
"value_errors": errors[:, 1],
|
| 1009 |
+
"chroma_errors": errors[:, 2],
|
| 1010 |
+
"code_errors": errors[:, 3],
|
| 1011 |
+
"model_size_mb": model_size_mb,
|
| 1012 |
+
"inference_time_ms": inference_time_ms,
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
# Compute Delta-E against ground truth
|
| 1016 |
+
delta_E_values = []
|
| 1017 |
+
for idx in range(len(pred)):
|
| 1018 |
+
try:
|
| 1019 |
+
# Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab
|
| 1020 |
+
ml_spec = clamp_munsell_specification(pred[idx])
|
| 1021 |
+
|
| 1022 |
+
# Round Code to nearest integer before round-trip conversion
|
| 1023 |
+
ml_spec_for_conversion = ml_spec.copy()
|
| 1024 |
+
ml_spec_for_conversion[3] = round(ml_spec[3])
|
| 1025 |
+
|
| 1026 |
+
ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
|
| 1027 |
+
ml_XYZ = xyY_to_XYZ(ml_xyy)
|
| 1028 |
+
ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 1029 |
+
|
| 1030 |
+
# Get ground truth Lab
|
| 1031 |
+
reference_Lab_sample = reference_Lab[idx]
|
| 1032 |
+
|
| 1033 |
+
# Compute Delta-E CIE2000
|
| 1034 |
+
delta_E = delta_E_CIE2000(reference_Lab_sample, ml_Lab)
|
| 1035 |
+
delta_E_values.append(delta_E)
|
| 1036 |
+
except (RuntimeError, ValueError):
|
| 1037 |
+
# Skip if conversion fails
|
| 1038 |
+
continue
|
| 1039 |
+
|
| 1040 |
+
result["delta_E"] = (
|
| 1041 |
+
np.mean(delta_E_values) if delta_E_values else np.nan
|
| 1042 |
+
)
|
| 1043 |
+
elif model_info["type"] == "three_stage":
|
| 1044 |
+
# Three-stage model: base + error predictor 1 + error predictor 2
|
| 1045 |
+
base_session = ort.InferenceSession(str(model_info["files"][0]))
|
| 1046 |
+
error1_session = ort.InferenceSession(str(model_info["files"][1]))
|
| 1047 |
+
error2_session = ort.InferenceSession(str(model_info["files"][2]))
|
| 1048 |
+
|
| 1049 |
+
# Define inference callable for benchmarking
|
| 1050 |
+
def three_stage_inference(
|
| 1051 |
+
_base_session: ort.InferenceSession = base_session,
|
| 1052 |
+
_error1_session: ort.InferenceSession = error1_session,
|
| 1053 |
+
_error2_session: ort.InferenceSession = error2_session,
|
| 1054 |
+
_X_norm: NDArray = X_norm,
|
| 1055 |
+
) -> NDArray:
|
| 1056 |
+
# Stage 1: Base model
|
| 1057 |
+
base_pred = _base_session.run(None, {"xyY": _X_norm})[0]
|
| 1058 |
+
# Stage 2: First error correction
|
| 1059 |
+
combined1 = np.concatenate([_X_norm, base_pred], axis=1).astype(
|
| 1060 |
+
np.float32
|
| 1061 |
+
)
|
| 1062 |
+
error1_corr = _error1_session.run(
|
| 1063 |
+
None, {"combined_input": combined1}
|
| 1064 |
+
)[0]
|
| 1065 |
+
stage2_pred = base_pred + error1_corr
|
| 1066 |
+
# Stage 3: Second error correction
|
| 1067 |
+
combined2 = np.concatenate([_X_norm, stage2_pred], axis=1).astype(
|
| 1068 |
+
np.float32
|
| 1069 |
+
)
|
| 1070 |
+
error2_corr = _error2_session.run(
|
| 1071 |
+
None, {"combined_input": combined2}
|
| 1072 |
+
)[0]
|
| 1073 |
+
return stage2_pred + error2_corr
|
| 1074 |
+
|
| 1075 |
+
# Benchmark speed
|
| 1076 |
+
inference_time_ms = benchmark_inference_speed(
|
| 1077 |
+
three_stage_inference, X_norm
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
# Get predictions
|
| 1081 |
+
base_pred_norm = base_session.run(None, {"xyY": X_norm})[0]
|
| 1082 |
+
combined1 = np.concatenate([X_norm, base_pred_norm], axis=1).astype(
|
| 1083 |
+
np.float32
|
| 1084 |
+
)
|
| 1085 |
+
error1_corr_norm = error1_session.run(
|
| 1086 |
+
None, {"combined_input": combined1}
|
| 1087 |
+
)[0]
|
| 1088 |
+
stage2_pred_norm = base_pred_norm + error1_corr_norm
|
| 1089 |
+
combined2 = np.concatenate([X_norm, stage2_pred_norm], axis=1).astype(
|
| 1090 |
+
np.float32
|
| 1091 |
+
)
|
| 1092 |
+
error2_corr_norm = error2_session.run(
|
| 1093 |
+
None, {"combined_input": combined2}
|
| 1094 |
+
)[0]
|
| 1095 |
+
final_pred_norm = stage2_pred_norm + error2_corr_norm
|
| 1096 |
+
pred = denormalize_output(final_pred_norm, output_params)
|
| 1097 |
+
errors = np.abs(pred - ground_truth)
|
| 1098 |
+
|
| 1099 |
+
result = {
|
| 1100 |
+
"hue_mae": np.mean(errors[:, 0]),
|
| 1101 |
+
"value_mae": np.mean(errors[:, 1]),
|
| 1102 |
+
"chroma_mae": np.mean(errors[:, 2]),
|
| 1103 |
+
"code_mae": np.mean(errors[:, 3]),
|
| 1104 |
+
"max_errors": np.max(errors, axis=1),
|
| 1105 |
+
"hue_errors": errors[:, 0],
|
| 1106 |
+
"value_errors": errors[:, 1],
|
| 1107 |
+
"chroma_errors": errors[:, 2],
|
| 1108 |
+
"code_errors": errors[:, 3],
|
| 1109 |
+
"model_size_mb": model_size_mb,
|
| 1110 |
+
"inference_time_ms": inference_time_ms,
|
| 1111 |
+
}
|
| 1112 |
+
|
| 1113 |
+
# Compute Delta-E against ground truth for three-stage model
|
| 1114 |
+
delta_E_values = []
|
| 1115 |
+
for idx in range(len(pred)):
|
| 1116 |
+
try:
|
| 1117 |
+
ml_spec = clamp_munsell_specification(pred[idx])
|
| 1118 |
+
ml_spec_for_conversion = ml_spec.copy()
|
| 1119 |
+
ml_spec_for_conversion[3] = round(ml_spec[3])
|
| 1120 |
+
ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
|
| 1121 |
+
ml_XYZ = xyY_to_XYZ(ml_xyy)
|
| 1122 |
+
ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 1123 |
+
delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
|
| 1124 |
+
delta_E_values.append(delta_E)
|
| 1125 |
+
except (RuntimeError, ValueError):
|
| 1126 |
+
continue
|
| 1127 |
+
|
| 1128 |
+
result["delta_E"] = (
|
| 1129 |
+
np.mean(delta_E_values) if delta_E_values else np.nan
|
| 1130 |
+
)
|
| 1131 |
+
else:
|
| 1132 |
+
# Single model
|
| 1133 |
+
session = ort.InferenceSession(str(model_info["files"][0]))
|
| 1134 |
+
|
| 1135 |
+
# Define inference callable for benchmarking
|
| 1136 |
+
def single_inference(
|
| 1137 |
+
_session: ort.InferenceSession = session, _X_norm: NDArray = X_norm
|
| 1138 |
+
) -> NDArray:
|
| 1139 |
+
return _session.run(None, {"xyY": _X_norm})[0]
|
| 1140 |
+
|
| 1141 |
+
# Benchmark speed
|
| 1142 |
+
inference_time_ms = benchmark_inference_speed(single_inference, X_norm)
|
| 1143 |
+
|
| 1144 |
+
result = evaluate_model(
|
| 1145 |
+
session,
|
| 1146 |
+
X_norm,
|
| 1147 |
+
ground_truth,
|
| 1148 |
+
output_params,
|
| 1149 |
+
reference_Lab=reference_Lab,
|
| 1150 |
+
)
|
| 1151 |
+
result["model_size_mb"] = model_size_mb
|
| 1152 |
+
result["inference_time_ms"] = inference_time_ms
|
| 1153 |
+
|
| 1154 |
+
results[model_name] = result
|
| 1155 |
+
|
| 1156 |
+
# Print results
|
| 1157 |
+
LOGGER.info("")
|
| 1158 |
+
LOGGER.info("Mean Absolute Errors:")
|
| 1159 |
+
LOGGER.info(" Hue: %.4f", result["hue_mae"])
|
| 1160 |
+
LOGGER.info(" Value: %.4f", result["value_mae"])
|
| 1161 |
+
LOGGER.info(" Chroma: %.4f", result["chroma_mae"])
|
| 1162 |
+
LOGGER.info(" Code: %.4f", result["code_mae"])
|
| 1163 |
+
if not np.isnan(result["delta_E"]):
|
| 1164 |
+
LOGGER.info(" Delta-E (vs Ground Truth): %.4f", result["delta_E"])
|
| 1165 |
+
LOGGER.info("")
|
| 1166 |
+
LOGGER.info("Performance Metrics:")
|
| 1167 |
+
LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"])
|
| 1168 |
+
LOGGER.info(
|
| 1169 |
+
" Inference Speed: %.4f ms/sample", result["inference_time_ms"]
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
# Summary comparison
|
| 1174 |
+
LOGGER.info("")
|
| 1175 |
+
LOGGER.info("=" * 80)
|
| 1176 |
+
LOGGER.info("SUMMARY COMPARISON")
|
| 1177 |
+
LOGGER.info("=" * 80)
|
| 1178 |
+
LOGGER.info("")
|
| 1179 |
+
|
| 1180 |
+
if not results:
|
| 1181 |
+
LOGGER.info("⚠️ No models were successfully evaluated")
|
| 1182 |
+
return
|
| 1183 |
+
|
| 1184 |
+
# MAE comparison table
|
| 1185 |
+
LOGGER.info("Mean Absolute Error Comparison:")
|
| 1186 |
+
LOGGER.info("")
|
| 1187 |
+
header = "{:<35} {:>8} {:>8} {:>8} {:>8} {:>10}".format(
|
| 1188 |
+
"Model",
|
| 1189 |
+
"Hue",
|
| 1190 |
+
"Value",
|
| 1191 |
+
"Chroma",
|
| 1192 |
+
"Code",
|
| 1193 |
+
"Delta-E",
|
| 1194 |
+
)
|
| 1195 |
+
LOGGER.info(header)
|
| 1196 |
+
LOGGER.info("-" * 90)
|
| 1197 |
+
|
| 1198 |
+
for model_name, result in results.items():
|
| 1199 |
+
delta_E_str = (
|
| 1200 |
+
f"{result['delta_E']:.4f}" if not np.isnan(result["delta_E"]) else "N/A"
|
| 1201 |
+
)
|
| 1202 |
+
LOGGER.info(
|
| 1203 |
+
"%-35s %8.4f %8.4f %8.4f %8.4f %10s",
|
| 1204 |
+
model_name[:35],
|
| 1205 |
+
result["hue_mae"],
|
| 1206 |
+
result["value_mae"],
|
| 1207 |
+
result["chroma_mae"],
|
| 1208 |
+
result["code_mae"],
|
| 1209 |
+
delta_E_str,
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
# Precision threshold comparison
|
| 1213 |
+
LOGGER.info("")
|
| 1214 |
+
LOGGER.info("Accuracy at Precision Thresholds:")
|
| 1215 |
+
LOGGER.info("")
|
| 1216 |
+
|
| 1217 |
+
thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0]
|
| 1218 |
+
header_parts = [f"{'Model/Threshold':<35}"]
|
| 1219 |
+
header_parts.extend(f"{f'< {threshold:.0e}':>10}" for threshold in thresholds)
|
| 1220 |
+
LOGGER.info(" ".join(header_parts))
|
| 1221 |
+
LOGGER.info("-" * 80)
|
| 1222 |
+
|
| 1223 |
+
for model_name, result in results.items():
|
| 1224 |
+
row_parts = [f"{model_name[:35]:<35}"]
|
| 1225 |
+
for threshold in thresholds:
|
| 1226 |
+
accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
|
| 1227 |
+
row_parts.append(f"{accuracy_pct:9.2f}%")
|
| 1228 |
+
LOGGER.info(" ".join(row_parts))
|
| 1229 |
+
|
| 1230 |
+
# Performance metrics comparison
|
| 1231 |
+
LOGGER.info("")
|
| 1232 |
+
LOGGER.info("Model Size and Inference Speed Comparison:")
|
| 1233 |
+
LOGGER.info("")
|
| 1234 |
+
header = f"{'Model':<35} {'Size (MB)':>12} {'Speed (ms/sample)':>18}"
|
| 1235 |
+
LOGGER.info(header)
|
| 1236 |
+
LOGGER.info("-" * 80)
|
| 1237 |
+
|
| 1238 |
+
for model_name, result in results.items():
|
| 1239 |
+
LOGGER.info(
|
| 1240 |
+
"%-35s %11.2f %17.4f",
|
| 1241 |
+
model_name[:35],
|
| 1242 |
+
result["model_size_mb"],
|
| 1243 |
+
result["inference_time_ms"],
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
# Find best model
|
| 1247 |
+
LOGGER.info("")
|
| 1248 |
+
LOGGER.info("=" * 80)
|
| 1249 |
+
LOGGER.info("BEST MODELS BY METRIC")
|
| 1250 |
+
LOGGER.info("=" * 80)
|
| 1251 |
+
LOGGER.info("")
|
| 1252 |
+
|
| 1253 |
+
metrics = ["hue_mae", "value_mae", "chroma_mae", "code_mae"]
|
| 1254 |
+
metric_names = ["Hue MAE", "Value MAE", "Chroma MAE", "Code MAE"]
|
| 1255 |
+
|
| 1256 |
+
for metric, metric_name in zip(metrics, metric_names, strict=False):
|
| 1257 |
+
best_model = min(results.items(), key=lambda x: x[1][metric])
|
| 1258 |
+
LOGGER.info(
|
| 1259 |
+
"%-15s: %s (%.4f)",
|
| 1260 |
+
metric_name,
|
| 1261 |
+
best_model[0],
|
| 1262 |
+
best_model[1][metric],
|
| 1263 |
+
)
|
| 1264 |
+
|
| 1265 |
+
# Overall best (average rank)
|
| 1266 |
+
LOGGER.info("")
|
| 1267 |
+
LOGGER.info("Overall Best (by average component MAE):")
|
| 1268 |
+
for model_name, result in results.items():
|
| 1269 |
+
avg_mae = np.mean(
|
| 1270 |
+
[
|
| 1271 |
+
result["hue_mae"],
|
| 1272 |
+
result["value_mae"],
|
| 1273 |
+
result["chroma_mae"],
|
| 1274 |
+
result["code_mae"],
|
| 1275 |
+
]
|
| 1276 |
+
)
|
| 1277 |
+
LOGGER.info(" %s: %.4f", model_name, avg_mae)
|
| 1278 |
+
|
| 1279 |
+
LOGGER.info("")
|
| 1280 |
+
LOGGER.info("=" * 80)
|
| 1281 |
+
|
| 1282 |
+
# Generate HTML report
|
| 1283 |
+
report_dir = PROJECT_ROOT / "reports" / "from_xyY"
|
| 1284 |
+
report_dir.mkdir(exist_ok=True)
|
| 1285 |
+
report_file = report_dir / "model_comparison.html"
|
| 1286 |
+
generate_html_report(
|
| 1287 |
+
results, len(xyY_samples), report_file, baseline_inference_time_ms
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
if __name__ == "__main__":
|
| 1292 |
+
main()
|
learning_munsell/comparison/from_xyY/compare_gamma_model.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick comparison of the gamma-corrected models against baselines.
|
| 3 |
+
|
| 4 |
+
This script compares:
|
| 5 |
+
1. MLP (Base) vs MLP (Gamma 2.33)
|
| 6 |
+
2. Multi-Head (Base) vs Multi-Head (Gamma 2.33) vs Multi-Head (ST.2084)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import onnxruntime as ort
|
| 14 |
+
from colour import XYZ_to_Lab, xyY_to_XYZ
|
| 15 |
+
from colour.difference import delta_E_CIE2000
|
| 16 |
+
from colour.models import eotf_inverse_ST2084
|
| 17 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
|
| 18 |
+
from colour.notation.munsell import (
|
| 19 |
+
CCS_ILLUMINANT_MUNSELL,
|
| 20 |
+
munsell_colour_to_munsell_specification,
|
| 21 |
+
munsell_specification_to_xyY,
|
| 22 |
+
)
|
| 23 |
+
from numpy.typing import NDArray
|
| 24 |
+
|
| 25 |
+
from learning_munsell import PROJECT_ROOT
|
| 26 |
+
|
| 27 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 28 |
+
LOGGER = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def normalize_input_standard(X: NDArray, params: dict[str, Any]) -> NDArray:
|
| 32 |
+
"""Standard xyY normalization."""
|
| 33 |
+
X_norm = np.copy(X)
|
| 34 |
+
X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
|
| 35 |
+
params["x_range"][1] - params["x_range"][0]
|
| 36 |
+
)
|
| 37 |
+
X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
|
| 38 |
+
params["y_range"][1] - params["y_range"][0]
|
| 39 |
+
)
|
| 40 |
+
X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / (
|
| 41 |
+
params["Y_range"][1] - params["Y_range"][0]
|
| 42 |
+
)
|
| 43 |
+
return X_norm.astype(np.float32)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def normalize_input_gamma(X: NDArray, params: dict[str, Any]) -> NDArray:
|
| 47 |
+
"""Gamma-corrected xyY normalization."""
|
| 48 |
+
gamma = params.get("gamma", 2.33)
|
| 49 |
+
X_norm = np.copy(X)
|
| 50 |
+
X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
|
| 51 |
+
params["x_range"][1] - params["x_range"][0]
|
| 52 |
+
)
|
| 53 |
+
X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
|
| 54 |
+
params["y_range"][1] - params["y_range"][0]
|
| 55 |
+
)
|
| 56 |
+
# Normalize Y then apply gamma
|
| 57 |
+
Y_normalized = (X[..., 2] - params["Y_range"][0]) / (
|
| 58 |
+
params["Y_range"][1] - params["Y_range"][0]
|
| 59 |
+
)
|
| 60 |
+
Y_normalized = np.clip(Y_normalized, 0, 1)
|
| 61 |
+
X_norm[..., 2] = np.power(Y_normalized, 1.0 / gamma)
|
| 62 |
+
return X_norm.astype(np.float32)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def normalize_input_st2084(X: NDArray, params: dict[str, Any]) -> NDArray:
|
| 66 |
+
"""ST.2084 (PQ) encoded xyY normalization."""
|
| 67 |
+
L_p = params.get("L_p", 100.0)
|
| 68 |
+
X_norm = np.copy(X)
|
| 69 |
+
X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
|
| 70 |
+
params["x_range"][1] - params["x_range"][0]
|
| 71 |
+
)
|
| 72 |
+
X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
|
| 73 |
+
params["y_range"][1] - params["y_range"][0]
|
| 74 |
+
)
|
| 75 |
+
# Normalize Y then apply ST.2084
|
| 76 |
+
Y_normalized = (X[..., 2] - params["Y_range"][0]) / (
|
| 77 |
+
params["Y_range"][1] - params["Y_range"][0]
|
| 78 |
+
)
|
| 79 |
+
Y_normalized = np.clip(Y_normalized, 0, 1)
|
| 80 |
+
# Scale to cd/m² and apply ST.2084 inverse EOTF
|
| 81 |
+
Y_cdm2 = Y_normalized * L_p
|
| 82 |
+
X_norm[..., 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p)
|
| 83 |
+
return X_norm.astype(np.float32)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
|
| 87 |
+
"""Denormalize Munsell output."""
|
| 88 |
+
y = np.copy(y_norm)
|
| 89 |
+
y[..., 0] = (
|
| 90 |
+
y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
|
| 91 |
+
+ params["hue_range"][0]
|
| 92 |
+
)
|
| 93 |
+
y[..., 1] = (
|
| 94 |
+
y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
|
| 95 |
+
+ params["value_range"][0]
|
| 96 |
+
)
|
| 97 |
+
y[..., 2] = (
|
| 98 |
+
y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
|
| 99 |
+
+ params["chroma_range"][0]
|
| 100 |
+
)
|
| 101 |
+
y[..., 3] = (
|
| 102 |
+
y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
|
| 103 |
+
+ params["code_range"][0]
|
| 104 |
+
)
|
| 105 |
+
return y
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def clamp_munsell_specification(spec: NDArray) -> NDArray:
|
| 109 |
+
"""Clamp Munsell specification to valid ranges."""
|
| 110 |
+
clamped = np.copy(spec)
|
| 111 |
+
clamped[..., 0] = np.clip(spec[..., 0], 0.0, 10.0) # Hue: [0, 10]
|
| 112 |
+
clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0) # Value: [1, 9] (colour library constraint)
|
| 113 |
+
clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0) # Chroma: [0, 50]
|
| 114 |
+
clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0) # Code: [1, 10]
|
| 115 |
+
return clamped
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]:
|
| 119 |
+
"""Compute Delta-E for predictions."""
|
| 120 |
+
delta_E_values = []
|
| 121 |
+
for idx in range(len(pred)):
|
| 122 |
+
try:
|
| 123 |
+
ml_spec = clamp_munsell_specification(pred[idx])
|
| 124 |
+
ml_spec_for_conversion = ml_spec.copy()
|
| 125 |
+
ml_spec_for_conversion[3] = round(ml_spec[3])
|
| 126 |
+
ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
|
| 127 |
+
ml_XYZ = xyY_to_XYZ(ml_xyy)
|
| 128 |
+
ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 129 |
+
delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
|
| 130 |
+
delta_E_values.append(delta_E)
|
| 131 |
+
except (RuntimeError, ValueError):
|
| 132 |
+
continue
|
| 133 |
+
return delta_E_values
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def main() -> None:
|
| 137 |
+
"""Compare gamma model against baseline."""
|
| 138 |
+
LOGGER.info("=" * 80)
|
| 139 |
+
LOGGER.info("Gamma Model Comparison: MLP vs MLP (Gamma 2.33)")
|
| 140 |
+
LOGGER.info("=" * 80)
|
| 141 |
+
|
| 142 |
+
models_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 143 |
+
|
| 144 |
+
# Load real Munsell data
|
| 145 |
+
LOGGER.info("\nLoading real Munsell colours...")
|
| 146 |
+
xyY_values = []
|
| 147 |
+
munsell_specs = []
|
| 148 |
+
reference_Lab = []
|
| 149 |
+
|
| 150 |
+
for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
|
| 151 |
+
try:
|
| 152 |
+
hue_code, value, chroma = munsell_spec_tuple
|
| 153 |
+
munsell_str = f"{hue_code} {value}/{chroma}"
|
| 154 |
+
spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 155 |
+
xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
|
| 156 |
+
|
| 157 |
+
XYZ = xyY_to_XYZ(xyY_scaled)
|
| 158 |
+
Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 159 |
+
|
| 160 |
+
xyY_values.append(xyY_scaled)
|
| 161 |
+
munsell_specs.append(spec)
|
| 162 |
+
reference_Lab.append(Lab)
|
| 163 |
+
except (RuntimeError, ValueError):
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
xyY_array = np.array(xyY_values)
|
| 167 |
+
ground_truth = np.array(munsell_specs)
|
| 168 |
+
reference_Lab = np.array(reference_Lab)
|
| 169 |
+
|
| 170 |
+
LOGGER.info("Loaded %d real Munsell colours", len(xyY_array))
|
| 171 |
+
|
| 172 |
+
# Test baseline MLP
|
| 173 |
+
LOGGER.info("\n" + "-" * 40)
|
| 174 |
+
LOGGER.info("1. MLP (Base) - Standard Normalization")
|
| 175 |
+
LOGGER.info("-" * 40)
|
| 176 |
+
|
| 177 |
+
base_onnx = models_dir / "mlp.onnx"
|
| 178 |
+
base_params_file = models_dir / "mlp_normalization_params.npz"
|
| 179 |
+
|
| 180 |
+
if base_onnx.exists() and base_params_file.exists():
|
| 181 |
+
base_session = ort.InferenceSession(str(base_onnx))
|
| 182 |
+
base_params_data = np.load(base_params_file, allow_pickle=True)
|
| 183 |
+
base_input_params = base_params_data["input_params"].item()
|
| 184 |
+
base_output_params = base_params_data["output_params"].item()
|
| 185 |
+
|
| 186 |
+
X_norm_base = normalize_input_standard(xyY_array, base_input_params)
|
| 187 |
+
pred_norm = base_session.run(None, {"xyY": X_norm_base})[0]
|
| 188 |
+
pred_base = denormalize_output(pred_norm, base_output_params)
|
| 189 |
+
|
| 190 |
+
errors_base = np.abs(pred_base - ground_truth)
|
| 191 |
+
delta_E_base = compute_delta_e(pred_base, reference_Lab)
|
| 192 |
+
|
| 193 |
+
LOGGER.info(" Hue MAE: %.4f", np.mean(errors_base[:, 0]))
|
| 194 |
+
LOGGER.info(" Value MAE: %.4f", np.mean(errors_base[:, 1]))
|
| 195 |
+
LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_base[:, 2]))
|
| 196 |
+
LOGGER.info(" Code MAE: %.4f", np.mean(errors_base[:, 3]))
|
| 197 |
+
LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
|
| 198 |
+
np.mean(delta_E_base), np.median(delta_E_base))
|
| 199 |
+
else:
|
| 200 |
+
LOGGER.info(" Model not found, skipping...")
|
| 201 |
+
delta_E_base = []
|
| 202 |
+
|
| 203 |
+
# Test gamma MLP
|
| 204 |
+
LOGGER.info("\n" + "-" * 40)
|
| 205 |
+
LOGGER.info("2. MLP (Gamma 2.33) - Gamma-Corrected Y")
|
| 206 |
+
LOGGER.info("-" * 40)
|
| 207 |
+
|
| 208 |
+
gamma_onnx = models_dir / "mlp_gamma.onnx"
|
| 209 |
+
gamma_params_file = models_dir / "mlp_gamma_normalization_params.npz"
|
| 210 |
+
|
| 211 |
+
if gamma_onnx.exists() and gamma_params_file.exists():
|
| 212 |
+
gamma_session = ort.InferenceSession(str(gamma_onnx))
|
| 213 |
+
gamma_params_data = np.load(gamma_params_file, allow_pickle=True)
|
| 214 |
+
gamma_input_params = gamma_params_data["input_params"].item()
|
| 215 |
+
gamma_output_params = gamma_params_data["output_params"].item()
|
| 216 |
+
|
| 217 |
+
X_norm_gamma = normalize_input_gamma(xyY_array, gamma_input_params)
|
| 218 |
+
pred_norm = gamma_session.run(None, {"xyY_gamma": X_norm_gamma})[0]
|
| 219 |
+
pred_gamma = denormalize_output(pred_norm, gamma_output_params)
|
| 220 |
+
|
| 221 |
+
errors_gamma = np.abs(pred_gamma - ground_truth)
|
| 222 |
+
delta_E_gamma = compute_delta_e(pred_gamma, reference_Lab)
|
| 223 |
+
|
| 224 |
+
LOGGER.info(" Hue MAE: %.4f", np.mean(errors_gamma[:, 0]))
|
| 225 |
+
LOGGER.info(" Value MAE: %.4f", np.mean(errors_gamma[:, 1]))
|
| 226 |
+
LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_gamma[:, 2]))
|
| 227 |
+
LOGGER.info(" Code MAE: %.4f", np.mean(errors_gamma[:, 3]))
|
| 228 |
+
LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
|
| 229 |
+
np.mean(delta_E_gamma), np.median(delta_E_gamma))
|
| 230 |
+
else:
|
| 231 |
+
LOGGER.info(" Model not found, skipping...")
|
| 232 |
+
delta_E_gamma = []
|
| 233 |
+
|
| 234 |
+
# Summary comparison for MLP
|
| 235 |
+
if delta_E_base and delta_E_gamma:
|
| 236 |
+
LOGGER.info("\n" + "=" * 80)
|
| 237 |
+
LOGGER.info("MLP COMPARISON SUMMARY")
|
| 238 |
+
LOGGER.info("=" * 80)
|
| 239 |
+
LOGGER.info("")
|
| 240 |
+
LOGGER.info("Delta-E (lower is better):")
|
| 241 |
+
LOGGER.info(" MLP (Base): %.4f mean, %.4f median",
|
| 242 |
+
np.mean(delta_E_base), np.median(delta_E_base))
|
| 243 |
+
LOGGER.info(" MLP (Gamma): %.4f mean, %.4f median",
|
| 244 |
+
np.mean(delta_E_gamma), np.median(delta_E_gamma))
|
| 245 |
+
LOGGER.info("")
|
| 246 |
+
|
| 247 |
+
improvement = (np.mean(delta_E_base) - np.mean(delta_E_gamma)) / np.mean(delta_E_base) * 100
|
| 248 |
+
if improvement > 0:
|
| 249 |
+
LOGGER.info(" Gamma model is %.1f%% BETTER", improvement)
|
| 250 |
+
else:
|
| 251 |
+
LOGGER.info(" Gamma model is %.1f%% WORSE", -improvement)
|
| 252 |
+
|
| 253 |
+
# Test Multi-Head baseline
|
| 254 |
+
LOGGER.info("\n" + "=" * 80)
|
| 255 |
+
LOGGER.info("MULTI-HEAD GAMMA EXPERIMENT")
|
| 256 |
+
LOGGER.info("=" * 80)
|
| 257 |
+
|
| 258 |
+
LOGGER.info("\n" + "-" * 40)
|
| 259 |
+
LOGGER.info("3. Multi-Head (Base) - Standard Normalization")
|
| 260 |
+
LOGGER.info("-" * 40)
|
| 261 |
+
|
| 262 |
+
mh_base_onnx = models_dir / "multi_head.onnx"
|
| 263 |
+
mh_base_params_file = models_dir / "multi_head_normalization_params.npz"
|
| 264 |
+
|
| 265 |
+
if mh_base_onnx.exists() and mh_base_params_file.exists():
|
| 266 |
+
mh_base_session = ort.InferenceSession(str(mh_base_onnx))
|
| 267 |
+
mh_base_params_data = np.load(mh_base_params_file, allow_pickle=True)
|
| 268 |
+
mh_base_input_params = mh_base_params_data["input_params"].item()
|
| 269 |
+
mh_base_output_params = mh_base_params_data["output_params"].item()
|
| 270 |
+
|
| 271 |
+
X_norm_mh_base = normalize_input_standard(xyY_array, mh_base_input_params)
|
| 272 |
+
pred_norm = mh_base_session.run(None, {"xyY": X_norm_mh_base})[0]
|
| 273 |
+
pred_mh_base = denormalize_output(pred_norm, mh_base_output_params)
|
| 274 |
+
|
| 275 |
+
errors_mh_base = np.abs(pred_mh_base - ground_truth)
|
| 276 |
+
delta_E_mh_base = compute_delta_e(pred_mh_base, reference_Lab)
|
| 277 |
+
|
| 278 |
+
LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_base[:, 0]))
|
| 279 |
+
LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_base[:, 1]))
|
| 280 |
+
LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_base[:, 2]))
|
| 281 |
+
LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_base[:, 3]))
|
| 282 |
+
LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
|
| 283 |
+
np.mean(delta_E_mh_base), np.median(delta_E_mh_base))
|
| 284 |
+
else:
|
| 285 |
+
LOGGER.info(" Model not found, skipping...")
|
| 286 |
+
delta_E_mh_base = []
|
| 287 |
+
|
| 288 |
+
# Test Multi-Head gamma
|
| 289 |
+
LOGGER.info("\n" + "-" * 40)
|
| 290 |
+
LOGGER.info("4. Multi-Head (Gamma 2.33) - Gamma-Corrected Y")
|
| 291 |
+
LOGGER.info("-" * 40)
|
| 292 |
+
|
| 293 |
+
mh_gamma_onnx = models_dir / "multi_head_gamma.onnx"
|
| 294 |
+
mh_gamma_params_file = models_dir / "multi_head_gamma_normalization_params.npz"
|
| 295 |
+
|
| 296 |
+
if mh_gamma_onnx.exists() and mh_gamma_params_file.exists():
|
| 297 |
+
mh_gamma_session = ort.InferenceSession(str(mh_gamma_onnx))
|
| 298 |
+
mh_gamma_params_data = np.load(mh_gamma_params_file, allow_pickle=True)
|
| 299 |
+
mh_gamma_input_params = mh_gamma_params_data["input_params"].item()
|
| 300 |
+
mh_gamma_output_params = mh_gamma_params_data["output_params"].item()
|
| 301 |
+
|
| 302 |
+
X_norm_mh_gamma = normalize_input_gamma(xyY_array, mh_gamma_input_params)
|
| 303 |
+
pred_norm = mh_gamma_session.run(None, {"xyY_gamma": X_norm_mh_gamma})[0]
|
| 304 |
+
pred_mh_gamma = denormalize_output(pred_norm, mh_gamma_output_params)
|
| 305 |
+
|
| 306 |
+
errors_mh_gamma = np.abs(pred_mh_gamma - ground_truth)
|
| 307 |
+
delta_E_mh_gamma = compute_delta_e(pred_mh_gamma, reference_Lab)
|
| 308 |
+
|
| 309 |
+
LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_gamma[:, 0]))
|
| 310 |
+
LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_gamma[:, 1]))
|
| 311 |
+
LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_gamma[:, 2]))
|
| 312 |
+
LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_gamma[:, 3]))
|
| 313 |
+
LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
|
| 314 |
+
np.mean(delta_E_mh_gamma), np.median(delta_E_mh_gamma))
|
| 315 |
+
else:
|
| 316 |
+
LOGGER.info(" Model not found, skipping...")
|
| 317 |
+
delta_E_mh_gamma = []
|
| 318 |
+
|
| 319 |
+
# Test Multi-Head ST.2084
|
| 320 |
+
LOGGER.info("\n" + "-" * 40)
|
| 321 |
+
LOGGER.info("5. Multi-Head (ST.2084) - PQ-Encoded Y")
|
| 322 |
+
LOGGER.info("-" * 40)
|
| 323 |
+
|
| 324 |
+
mh_st2084_onnx = models_dir / "multi_head_st2084.onnx"
|
| 325 |
+
mh_st2084_params_file = models_dir / "multi_head_st2084_normalization_params.npz"
|
| 326 |
+
|
| 327 |
+
if mh_st2084_onnx.exists() and mh_st2084_params_file.exists():
|
| 328 |
+
mh_st2084_session = ort.InferenceSession(str(mh_st2084_onnx))
|
| 329 |
+
mh_st2084_params_data = np.load(mh_st2084_params_file, allow_pickle=True)
|
| 330 |
+
mh_st2084_input_params = mh_st2084_params_data["input_params"].item()
|
| 331 |
+
mh_st2084_output_params = mh_st2084_params_data["output_params"].item()
|
| 332 |
+
|
| 333 |
+
X_norm_mh_st2084 = normalize_input_st2084(xyY_array, mh_st2084_input_params)
|
| 334 |
+
pred_norm = mh_st2084_session.run(None, {"xyY_st2084": X_norm_mh_st2084})[0]
|
| 335 |
+
pred_mh_st2084 = denormalize_output(pred_norm, mh_st2084_output_params)
|
| 336 |
+
|
| 337 |
+
errors_mh_st2084 = np.abs(pred_mh_st2084 - ground_truth)
|
| 338 |
+
delta_E_mh_st2084 = compute_delta_e(pred_mh_st2084, reference_Lab)
|
| 339 |
+
|
| 340 |
+
LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_st2084[:, 0]))
|
| 341 |
+
LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_st2084[:, 1]))
|
| 342 |
+
LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_st2084[:, 2]))
|
| 343 |
+
LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_st2084[:, 3]))
|
| 344 |
+
LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
|
| 345 |
+
np.mean(delta_E_mh_st2084), np.median(delta_E_mh_st2084))
|
| 346 |
+
else:
|
| 347 |
+
LOGGER.info(" Model not found, skipping...")
|
| 348 |
+
delta_E_mh_st2084 = []
|
| 349 |
+
|
| 350 |
+
# Summary comparison for Multi-Head
|
| 351 |
+
if delta_E_mh_base and delta_E_mh_gamma:
|
| 352 |
+
LOGGER.info("\n" + "=" * 80)
|
| 353 |
+
LOGGER.info("MULTI-HEAD COMPARISON SUMMARY")
|
| 354 |
+
LOGGER.info("=" * 80)
|
| 355 |
+
LOGGER.info("")
|
| 356 |
+
LOGGER.info("Delta-E (lower is better):")
|
| 357 |
+
LOGGER.info(" Multi-Head (Base): %.4f mean, %.4f median",
|
| 358 |
+
np.mean(delta_E_mh_base), np.median(delta_E_mh_base))
|
| 359 |
+
LOGGER.info(" Multi-Head (Gamma): %.4f mean, %.4f median",
|
| 360 |
+
np.mean(delta_E_mh_gamma), np.median(delta_E_mh_gamma))
|
| 361 |
+
if delta_E_mh_st2084:
|
| 362 |
+
LOGGER.info(" Multi-Head (ST.2084): %.4f mean, %.4f median",
|
| 363 |
+
np.mean(delta_E_mh_st2084), np.median(delta_E_mh_st2084))
|
| 364 |
+
LOGGER.info("")
|
| 365 |
+
|
| 366 |
+
mh_gamma_improvement = (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_gamma)) / np.mean(delta_E_mh_base) * 100
|
| 367 |
+
if mh_gamma_improvement > 0:
|
| 368 |
+
LOGGER.info(" Multi-Head Gamma vs Base: %.1f%% BETTER", mh_gamma_improvement)
|
| 369 |
+
else:
|
| 370 |
+
LOGGER.info(" Multi-Head Gamma vs Base: %.1f%% WORSE", -mh_gamma_improvement)
|
| 371 |
+
|
| 372 |
+
if delta_E_mh_st2084:
|
| 373 |
+
mh_st2084_improvement = (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_st2084)) / np.mean(delta_E_mh_base) * 100
|
| 374 |
+
if mh_st2084_improvement > 0:
|
| 375 |
+
LOGGER.info(" Multi-Head ST.2084 vs Base: %.1f%% BETTER", mh_st2084_improvement)
|
| 376 |
+
else:
|
| 377 |
+
LOGGER.info(" Multi-Head ST.2084 vs Base: %.1f%% WORSE", -mh_st2084_improvement)
|
| 378 |
+
|
| 379 |
+
# Compare ST.2084 vs Gamma
|
| 380 |
+
st2084_vs_gamma = (np.mean(delta_E_mh_gamma) - np.mean(delta_E_mh_st2084)) / np.mean(delta_E_mh_gamma) * 100
|
| 381 |
+
if st2084_vs_gamma > 0:
|
| 382 |
+
LOGGER.info(" Multi-Head ST.2084 vs Gamma: %.1f%% BETTER", st2084_vs_gamma)
|
| 383 |
+
else:
|
| 384 |
+
LOGGER.info(" Multi-Head ST.2084 vs Gamma: %.1f%% WORSE", -st2084_vs_gamma)
|
| 385 |
+
|
| 386 |
+
LOGGER.info("\n" + "=" * 80)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
main()
|
learning_munsell/comparison/to_xyY/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Comparison scripts for Munsell to xyY conversion models."""
|
learning_munsell/comparison/to_xyY/compare_all_models.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compare all ML models for Munsell to xyY conversion on real Munsell data.
|
| 3 |
+
|
| 4 |
+
Models to compare:
|
| 5 |
+
1. Simple MLP Approximator
|
| 6 |
+
2. Multi-Head MLP
|
| 7 |
+
3. Multi-Head MLP (Optimized) - with hyperparameter optimization
|
| 8 |
+
4. Multi-Head + Multi-Error Predictor
|
| 9 |
+
5. Multi-MLP - 3 independent branches
|
| 10 |
+
6. Multi-MLP (Optimized) - 3 independent branches with optimized hyperparameters
|
| 11 |
+
7. Multi-MLP + Error Predictor
|
| 12 |
+
8. Multi-MLP + Multi-Error Predictor
|
| 13 |
+
9. Multi-MLP (Optimized) + Multi-Error Predictor (Optimized)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
import time
|
| 20 |
+
import warnings
|
| 21 |
+
from typing import TYPE_CHECKING
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import onnxruntime as ort
|
| 25 |
+
from colour import XYZ_to_Lab, xyY_to_XYZ
|
| 26 |
+
from colour.difference import delta_E_CIE2000
|
| 27 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
|
| 28 |
+
from colour.notation.munsell import (
|
| 29 |
+
CCS_ILLUMINANT_MUNSELL,
|
| 30 |
+
munsell_colour_to_munsell_specification,
|
| 31 |
+
munsell_specification_to_xyY,
|
| 32 |
+
)
|
| 33 |
+
from numpy.typing import NDArray # noqa: TC002
|
| 34 |
+
|
| 35 |
+
from learning_munsell import PROJECT_ROOT
|
| 36 |
+
from learning_munsell.utilities.common import (
|
| 37 |
+
benchmark_inference_speed,
|
| 38 |
+
generate_html_report_footer,
|
| 39 |
+
generate_html_report_header,
|
| 40 |
+
generate_ranking_section,
|
| 41 |
+
get_model_size_mb,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if TYPE_CHECKING:
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
|
| 47 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 48 |
+
LOGGER = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_munsell(munsell: np.ndarray) -> np.ndarray:
|
| 52 |
+
"""Normalize Munsell specs to [0, 1] range."""
|
| 53 |
+
normalized = munsell.copy()
|
| 54 |
+
normalized[..., 0] = munsell[..., 0] / 10.0 # Hue (in decade)
|
| 55 |
+
normalized[..., 1] = munsell[..., 1] / 10.0 # Value
|
| 56 |
+
normalized[..., 2] = munsell[..., 2] / 50.0 # Chroma
|
| 57 |
+
normalized[..., 3] = munsell[..., 3] / 10.0 # Code
|
| 58 |
+
return normalized.astype(np.float32)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def evaluate_model(
|
| 62 |
+
session: ort.InferenceSession,
|
| 63 |
+
X_norm: np.ndarray,
|
| 64 |
+
ground_truth: np.ndarray,
|
| 65 |
+
input_name: str = "munsell_normalized",
|
| 66 |
+
) -> dict:
|
| 67 |
+
"""Evaluate a single model."""
|
| 68 |
+
pred = session.run(None, {input_name: X_norm})[0]
|
| 69 |
+
errors = np.abs(pred - ground_truth)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"x_mae": np.mean(errors[:, 0]),
|
| 73 |
+
"y_mae": np.mean(errors[:, 1]),
|
| 74 |
+
"Y_mae": np.mean(errors[:, 2]),
|
| 75 |
+
"predictions": pred,
|
| 76 |
+
"errors": errors,
|
| 77 |
+
"max_errors": np.max(errors, axis=1),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def compute_delta_E(
|
| 82 |
+
ml_predictions: np.ndarray,
|
| 83 |
+
reference_xyY: np.ndarray,
|
| 84 |
+
) -> float:
|
| 85 |
+
"""Compute Delta-E CIE2000 between ML predictions and reference xyY (ground truth)."""
|
| 86 |
+
delta_E_values = []
|
| 87 |
+
|
| 88 |
+
for ml_xyY, ref_xyY in zip(ml_predictions, reference_xyY, strict=False):
|
| 89 |
+
try:
|
| 90 |
+
ml_XYZ = xyY_to_XYZ(ml_xyY)
|
| 91 |
+
ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 92 |
+
|
| 93 |
+
ref_XYZ = xyY_to_XYZ(ref_xyY)
|
| 94 |
+
ref_Lab = XYZ_to_Lab(ref_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 95 |
+
|
| 96 |
+
delta_E = delta_E_CIE2000(ref_Lab, ml_Lab)
|
| 97 |
+
if not np.isnan(delta_E):
|
| 98 |
+
delta_E_values.append(delta_E)
|
| 99 |
+
except (RuntimeError, ValueError):
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
return np.mean(delta_E_values) if delta_E_values else np.nan
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def generate_html_report(
|
| 106 |
+
results: dict,
|
| 107 |
+
num_samples: int,
|
| 108 |
+
output_file: Path,
|
| 109 |
+
baseline_inference_time_ms: float,
|
| 110 |
+
) -> None:
|
| 111 |
+
"""Generate HTML report with visualizations."""
|
| 112 |
+
# Calculate average MAE
|
| 113 |
+
avg_maes = {}
|
| 114 |
+
for model_name, result in results.items():
|
| 115 |
+
avg_maes[model_name] = np.mean(
|
| 116 |
+
[
|
| 117 |
+
result["x_mae"],
|
| 118 |
+
result["y_mae"],
|
| 119 |
+
result["Y_mae"],
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Sort by average MAE
|
| 124 |
+
sorted_models = sorted(avg_maes.items(), key=lambda x: x[1])
|
| 125 |
+
|
| 126 |
+
# Start HTML
|
| 127 |
+
html = generate_html_report_header(
|
| 128 |
+
title="ML Model Comparison Report",
|
| 129 |
+
subtitle="Munsell to xyY Conversion",
|
| 130 |
+
num_samples=num_samples,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Best Models Summary
|
| 134 |
+
best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0]
|
| 135 |
+
best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0]
|
| 136 |
+
best_avg = sorted_models[0][0]
|
| 137 |
+
|
| 138 |
+
# Find best Delta-E
|
| 139 |
+
delta_E_results = [
|
| 140 |
+
(n, r["delta_E"]) for n, r in results.items() if not np.isnan(r["delta_E"])
|
| 141 |
+
]
|
| 142 |
+
best_delta_E = (
|
| 143 |
+
min(delta_E_results, key=lambda x: x[1])[0] if delta_E_results else None
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
html += f"""
|
| 147 |
+
<!-- Best Models Summary -->
|
| 148 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 149 |
+
<h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Best Models by Metric</h2>
|
| 150 |
+
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
|
| 151 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 152 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Smallest Size</div>
|
| 153 |
+
<div class="text-3xl font-bold text-primary mb-3">{results[best_size]["model_size_mb"]:.2f} MB</div>
|
| 154 |
+
<div class="text-sm text-foreground/80">{best_size}</div>
|
| 155 |
+
</div>
|
| 156 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 157 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Fastest Speed</div>
|
| 158 |
+
<div class="text-3xl font-bold text-primary mb-3">{results[best_speed]["inference_time_ms"]:.4f} ms</div>
|
| 159 |
+
<div class="text-sm text-foreground/80">{best_speed}</div>
|
| 160 |
+
</div>
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
if best_delta_E:
|
| 164 |
+
html += f"""
|
| 165 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 166 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Delta-E</div>
|
| 167 |
+
<div class="text-3xl font-bold text-primary mb-3">{results[best_delta_E]["delta_E"]:.6f}</div>
|
| 168 |
+
<div class="text-sm text-foreground/80">{best_delta_E}</div>
|
| 169 |
+
</div>
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
html += f"""
|
| 173 |
+
<div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
|
| 174 |
+
<div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Average MAE</div>
|
| 175 |
+
<div class="text-3xl font-bold text-primary mb-3">{avg_maes[best_avg]:.6f}</div>
|
| 176 |
+
<div class="text-sm text-foreground/80">{best_avg}</div>
|
| 177 |
+
</div>
|
| 178 |
+
</div>
|
| 179 |
+
</div>
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
# Performance Metrics Table
|
| 183 |
+
sorted_by_avg_mae = sorted(results.items(), key=lambda x: avg_maes[x[0]])
|
| 184 |
+
|
| 185 |
+
html += """
|
| 186 |
+
<!-- Performance Metrics Table -->
|
| 187 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 188 |
+
<h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Model Performance Metrics</h2>
|
| 189 |
+
<div class="overflow-x-auto">
|
| 190 |
+
<table class="w-full text-sm">
|
| 191 |
+
<thead>
|
| 192 |
+
<tr class="border-b border-border">
|
| 193 |
+
<th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
|
| 194 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">Size (MB)</th>
|
| 195 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">Speed (ms/sample)</th>
|
| 196 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">vs Baseline</th>
|
| 197 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE x</th>
|
| 198 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE y</th>
|
| 199 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE Y</th>
|
| 200 |
+
<th class="text-right py-3 px-4 font-semibold text-muted-foreground">Delta-E</th>
|
| 201 |
+
</tr>
|
| 202 |
+
</thead>
|
| 203 |
+
<tbody>
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
for model_name, result in sorted_by_avg_mae:
|
| 207 |
+
size_mb = result["model_size_mb"]
|
| 208 |
+
speed_ms = result["inference_time_ms"]
|
| 209 |
+
delta_E = result["delta_E"]
|
| 210 |
+
|
| 211 |
+
# Calculate speedup vs baseline
|
| 212 |
+
speedup = baseline_inference_time_ms / speed_ms if speed_ms > 0 else 0
|
| 213 |
+
|
| 214 |
+
size_class = "text-primary font-semibold" if model_name == best_size else ""
|
| 215 |
+
speed_class = "text-primary font-semibold" if model_name == best_speed else ""
|
| 216 |
+
delta_E_class = (
|
| 217 |
+
"text-primary font-semibold" if model_name == best_delta_E else ""
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
delta_E_str = f"{delta_E:.6f}" if not np.isnan(delta_E) else "—"
|
| 221 |
+
|
| 222 |
+
speedup_text = f"{speedup:.0f}x" if speedup > 100 else f"{speedup:.1f}x"
|
| 223 |
+
|
| 224 |
+
html += f"""
|
| 225 |
+
<tr class="border-b border-border/50 hover:bg-muted/30 transition-colors">
|
| 226 |
+
<td class="py-3 px-4 font-medium">{model_name}</td>
|
| 227 |
+
<td class="py-3 px-4 text-right {size_class}">{size_mb:.2f}</td>
|
| 228 |
+
<td class="py-3 px-4 text-right {speed_class}">{speed_ms:.4f}</td>
|
| 229 |
+
<td class="py-3 px-4 text-right text-primary font-semibold">{speedup_text}</td>
|
| 230 |
+
<td class="py-3 px-4 text-right">{result["x_mae"]:.6f}</td>
|
| 231 |
+
<td class="py-3 px-4 text-right">{result["y_mae"]:.6f}</td>
|
| 232 |
+
<td class="py-3 px-4 text-right">{result["Y_mae"]:.6f}</td>
|
| 233 |
+
<td class="py-3 px-4 text-right {delta_E_class}">{delta_E_str}</td>
|
| 234 |
+
</tr>
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
html += """
|
| 238 |
+
</tbody>
|
| 239 |
+
</table>
|
| 240 |
+
</div>
|
| 241 |
+
</div>
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
# Add ranking section
|
| 245 |
+
html += generate_ranking_section(
|
| 246 |
+
results,
|
| 247 |
+
metric_key="avg_mae",
|
| 248 |
+
title="Overall Ranking (by Average MAE)",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Precision thresholds
|
| 252 |
+
thresholds = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
|
| 253 |
+
|
| 254 |
+
html += """
|
| 255 |
+
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
|
| 256 |
+
<h2 class="text-2xl font-semibold mb-3 pb-3 border-b border-primary/30">Accuracy at Precision Thresholds</h2>
|
| 257 |
+
<p class="text-sm text-muted-foreground mb-6">Percentage of predictions where max error across all components is below threshold:</p>
|
| 258 |
+
<div class="overflow-x-auto">
|
| 259 |
+
<table class="w-full text-sm">
|
| 260 |
+
<thead>
|
| 261 |
+
<tr class="border-b border-border">
|
| 262 |
+
<th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
for threshold in thresholds:
|
| 266 |
+
html += f' <th class="text-right py-3 px-4 font-semibold text-muted-foreground">< {threshold:.0e}</th>\n'
|
| 267 |
+
|
| 268 |
+
html += """
|
| 269 |
+
</tr>
|
| 270 |
+
</thead>
|
| 271 |
+
<tbody>
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
for model_name, _ in sorted_models:
|
| 275 |
+
result = results[model_name]
|
| 276 |
+
html += f"""
|
| 277 |
+
<tr class="border-b border-border hover:bg-muted/30 transition-colors">
|
| 278 |
+
<td class="text-left py-3 px-4 font-medium">{model_name}</td>
|
| 279 |
+
"""
|
| 280 |
+
for threshold in thresholds:
|
| 281 |
+
accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
|
| 282 |
+
html += f' <td class="text-right py-3 px-4">{accuracy_pct:.2f}%</td>\n'
|
| 283 |
+
|
| 284 |
+
html += """
|
| 285 |
+
</tr>
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
html += """
|
| 289 |
+
</tbody>
|
| 290 |
+
</table>
|
| 291 |
+
</div>
|
| 292 |
+
</div>
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
html += generate_html_report_footer()
|
| 296 |
+
|
| 297 |
+
# Write HTML file
|
| 298 |
+
with open(output_file, "w") as f:
|
| 299 |
+
f.write(html)
|
| 300 |
+
|
| 301 |
+
LOGGER.info("")
|
| 302 |
+
LOGGER.info("HTML report saved to: %s", output_file)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def main() -> None:
|
| 306 |
+
"""Compare all models."""
|
| 307 |
+
LOGGER.info("=" * 80)
|
| 308 |
+
LOGGER.info("Munsell to xyY Model Comparison")
|
| 309 |
+
LOGGER.info("=" * 80)
|
| 310 |
+
|
| 311 |
+
# Paths
|
| 312 |
+
model_directory = PROJECT_ROOT / "models" / "to_xyY"
|
| 313 |
+
|
| 314 |
+
# Load real Munsell dataset
|
| 315 |
+
LOGGER.info("")
|
| 316 |
+
LOGGER.info("Loading real Munsell dataset...")
|
| 317 |
+
munsell_specs = []
|
| 318 |
+
xyY_ground_truth = []
|
| 319 |
+
|
| 320 |
+
for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
|
| 321 |
+
try:
|
| 322 |
+
hue_code, value, chroma = munsell_spec_tuple
|
| 323 |
+
munsell_str = f"{hue_code} {value}/{chroma}"
|
| 324 |
+
spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 325 |
+
xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
|
| 326 |
+
munsell_specs.append(spec)
|
| 327 |
+
xyY_ground_truth.append(xyY_scaled)
|
| 328 |
+
except Exception: # noqa: BLE001, S112
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
munsell_specs = np.array(munsell_specs, dtype=np.float32)
|
| 332 |
+
xyY_ground_truth = np.array(xyY_ground_truth, dtype=np.float32)
|
| 333 |
+
LOGGER.info("Loaded %d valid Munsell colors", len(munsell_specs))
|
| 334 |
+
|
| 335 |
+
# Normalize inputs
|
| 336 |
+
munsell_normalized = normalize_munsell(munsell_specs)
|
| 337 |
+
|
| 338 |
+
# Benchmark colour library first
|
| 339 |
+
LOGGER.info("")
|
| 340 |
+
LOGGER.info("=" * 80)
|
| 341 |
+
LOGGER.info("Colour Library (munsell_specification_to_xyY)")
|
| 342 |
+
LOGGER.info("=" * 80)
|
| 343 |
+
|
| 344 |
+
# Benchmark the munsell_specification_to_xyY function
|
| 345 |
+
# Note: Using full dataset (100% of samples)
|
| 346 |
+
|
| 347 |
+
# Set random seed for reproducibility
|
| 348 |
+
np.random.seed(42)
|
| 349 |
+
|
| 350 |
+
# Use 100% of samples for comprehensive benchmarking
|
| 351 |
+
sampled_indices = np.arange(len(munsell_specs))
|
| 352 |
+
munsell_benchmark = munsell_specs[sampled_indices]
|
| 353 |
+
|
| 354 |
+
start_time = time.perf_counter()
|
| 355 |
+
colour_predictions = []
|
| 356 |
+
successful_inferences = 0
|
| 357 |
+
|
| 358 |
+
with warnings.catch_warnings():
|
| 359 |
+
warnings.simplefilter("ignore")
|
| 360 |
+
for spec in munsell_benchmark:
|
| 361 |
+
try:
|
| 362 |
+
xyY = munsell_specification_to_xyY(spec)
|
| 363 |
+
colour_predictions.append(xyY)
|
| 364 |
+
successful_inferences += 1
|
| 365 |
+
except (RuntimeError, ValueError):
|
| 366 |
+
colour_predictions.append(np.array([np.nan, np.nan, np.nan]))
|
| 367 |
+
|
| 368 |
+
end_time = time.perf_counter()
|
| 369 |
+
|
| 370 |
+
total_time_s = end_time - start_time
|
| 371 |
+
baseline_inference_time_ms = (
|
| 372 |
+
(total_time_s / successful_inferences) * 1000
|
| 373 |
+
if successful_inferences > 0
|
| 374 |
+
else 0
|
| 375 |
+
)
|
| 376 |
+
colour_predictions = np.array(colour_predictions)
|
| 377 |
+
|
| 378 |
+
LOGGER.info(" Successful inferences: %d", successful_inferences)
|
| 379 |
+
LOGGER.info(" Inference Speed: %.4f ms/sample", baseline_inference_time_ms)
|
| 380 |
+
|
| 381 |
+
# Define models to compare
|
| 382 |
+
models = [
|
| 383 |
+
{
|
| 384 |
+
"name": "Simple MLP",
|
| 385 |
+
"files": [model_directory / "munsell_to_xyY_approximator.onnx"],
|
| 386 |
+
"params_file": model_directory
|
| 387 |
+
/ "munsell_to_xyY_approximator_normalization_params.npz",
|
| 388 |
+
"type": "single",
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"name": "Multi-Head",
|
| 392 |
+
"files": [model_directory / "multi_head.onnx"],
|
| 393 |
+
"params_file": model_directory / "multi_head_normalization_params.npz",
|
| 394 |
+
"type": "single",
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"name": "Multi-Head (Optimized)",
|
| 398 |
+
"files": [model_directory / "multi_head_optimized.onnx"],
|
| 399 |
+
"params_file": model_directory
|
| 400 |
+
/ "multi_head_optimized_normalization_params.npz",
|
| 401 |
+
"type": "single",
|
| 402 |
+
},
|
| 403 |
+
{
|
| 404 |
+
"name": "Multi-Head + Multi-Error Predictor",
|
| 405 |
+
"files": [
|
| 406 |
+
model_directory / "multi_head.onnx",
|
| 407 |
+
model_directory / "multi_head_multi_error_predictor.onnx",
|
| 408 |
+
],
|
| 409 |
+
"params_file": model_directory
|
| 410 |
+
/ "multi_head_multi_error_predictor_normalization_params.npz",
|
| 411 |
+
"type": "two_stage",
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"name": "Multi-MLP",
|
| 415 |
+
"files": [model_directory / "multi_mlp.onnx"],
|
| 416 |
+
"params_file": model_directory / "multi_mlp_normalization_params.npz",
|
| 417 |
+
"type": "single",
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"name": "Multi-MLP (Optimized)",
|
| 421 |
+
"files": [model_directory / "multi_mlp_optimized.onnx"],
|
| 422 |
+
"params_file": model_directory
|
| 423 |
+
/ "multi_mlp_optimized_normalization_params.npz",
|
| 424 |
+
"type": "single",
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"name": "Multi-MLP + Error Predictor",
|
| 428 |
+
"files": [
|
| 429 |
+
model_directory / "multi_mlp.onnx",
|
| 430 |
+
model_directory / "multi_mlp_error_predictor.onnx",
|
| 431 |
+
],
|
| 432 |
+
"params_file": model_directory
|
| 433 |
+
/ "multi_mlp_error_predictor_normalization_params.npz",
|
| 434 |
+
"type": "two_stage",
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"name": "Multi-MLP + Multi-Error Predictor",
|
| 438 |
+
"files": [
|
| 439 |
+
model_directory / "multi_mlp.onnx",
|
| 440 |
+
model_directory / "multi_mlp_multi_error_predictor.onnx",
|
| 441 |
+
],
|
| 442 |
+
"params_file": model_directory
|
| 443 |
+
/ "multi_mlp_multi_error_predictor_normalization_params.npz",
|
| 444 |
+
"type": "two_stage",
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"name": "Multi-MLP (Optimized) + Multi-Error Predictor (Optimized)",
|
| 448 |
+
"files": [
|
| 449 |
+
model_directory / "multi_mlp_optimized.onnx",
|
| 450 |
+
model_directory / "multi_mlp_multi_error_predictor_optimized.onnx",
|
| 451 |
+
],
|
| 452 |
+
"params_file": model_directory
|
| 453 |
+
/ "multi_mlp_multi_error_predictor_optimized_normalization_params.npz",
|
| 454 |
+
"type": "two_stage",
|
| 455 |
+
},
|
| 456 |
+
]
|
| 457 |
+
|
| 458 |
+
# Evaluate each model
|
| 459 |
+
results = {}
|
| 460 |
+
|
| 461 |
+
for model_info in models:
|
| 462 |
+
model_name = model_info["name"]
|
| 463 |
+
LOGGER.info("")
|
| 464 |
+
LOGGER.info("=" * 80)
|
| 465 |
+
LOGGER.info(model_name)
|
| 466 |
+
LOGGER.info("=" * 80)
|
| 467 |
+
|
| 468 |
+
# Calculate model size
|
| 469 |
+
model_size_mb = get_model_size_mb(model_info["files"])
|
| 470 |
+
|
| 471 |
+
if model_info["type"] == "two_stage":
|
| 472 |
+
# Two-stage model
|
| 473 |
+
base_session = ort.InferenceSession(str(model_info["files"][0]))
|
| 474 |
+
error_session = ort.InferenceSession(str(model_info["files"][1]))
|
| 475 |
+
error_input_name = error_session.get_inputs()[0].name
|
| 476 |
+
|
| 477 |
+
# Define inference callable
|
| 478 |
+
def two_stage_inference(
|
| 479 |
+
_base_session: ort.InferenceSession = base_session,
|
| 480 |
+
_error_session: ort.InferenceSession = error_session,
|
| 481 |
+
_munsell_normalized: NDArray = munsell_normalized,
|
| 482 |
+
_error_input_name: str = error_input_name,
|
| 483 |
+
) -> NDArray:
|
| 484 |
+
base_pred = _base_session.run(
|
| 485 |
+
None, {"munsell_normalized": _munsell_normalized}
|
| 486 |
+
)[0]
|
| 487 |
+
combined = np.concatenate(
|
| 488 |
+
[_munsell_normalized, base_pred], axis=1
|
| 489 |
+
).astype(np.float32)
|
| 490 |
+
error_corr = _error_session.run(
|
| 491 |
+
None, {_error_input_name: combined}
|
| 492 |
+
)[0]
|
| 493 |
+
return base_pred + error_corr
|
| 494 |
+
|
| 495 |
+
# Benchmark speed
|
| 496 |
+
inference_time_ms = benchmark_inference_speed(
|
| 497 |
+
two_stage_inference, munsell_normalized
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Get predictions
|
| 501 |
+
base_pred = base_session.run(
|
| 502 |
+
None, {"munsell_normalized": munsell_normalized}
|
| 503 |
+
)[0]
|
| 504 |
+
combined = np.concatenate(
|
| 505 |
+
[munsell_normalized, base_pred], axis=1
|
| 506 |
+
).astype(np.float32)
|
| 507 |
+
error_corr = error_session.run(
|
| 508 |
+
None, {error_input_name: combined}
|
| 509 |
+
)[0]
|
| 510 |
+
pred = base_pred + error_corr
|
| 511 |
+
|
| 512 |
+
errors = np.abs(pred - xyY_ground_truth)
|
| 513 |
+
result = {
|
| 514 |
+
"x_mae": np.mean(errors[:, 0]),
|
| 515 |
+
"y_mae": np.mean(errors[:, 1]),
|
| 516 |
+
"Y_mae": np.mean(errors[:, 2]),
|
| 517 |
+
"predictions": pred,
|
| 518 |
+
"errors": errors,
|
| 519 |
+
"max_errors": np.max(errors, axis=1),
|
| 520 |
+
}
|
| 521 |
+
else:
|
| 522 |
+
# Single model
|
| 523 |
+
session = ort.InferenceSession(str(model_info["files"][0]))
|
| 524 |
+
|
| 525 |
+
# Define inference callable
|
| 526 |
+
def single_inference(
|
| 527 |
+
_session: ort.InferenceSession = session,
|
| 528 |
+
_munsell_normalized: NDArray = munsell_normalized,
|
| 529 |
+
) -> NDArray:
|
| 530 |
+
return _session.run(
|
| 531 |
+
None, {"munsell_normalized": _munsell_normalized}
|
| 532 |
+
)[0]
|
| 533 |
+
|
| 534 |
+
# Benchmark speed
|
| 535 |
+
inference_time_ms = benchmark_inference_speed(
|
| 536 |
+
single_inference, munsell_normalized
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
result = evaluate_model(session, munsell_normalized, xyY_ground_truth)
|
| 540 |
+
|
| 541 |
+
result["model_size_mb"] = model_size_mb
|
| 542 |
+
result["inference_time_ms"] = inference_time_ms
|
| 543 |
+
result["avg_mae"] = np.mean(
|
| 544 |
+
[result["x_mae"], result["y_mae"], result["Y_mae"]]
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
# Compute Delta-E against ground truth (measured xyY)
|
| 548 |
+
sampled_predictions = result["predictions"][sampled_indices]
|
| 549 |
+
result["delta_E"] = compute_delta_E(
|
| 550 |
+
sampled_predictions,
|
| 551 |
+
xyY_ground_truth,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
results[model_name] = result
|
| 555 |
+
|
| 556 |
+
# Print results
|
| 557 |
+
LOGGER.info("")
|
| 558 |
+
LOGGER.info("Mean Absolute Errors:")
|
| 559 |
+
LOGGER.info(" x: %.6f", result["x_mae"])
|
| 560 |
+
LOGGER.info(" y: %.6f", result["y_mae"])
|
| 561 |
+
LOGGER.info(" Y: %.6f", result["Y_mae"])
|
| 562 |
+
if not np.isnan(result["delta_E"]):
|
| 563 |
+
LOGGER.info(" Delta-E (vs Ground Truth): %.6f", result["delta_E"])
|
| 564 |
+
LOGGER.info("")
|
| 565 |
+
LOGGER.info("Performance Metrics:")
|
| 566 |
+
LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"])
|
| 567 |
+
LOGGER.info(
|
| 568 |
+
" Inference Speed: %.4f ms/sample", result["inference_time_ms"]
|
| 569 |
+
)
|
| 570 |
+
LOGGER.info(
|
| 571 |
+
" Speedup vs Colour: %.1fx",
|
| 572 |
+
baseline_inference_time_ms / inference_time_ms,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
# Summary
|
| 577 |
+
LOGGER.info("")
|
| 578 |
+
LOGGER.info("=" * 80)
|
| 579 |
+
LOGGER.info("SUMMARY COMPARISON")
|
| 580 |
+
LOGGER.info("=" * 80)
|
| 581 |
+
LOGGER.info("")
|
| 582 |
+
|
| 583 |
+
if not results:
|
| 584 |
+
LOGGER.info("No models were successfully evaluated")
|
| 585 |
+
return
|
| 586 |
+
|
| 587 |
+
# MAE comparison table
|
| 588 |
+
LOGGER.info("Mean Absolute Error Comparison:")
|
| 589 |
+
LOGGER.info("")
|
| 590 |
+
header = f"{'Model':<40} {'x':>10} {'y':>10} {'Y':>10} {'Delta-E':>12}"
|
| 591 |
+
LOGGER.info(header)
|
| 592 |
+
LOGGER.info("-" * 85)
|
| 593 |
+
|
| 594 |
+
for model_name, result in results.items():
|
| 595 |
+
delta_E_str = (
|
| 596 |
+
f"{result['delta_E']:.6f}" if not np.isnan(result["delta_E"]) else "N/A"
|
| 597 |
+
)
|
| 598 |
+
LOGGER.info(
|
| 599 |
+
"%-40s %10.6f %10.6f %10.6f %12s",
|
| 600 |
+
model_name,
|
| 601 |
+
result["x_mae"],
|
| 602 |
+
result["y_mae"],
|
| 603 |
+
result["Y_mae"],
|
| 604 |
+
delta_E_str,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# Generate HTML report
|
| 608 |
+
report_dir = PROJECT_ROOT / "reports" / "to_xyY"
|
| 609 |
+
report_dir.mkdir(parents=True, exist_ok=True)
|
| 610 |
+
report_file = report_dir / "model_comparison.html"
|
| 611 |
+
generate_html_report(
|
| 612 |
+
results, len(munsell_specs), report_file, baseline_inference_time_ms
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
if __name__ == "__main__":
|
| 617 |
+
main()
|
learning_munsell/data_generation/generate_training_data.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate training data for ML-based xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
Generates samples by sampling in Munsell space and converting to xyY via
|
| 5 |
+
forward conversion, guaranteeing 100% valid samples.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
uv run python -m learning_munsell.data_generation.generate_training_data
|
| 9 |
+
uv run python -m learning_munsell.data_generation.generate_training_data \\
|
| 10 |
+
--n-samples 2000000 --perturbation 0.10 --output training_data_large
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import multiprocessing as mp
|
| 17 |
+
import warnings
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
|
| 22 |
+
from colour.notation.munsell import (
|
| 23 |
+
munsell_colour_to_munsell_specification,
|
| 24 |
+
munsell_specification_to_xyY,
|
| 25 |
+
)
|
| 26 |
+
from colour.utilities import ColourUsageWarning
|
| 27 |
+
from numpy.typing import NDArray
|
| 28 |
+
from sklearn.model_selection import train_test_split
|
| 29 |
+
|
| 30 |
+
from learning_munsell import PROJECT_ROOT
|
| 31 |
+
|
| 32 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 33 |
+
LOGGER = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _worker_generate_samples(
|
| 37 |
+
args: tuple[int, NDArray, int, float],
|
| 38 |
+
) -> tuple[list[NDArray], list[NDArray]]:
|
| 39 |
+
"""
|
| 40 |
+
Worker function to generate samples in parallel.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
args : tuple
|
| 45 |
+
- worker_id: Worker identifier
|
| 46 |
+
- base_specs: Array of base Munsell specifications
|
| 47 |
+
- samples_per_base: Number of samples to generate per base color
|
| 48 |
+
- perturbation_pct: Perturbation percentage
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
tuple
|
| 53 |
+
- xyY_samples: List of xyY arrays
|
| 54 |
+
- munsell_samples: List of Munsell specification arrays
|
| 55 |
+
"""
|
| 56 |
+
worker_id, base_specs, samples_per_base, perturbation_pct = args
|
| 57 |
+
|
| 58 |
+
np.random.seed(42 + worker_id)
|
| 59 |
+
|
| 60 |
+
warnings.filterwarnings("ignore", category=ColourUsageWarning)
|
| 61 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
| 62 |
+
|
| 63 |
+
xyY_samples = []
|
| 64 |
+
munsell_samples = []
|
| 65 |
+
|
| 66 |
+
hue_range = 9.5
|
| 67 |
+
value_range = 9.0
|
| 68 |
+
chroma_range = 50.0
|
| 69 |
+
|
| 70 |
+
for base_spec in base_specs:
|
| 71 |
+
for _ in range(samples_per_base):
|
| 72 |
+
hue_delta = np.random.uniform(
|
| 73 |
+
-perturbation_pct * hue_range, perturbation_pct * hue_range
|
| 74 |
+
)
|
| 75 |
+
value_delta = np.random.uniform(
|
| 76 |
+
-perturbation_pct * value_range, perturbation_pct * value_range
|
| 77 |
+
)
|
| 78 |
+
chroma_delta = np.random.uniform(
|
| 79 |
+
-perturbation_pct * chroma_range, perturbation_pct * chroma_range
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
perturbed_spec = base_spec.copy()
|
| 83 |
+
perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0)
|
| 84 |
+
perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0)
|
| 85 |
+
perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0)
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
xyY = munsell_specification_to_xyY(perturbed_spec)
|
| 89 |
+
xyY_samples.append(xyY)
|
| 90 |
+
munsell_samples.append(perturbed_spec)
|
| 91 |
+
except Exception: # noqa: BLE001, S110
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
return xyY_samples, munsell_samples
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def generate_forward_munsell_samples(
|
| 98 |
+
n_samples: int = 500000,
|
| 99 |
+
perturbation_pct: float = 0.05,
|
| 100 |
+
n_workers: int | None = None,
|
| 101 |
+
) -> tuple[NDArray, NDArray]:
|
| 102 |
+
"""
|
| 103 |
+
Generate samples by sampling directly in Munsell space and converting to xyY.
|
| 104 |
+
|
| 105 |
+
Parameters
|
| 106 |
+
----------
|
| 107 |
+
n_samples : int
|
| 108 |
+
Target number of samples to generate.
|
| 109 |
+
perturbation_pct : float
|
| 110 |
+
Perturbation as percentage of valid range.
|
| 111 |
+
n_workers : int, optional
|
| 112 |
+
Number of parallel workers. Defaults to CPU count.
|
| 113 |
+
|
| 114 |
+
Returns
|
| 115 |
+
-------
|
| 116 |
+
tuple
|
| 117 |
+
- xyY_samples: Array of shape (n, 3) with xyY values
|
| 118 |
+
- munsell_samples: Array of shape (n, 4) with Munsell specifications
|
| 119 |
+
"""
|
| 120 |
+
if n_workers is None:
|
| 121 |
+
n_workers = mp.cpu_count()
|
| 122 |
+
|
| 123 |
+
LOGGER.info(
|
| 124 |
+
"Generating %d samples with %.0f%% perturbations using %d workers...",
|
| 125 |
+
n_samples,
|
| 126 |
+
perturbation_pct * 100,
|
| 127 |
+
n_workers,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Extract base Munsell specifications
|
| 131 |
+
base_specs = []
|
| 132 |
+
for munsell_spec_tuple, _ in MUNSELL_COLOURS_ALL:
|
| 133 |
+
hue_code_str, value, chroma = munsell_spec_tuple
|
| 134 |
+
munsell_str = f"{hue_code_str} {value}/{chroma}"
|
| 135 |
+
spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 136 |
+
base_specs.append(spec)
|
| 137 |
+
|
| 138 |
+
base_specs = np.array(base_specs)
|
| 139 |
+
samples_per_base = n_samples // len(base_specs) + 1
|
| 140 |
+
|
| 141 |
+
LOGGER.info("Using %d base Munsell colors", len(base_specs))
|
| 142 |
+
LOGGER.info("Generating ~%d samples per base color", samples_per_base)
|
| 143 |
+
|
| 144 |
+
# Split base specs across workers
|
| 145 |
+
specs_per_worker = len(base_specs) // n_workers
|
| 146 |
+
worker_args = []
|
| 147 |
+
|
| 148 |
+
for i in range(n_workers):
|
| 149 |
+
start_idx = i * specs_per_worker
|
| 150 |
+
end_idx = start_idx + specs_per_worker if i < n_workers - 1 else len(base_specs)
|
| 151 |
+
worker_specs = base_specs[start_idx:end_idx]
|
| 152 |
+
worker_args.append((i, worker_specs, samples_per_base, perturbation_pct))
|
| 153 |
+
|
| 154 |
+
# Run in parallel
|
| 155 |
+
LOGGER.info("Starting %d parallel workers...", n_workers)
|
| 156 |
+
with mp.Pool(n_workers) as pool:
|
| 157 |
+
results = pool.map(_worker_generate_samples, worker_args)
|
| 158 |
+
|
| 159 |
+
# Combine results
|
| 160 |
+
all_xyY = []
|
| 161 |
+
all_munsell = []
|
| 162 |
+
for xyY_samples, munsell_samples in results:
|
| 163 |
+
all_xyY.extend(xyY_samples)
|
| 164 |
+
all_munsell.extend(munsell_samples)
|
| 165 |
+
|
| 166 |
+
# Trim to exact sample count
|
| 167 |
+
all_xyY = all_xyY[:n_samples]
|
| 168 |
+
all_munsell = all_munsell[:n_samples]
|
| 169 |
+
|
| 170 |
+
LOGGER.info("Generated %d valid samples", len(all_xyY))
|
| 171 |
+
return np.array(all_xyY), np.array(all_munsell)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def main(
|
| 175 |
+
n_samples: int = 500000,
|
| 176 |
+
perturbation_pct: float = 0.05,
|
| 177 |
+
output: str = "training_data",
|
| 178 |
+
) -> None:
|
| 179 |
+
"""Generate and save training data."""
|
| 180 |
+
LOGGER.info("=" * 80)
|
| 181 |
+
LOGGER.info("Training Data Generation")
|
| 182 |
+
LOGGER.info("=" * 80)
|
| 183 |
+
|
| 184 |
+
output_dir = PROJECT_ROOT / "data"
|
| 185 |
+
output_dir.mkdir(exist_ok=True)
|
| 186 |
+
|
| 187 |
+
LOGGER.info("")
|
| 188 |
+
LOGGER.info("SAMPLING STRATEGY")
|
| 189 |
+
LOGGER.info("=" * 80)
|
| 190 |
+
LOGGER.info("Forward Munsell->xyY sampling:")
|
| 191 |
+
LOGGER.info(
|
| 192 |
+
" - Base: %d colors from MUNSELL_COLOURS_ALL", len(MUNSELL_COLOURS_ALL)
|
| 193 |
+
)
|
| 194 |
+
LOGGER.info(
|
| 195 |
+
" - Perturbations: +/-%.0f%% of valid range per component",
|
| 196 |
+
perturbation_pct * 100,
|
| 197 |
+
)
|
| 198 |
+
LOGGER.info(
|
| 199 |
+
" - Hue: +/-%.2f (+/-%.0f%% of 9.5 range)",
|
| 200 |
+
perturbation_pct * 9.5,
|
| 201 |
+
perturbation_pct * 100,
|
| 202 |
+
)
|
| 203 |
+
LOGGER.info(
|
| 204 |
+
" - Value: +/-%.2f (+/-%.0f%% of 9.0 range)",
|
| 205 |
+
perturbation_pct * 9.0,
|
| 206 |
+
perturbation_pct * 100,
|
| 207 |
+
)
|
| 208 |
+
LOGGER.info(
|
| 209 |
+
" - Chroma: +/-%.1f (+/-%.0f%% of 50.0 range)",
|
| 210 |
+
perturbation_pct * 50.0,
|
| 211 |
+
perturbation_pct * 100,
|
| 212 |
+
)
|
| 213 |
+
LOGGER.info(" - Target samples: %d", n_samples)
|
| 214 |
+
LOGGER.info("=" * 80)
|
| 215 |
+
LOGGER.info("")
|
| 216 |
+
|
| 217 |
+
# Generate samples
|
| 218 |
+
xyY_all, munsell_all = generate_forward_munsell_samples(
|
| 219 |
+
n_samples=n_samples,
|
| 220 |
+
perturbation_pct=perturbation_pct,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
valid_mask = np.ones(len(xyY_all), dtype=bool)
|
| 224 |
+
|
| 225 |
+
LOGGER.info("")
|
| 226 |
+
LOGGER.info("Sample statistics:")
|
| 227 |
+
LOGGER.info(" Total samples generated: %d", len(xyY_all))
|
| 228 |
+
LOGGER.info(" All samples are valid (100%% by forward conversion)")
|
| 229 |
+
|
| 230 |
+
LOGGER.info("")
|
| 231 |
+
LOGGER.info("Using %d valid samples for training", len(xyY_all))
|
| 232 |
+
|
| 233 |
+
# Split into train/validation/test (70/15/15)
|
| 234 |
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 235 |
+
xyY_all, munsell_all, test_size=0.15, random_state=42
|
| 236 |
+
)
|
| 237 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 238 |
+
X_temp, y_temp, test_size=0.15 / 0.85, random_state=42
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
LOGGER.info("")
|
| 242 |
+
LOGGER.info("Data split:")
|
| 243 |
+
LOGGER.info(" Train: %d samples", len(X_train))
|
| 244 |
+
LOGGER.info(" Validation: %d samples", len(X_val))
|
| 245 |
+
LOGGER.info(" Test: %d samples", len(X_test))
|
| 246 |
+
|
| 247 |
+
# Save training data
|
| 248 |
+
cache_file = output_dir / f"{output}.npz"
|
| 249 |
+
np.savez_compressed(
|
| 250 |
+
cache_file,
|
| 251 |
+
X_train=X_train,
|
| 252 |
+
y_train=y_train,
|
| 253 |
+
X_val=X_val,
|
| 254 |
+
y_val=y_val,
|
| 255 |
+
X_test=X_test,
|
| 256 |
+
y_test=y_test,
|
| 257 |
+
xyY_all=xyY_all,
|
| 258 |
+
munsell_all=munsell_all,
|
| 259 |
+
valid_mask=valid_mask,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Save parameters to sidecar file
|
| 263 |
+
params_file = output_dir / f"{output}_params.json"
|
| 264 |
+
params = {
|
| 265 |
+
"n_samples": n_samples,
|
| 266 |
+
"perturbation_pct": perturbation_pct,
|
| 267 |
+
"n_base_colors": len(MUNSELL_COLOURS_ALL),
|
| 268 |
+
"train_samples": len(X_train),
|
| 269 |
+
"val_samples": len(X_val),
|
| 270 |
+
"test_samples": len(X_test),
|
| 271 |
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 272 |
+
}
|
| 273 |
+
with open(params_file, "w") as f:
|
| 274 |
+
json.dump(params, f, indent=2)
|
| 275 |
+
|
| 276 |
+
LOGGER.info("")
|
| 277 |
+
LOGGER.info("Training data saved to: %s", cache_file)
|
| 278 |
+
LOGGER.info("Parameters saved to: %s", params_file)
|
| 279 |
+
LOGGER.info("=" * 80)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
parser = argparse.ArgumentParser(
|
| 284 |
+
description="Generate training data for xyY to Munsell conversion"
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--n-samples",
|
| 288 |
+
type=int,
|
| 289 |
+
default=500000,
|
| 290 |
+
help="Number of samples to generate (default: 500000)",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--perturbation",
|
| 294 |
+
type=float,
|
| 295 |
+
default=0.05,
|
| 296 |
+
help="Perturbation as fraction of valid range (default: 0.05)",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--output",
|
| 300 |
+
type=str,
|
| 301 |
+
default="training_data",
|
| 302 |
+
help="Output filename without extension (default: training_data)",
|
| 303 |
+
)
|
| 304 |
+
args = parser.parse_args()
|
| 305 |
+
|
| 306 |
+
main(
|
| 307 |
+
n_samples=args.n_samples,
|
| 308 |
+
perturbation_pct=args.perturbation,
|
| 309 |
+
output=args.output,
|
| 310 |
+
)
|
learning_munsell/interpolation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Interpolation-based methods for Munsell conversions."""
|
learning_munsell/interpolation/from_xyY/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Interpolation-based methods for xyY to Munsell conversions."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
|
| 5 |
+
from colour.notation.munsell import munsell_colour_to_munsell_specification
|
| 6 |
+
from numpy.typing import NDArray
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_munsell_reference_data() -> tuple[NDArray, NDArray]:
|
| 10 |
+
"""
|
| 11 |
+
Load reference Munsell data from colour library.
|
| 12 |
+
|
| 13 |
+
Returns xyY coordinates and corresponding Munsell specifications
|
| 14 |
+
[hue, value, chroma, code] for all 4,995 reference colors.
|
| 15 |
+
|
| 16 |
+
The Y values are normalized to [0, 1] range (originally 0-102.57).
|
| 17 |
+
|
| 18 |
+
Returns
|
| 19 |
+
-------
|
| 20 |
+
Tuple[NDArray, NDArray]
|
| 21 |
+
X : xyY values of shape (4995, 3) with Y normalized to [0, 1]
|
| 22 |
+
y : Munsell specifications of shape (4995, 4)
|
| 23 |
+
"""
|
| 24 |
+
xyY_list = []
|
| 25 |
+
munsell_list = []
|
| 26 |
+
|
| 27 |
+
for munsell_tuple, xyY in MUNSELL_COLOURS_ALL:
|
| 28 |
+
hue_name, value, chroma = munsell_tuple
|
| 29 |
+
munsell_string = f"{hue_name} {value}/{chroma}"
|
| 30 |
+
|
| 31 |
+
# Convert to numeric specification [hue, value, chroma, code]
|
| 32 |
+
spec = munsell_colour_to_munsell_specification(munsell_string)
|
| 33 |
+
|
| 34 |
+
# Normalize Y to [0, 1] range (max ~102.57)
|
| 35 |
+
xyY_normalized = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
|
| 36 |
+
|
| 37 |
+
xyY_list.append(xyY_normalized)
|
| 38 |
+
munsell_list.append(spec)
|
| 39 |
+
|
| 40 |
+
return np.array(xyY_list), np.array(munsell_list)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
__all__ = ["load_munsell_reference_data"]
|
learning_munsell/interpolation/from_xyY/compare_methods.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compare classical interpolation methods against the best ML model.
|
| 3 |
+
|
| 4 |
+
Evaluates RBF, KD-Tree, and Delaunay interpolation on REAL Munsell colors
|
| 5 |
+
and compares with the Multi-Head (W+B) + Multi-Error Predictor (W+B) model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import onnxruntime as ort
|
| 12 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
|
| 13 |
+
from colour.notation.munsell import munsell_colour_to_munsell_specification
|
| 14 |
+
from scipy.interpolate import LinearNDInterpolator, RBFInterpolator
|
| 15 |
+
from scipy.spatial import KDTree
|
| 16 |
+
from sklearn.model_selection import train_test_split
|
| 17 |
+
|
| 18 |
+
from learning_munsell import PROJECT_ROOT
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 21 |
+
LOGGER = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_reference_data():
|
| 25 |
+
"""Load ALL Munsell colors as training data for interpolators."""
|
| 26 |
+
X, y = [], []
|
| 27 |
+
for munsell_tuple, xyY in MUNSELL_COLOURS_ALL:
|
| 28 |
+
hue_name, value, chroma = munsell_tuple
|
| 29 |
+
munsell_str = f"{hue_name} {value}/{chroma}"
|
| 30 |
+
spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 31 |
+
# Normalize Y to [0, 1]
|
| 32 |
+
X.append([xyY[0], xyY[1], xyY[2] / 100.0])
|
| 33 |
+
y.append(spec)
|
| 34 |
+
return np.array(X), np.array(y)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def evaluate(predictions, y_true, method_name):
|
| 40 |
+
"""Calculate MAE for each component."""
|
| 41 |
+
errors = np.abs(predictions - y_true)
|
| 42 |
+
results = {
|
| 43 |
+
"hue": errors[:, 0].mean(),
|
| 44 |
+
"value": errors[:, 1].mean(),
|
| 45 |
+
"chroma": errors[:, 2].mean(),
|
| 46 |
+
"code": errors[:, 3].mean(),
|
| 47 |
+
}
|
| 48 |
+
LOGGER.info(" %s:", method_name)
|
| 49 |
+
for comp in ["hue", "value", "chroma", "code"]:
|
| 50 |
+
LOGGER.info(" %s MAE: %.4f", comp.capitalize(), results[comp])
|
| 51 |
+
return results
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def rbf_predict(X_train, y_train, X_test):
|
| 55 |
+
"""RBF interpolation prediction."""
|
| 56 |
+
predictions = np.zeros((len(X_test), 4))
|
| 57 |
+
for i in range(4):
|
| 58 |
+
rbf = RBFInterpolator(X_train, y_train[:, i], kernel="thin_plate_spline")
|
| 59 |
+
predictions[:, i] = rbf(X_test)
|
| 60 |
+
return predictions
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def kdtree_predict(X_train, y_train, X_test, k=5):
|
| 64 |
+
"""KD-Tree with inverse distance weighting prediction."""
|
| 65 |
+
tree = KDTree(X_train)
|
| 66 |
+
distances, indices = tree.query(X_test, k=k)
|
| 67 |
+
distances = np.maximum(distances, 1e-10)
|
| 68 |
+
weights = 1.0 / (distances**2)
|
| 69 |
+
weights /= weights.sum(axis=1, keepdims=True)
|
| 70 |
+
|
| 71 |
+
predictions = np.zeros((len(X_test), 4))
|
| 72 |
+
for i in range(len(X_test)):
|
| 73 |
+
predictions[i] = np.sum(weights[i, :, np.newaxis] * y_train[indices[i]], axis=0)
|
| 74 |
+
return predictions
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def delaunay_predict(X_train, y_train, X_test):
|
| 78 |
+
"""Delaunay interpolation with NN fallback."""
|
| 79 |
+
predictions = np.zeros((len(X_test), 4))
|
| 80 |
+
tree = KDTree(X_train)
|
| 81 |
+
|
| 82 |
+
for i in range(4):
|
| 83 |
+
interp = LinearNDInterpolator(X_train, y_train[:, i])
|
| 84 |
+
predictions[:, i] = interp(X_test)
|
| 85 |
+
|
| 86 |
+
# Fallback to nearest neighbor for NaN
|
| 87 |
+
nan_mask = np.any(np.isnan(predictions), axis=1)
|
| 88 |
+
if nan_mask.sum() > 0:
|
| 89 |
+
_, indices = tree.query(X_test[nan_mask])
|
| 90 |
+
predictions[nan_mask] = y_train[indices]
|
| 91 |
+
|
| 92 |
+
return predictions
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def ml_predict(X_test):
|
| 96 |
+
"""ML model prediction using base + error predictor."""
|
| 97 |
+
base_path = PROJECT_ROOT / "models" / "from_xyY" / "multi_head_weighted_boundary.onnx"
|
| 98 |
+
error_path = (
|
| 99 |
+
PROJECT_ROOT
|
| 100 |
+
/ "models"
|
| 101 |
+
/ "from_xyY"
|
| 102 |
+
/ "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if not base_path.exists() or not error_path.exists():
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
# Input is already normalized to [0, 1] for x, y, Y
|
| 109 |
+
X_norm = X_test.astype(np.float32)
|
| 110 |
+
|
| 111 |
+
# Base model prediction
|
| 112 |
+
base_session = ort.InferenceSession(str(base_path))
|
| 113 |
+
base_out = base_session.run(None, {"xyY": X_norm})[0]
|
| 114 |
+
|
| 115 |
+
# Error predictor (takes xyY + base predictions)
|
| 116 |
+
error_session = ort.InferenceSession(str(error_path))
|
| 117 |
+
combined_input = np.concatenate([X_norm, base_out], axis=1).astype(np.float32)
|
| 118 |
+
error_out = error_session.run(None, {"combined_input": combined_input})[0]
|
| 119 |
+
|
| 120 |
+
# Combined prediction (normalized)
|
| 121 |
+
pred_norm = base_out + error_out
|
| 122 |
+
|
| 123 |
+
# Denormalize using actual ranges from params file
|
| 124 |
+
predictions = np.zeros_like(pred_norm)
|
| 125 |
+
predictions[:, 0] = pred_norm[:, 0] * (10.0 - 0.5) + 0.5 # Hue: [0.5, 10]
|
| 126 |
+
predictions[:, 1] = pred_norm[:, 1] * (10.0 - 0.0) + 0.0 # Value: [0, 10]
|
| 127 |
+
predictions[:, 2] = pred_norm[:, 2] * (50.0 - 0.0) + 0.0 # Chroma: [0, 50]
|
| 128 |
+
predictions[:, 3] = pred_norm[:, 3] * (10.0 - 1.0) + 1.0 # Code: [1, 10]
|
| 129 |
+
|
| 130 |
+
return predictions
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main():
|
| 134 |
+
"""Compare all methods using held-out test set."""
|
| 135 |
+
LOGGER.info("=" * 80)
|
| 136 |
+
LOGGER.info("Classical Interpolation vs ML Model Comparison")
|
| 137 |
+
LOGGER.info("=" * 80)
|
| 138 |
+
|
| 139 |
+
LOGGER.info("")
|
| 140 |
+
LOGGER.info("Loading data...")
|
| 141 |
+
X_all, y_all = load_reference_data()
|
| 142 |
+
|
| 143 |
+
# 80/20 train/test split for fair comparison
|
| 144 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 145 |
+
X_all, y_all, test_size=0.2, random_state=42
|
| 146 |
+
)
|
| 147 |
+
LOGGER.info(" Total: %d colors", len(X_all))
|
| 148 |
+
LOGGER.info(" Training: %d colors (80%%)", len(X_train))
|
| 149 |
+
LOGGER.info(" Test: %d colors (20%%)", len(X_test))
|
| 150 |
+
|
| 151 |
+
results = {}
|
| 152 |
+
|
| 153 |
+
# RBF
|
| 154 |
+
LOGGER.info("")
|
| 155 |
+
LOGGER.info("-" * 60)
|
| 156 |
+
LOGGER.info("RBF Interpolation (thin_plate_spline)")
|
| 157 |
+
rbf_pred = rbf_predict(X_train, y_train, X_test)
|
| 158 |
+
results["RBF"] = evaluate(rbf_pred, y_test, "RBF")
|
| 159 |
+
|
| 160 |
+
# KD-Tree
|
| 161 |
+
LOGGER.info("")
|
| 162 |
+
LOGGER.info("-" * 60)
|
| 163 |
+
LOGGER.info("KD-Tree Interpolation (k=5, IDW)")
|
| 164 |
+
kdt_pred = kdtree_predict(X_train, y_train, X_test, k=5)
|
| 165 |
+
results["KD-Tree"] = evaluate(kdt_pred, y_test, "KD-Tree")
|
| 166 |
+
|
| 167 |
+
# Delaunay
|
| 168 |
+
LOGGER.info("")
|
| 169 |
+
LOGGER.info("-" * 60)
|
| 170 |
+
LOGGER.info("Delaunay Interpolation (with NN fallback)")
|
| 171 |
+
del_pred = delaunay_predict(X_train, y_train, X_test)
|
| 172 |
+
results["Delaunay"] = evaluate(del_pred, y_test, "Delaunay")
|
| 173 |
+
|
| 174 |
+
# ML
|
| 175 |
+
LOGGER.info("")
|
| 176 |
+
LOGGER.info("-" * 60)
|
| 177 |
+
LOGGER.info("ML Model (Multi-Head W+B + Multi-Error Predictor W+B)")
|
| 178 |
+
ml_pred = ml_predict(X_test)
|
| 179 |
+
if ml_pred is not None:
|
| 180 |
+
results["ML"] = evaluate(ml_pred, y_test, "ML")
|
| 181 |
+
else:
|
| 182 |
+
LOGGER.info(" Skipped (model not found)")
|
| 183 |
+
|
| 184 |
+
# Summary
|
| 185 |
+
LOGGER.info("")
|
| 186 |
+
LOGGER.info("=" * 80)
|
| 187 |
+
LOGGER.info("SUMMARY (MAE on %d held-out test colors)", len(X_test))
|
| 188 |
+
LOGGER.info("=" * 80)
|
| 189 |
+
LOGGER.info("")
|
| 190 |
+
LOGGER.info("%-12s %8s %8s %8s %8s", "Method", "Hue", "Value", "Chroma", "Code")
|
| 191 |
+
LOGGER.info("-" * 52)
|
| 192 |
+
|
| 193 |
+
for method, mae in results.items():
|
| 194 |
+
LOGGER.info(
|
| 195 |
+
"%-12s %8.4f %8.4f %8.4f %8.4f",
|
| 196 |
+
method,
|
| 197 |
+
mae["hue"],
|
| 198 |
+
mae["value"],
|
| 199 |
+
mae["chroma"],
|
| 200 |
+
mae["code"],
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
LOGGER.info("")
|
| 204 |
+
LOGGER.info("=" * 80)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
main()
|
learning_munsell/interpolation/from_xyY/delaunay_interpolator.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Delaunay triangulation based interpolation for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This approach uses scipy's LinearNDInterpolator which performs piecewise
|
| 5 |
+
linear interpolation based on Delaunay triangulation.
|
| 6 |
+
|
| 7 |
+
Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
|
| 8 |
+
|
| 9 |
+
Advantages:
|
| 10 |
+
- Piecewise linear: exact at data points, linear between
|
| 11 |
+
- Handles irregular point distributions
|
| 12 |
+
- No hyperparameters to tune
|
| 13 |
+
|
| 14 |
+
Disadvantages:
|
| 15 |
+
- Returns NaN outside convex hull of data points
|
| 16 |
+
- Non-convex Munsell boundary may cause issues
|
| 17 |
+
- C0 continuous only (discontinuous gradients at cell boundaries)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import pickle
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from numpy.typing import NDArray
|
| 26 |
+
from scipy.interpolate import LinearNDInterpolator
|
| 27 |
+
from scipy.spatial import KDTree
|
| 28 |
+
from sklearn.model_selection import train_test_split
|
| 29 |
+
|
| 30 |
+
from learning_munsell import PROJECT_ROOT, setup_logging
|
| 31 |
+
from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MunsellDelaunayInterpolator:
|
| 38 |
+
"""
|
| 39 |
+
Delaunay triangulation based interpolator for xyY to Munsell conversion.
|
| 40 |
+
|
| 41 |
+
Uses LinearNDInterpolator for piecewise linear interpolation within
|
| 42 |
+
the Delaunay triangulation. Falls back to nearest neighbor for points
|
| 43 |
+
outside the convex hull.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, fallback_to_nearest: bool = True) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Initialize the Delaunay interpolator.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
fallback_to_nearest
|
| 53 |
+
If True, use nearest neighbor for points outside convex hull.
|
| 54 |
+
If False, return NaN for such points.
|
| 55 |
+
"""
|
| 56 |
+
self.fallback_to_nearest = fallback_to_nearest
|
| 57 |
+
self.interpolators: dict = {}
|
| 58 |
+
self.kdtree: KDTree | None = None
|
| 59 |
+
self.y_data: NDArray | None = None
|
| 60 |
+
self.fitted = False
|
| 61 |
+
|
| 62 |
+
def fit(self, X: NDArray, y: NDArray) -> "MunsellDelaunayInterpolator":
|
| 63 |
+
"""
|
| 64 |
+
Build the Delaunay interpolator from training data.
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
X
|
| 69 |
+
xyY input values of shape (n, 3)
|
| 70 |
+
y
|
| 71 |
+
Munsell output values [hue, value, chroma, code] of shape (n, 4)
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
-------
|
| 75 |
+
self
|
| 76 |
+
"""
|
| 77 |
+
LOGGER.info("Building Delaunay interpolator...")
|
| 78 |
+
LOGGER.info(" Fallback to nearest: %s", self.fallback_to_nearest)
|
| 79 |
+
LOGGER.info(" Data points: %d", len(X))
|
| 80 |
+
|
| 81 |
+
component_names = ["hue", "value", "chroma", "code"]
|
| 82 |
+
|
| 83 |
+
for i, name in enumerate(component_names):
|
| 84 |
+
LOGGER.info(" Building %s interpolator...", name)
|
| 85 |
+
self.interpolators[name] = LinearNDInterpolator(X, y[:, i])
|
| 86 |
+
|
| 87 |
+
# Build KDTree for nearest neighbor fallback
|
| 88 |
+
if self.fallback_to_nearest:
|
| 89 |
+
LOGGER.info(" Building KD-Tree for fallback...")
|
| 90 |
+
self.kdtree = KDTree(X)
|
| 91 |
+
self.y_data = y.copy()
|
| 92 |
+
|
| 93 |
+
self.fitted = True
|
| 94 |
+
LOGGER.info("Delaunay interpolator built successfully")
|
| 95 |
+
return self
|
| 96 |
+
|
| 97 |
+
def predict(self, X: NDArray) -> NDArray:
|
| 98 |
+
"""
|
| 99 |
+
Predict Munsell values using Delaunay interpolation.
|
| 100 |
+
|
| 101 |
+
Parameters
|
| 102 |
+
----------
|
| 103 |
+
X
|
| 104 |
+
xyY input values of shape (n, 3)
|
| 105 |
+
|
| 106 |
+
Returns
|
| 107 |
+
-------
|
| 108 |
+
NDArray
|
| 109 |
+
Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
|
| 110 |
+
"""
|
| 111 |
+
if not self.fitted:
|
| 112 |
+
msg = "Interpolator not fitted. Call fit() first."
|
| 113 |
+
raise RuntimeError(msg)
|
| 114 |
+
|
| 115 |
+
results = np.zeros((len(X), 4))
|
| 116 |
+
|
| 117 |
+
for i, name in enumerate(["hue", "value", "chroma", "code"]):
|
| 118 |
+
results[:, i] = self.interpolators[name](X)
|
| 119 |
+
|
| 120 |
+
# Handle NaN values (points outside convex hull)
|
| 121 |
+
if self.fallback_to_nearest:
|
| 122 |
+
nan_mask = np.any(np.isnan(results), axis=1)
|
| 123 |
+
n_nan = nan_mask.sum()
|
| 124 |
+
|
| 125 |
+
if n_nan > 0:
|
| 126 |
+
LOGGER.debug(" %d points outside hull, using nearest neighbor", n_nan)
|
| 127 |
+
# Find nearest neighbors for NaN points
|
| 128 |
+
_, indices = self.kdtree.query(X[nan_mask])
|
| 129 |
+
results[nan_mask] = self.y_data[indices]
|
| 130 |
+
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
def save(self, path: Path) -> None:
|
| 134 |
+
"""Save the interpolator to disk."""
|
| 135 |
+
with open(path, "wb") as f:
|
| 136 |
+
pickle.dump(
|
| 137 |
+
{
|
| 138 |
+
"fallback_to_nearest": self.fallback_to_nearest,
|
| 139 |
+
"interpolators": self.interpolators,
|
| 140 |
+
"kdtree": self.kdtree,
|
| 141 |
+
"y_data": self.y_data,
|
| 142 |
+
},
|
| 143 |
+
f,
|
| 144 |
+
)
|
| 145 |
+
LOGGER.info("Saved Delaunay interpolator to %s", path)
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def load(cls, path: Path) -> "MunsellDelaunayInterpolator":
|
| 149 |
+
"""Load the interpolator from disk."""
|
| 150 |
+
with open(path, "rb") as f:
|
| 151 |
+
data = pickle.load(f) # noqa: S301
|
| 152 |
+
|
| 153 |
+
instance = cls(fallback_to_nearest=data["fallback_to_nearest"])
|
| 154 |
+
instance.interpolators = data["interpolators"]
|
| 155 |
+
instance.kdtree = data["kdtree"]
|
| 156 |
+
instance.y_data = data["y_data"]
|
| 157 |
+
instance.fitted = True
|
| 158 |
+
|
| 159 |
+
LOGGER.info("Loaded Delaunay interpolator from %s", path)
|
| 160 |
+
return instance
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def evaluate_delaunay(
|
| 164 |
+
interpolator: MunsellDelaunayInterpolator,
|
| 165 |
+
X: NDArray,
|
| 166 |
+
y: NDArray,
|
| 167 |
+
name: str = "Test",
|
| 168 |
+
) -> dict:
|
| 169 |
+
"""Evaluate Delaunay interpolator performance."""
|
| 170 |
+
predictions = interpolator.predict(X)
|
| 171 |
+
|
| 172 |
+
# Check for NaN values
|
| 173 |
+
nan_count = np.isnan(predictions).any(axis=1).sum()
|
| 174 |
+
if nan_count > 0:
|
| 175 |
+
LOGGER.warning(" %d/%d predictions contain NaN", nan_count, len(X))
|
| 176 |
+
|
| 177 |
+
# Filter out NaN for error calculation
|
| 178 |
+
valid_mask = ~np.isnan(predictions).any(axis=1)
|
| 179 |
+
if valid_mask.sum() == 0:
|
| 180 |
+
LOGGER.error(" All predictions are NaN!")
|
| 181 |
+
return {
|
| 182 |
+
"hue": float("nan"),
|
| 183 |
+
"value": float("nan"),
|
| 184 |
+
"chroma": float("nan"),
|
| 185 |
+
"code": float("nan"),
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
errors = np.abs(predictions[valid_mask] - y[valid_mask])
|
| 189 |
+
|
| 190 |
+
component_names = ["Hue", "Value", "Chroma", "Code"]
|
| 191 |
+
results = {}
|
| 192 |
+
|
| 193 |
+
LOGGER.info("%s set MAE (%d/%d valid):", name, valid_mask.sum(), len(X))
|
| 194 |
+
for i, comp_name in enumerate(component_names):
|
| 195 |
+
mae = errors[:, i].mean()
|
| 196 |
+
results[comp_name.lower()] = mae
|
| 197 |
+
LOGGER.info(" %s: %.4f", comp_name, mae)
|
| 198 |
+
|
| 199 |
+
return results
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def main() -> None:
|
| 203 |
+
"""Build and evaluate Delaunay interpolator using reference Munsell data."""
|
| 204 |
+
|
| 205 |
+
log_file = setup_logging("delaunay_interpolator", "from_xyY")
|
| 206 |
+
|
| 207 |
+
LOGGER.info("=" * 80)
|
| 208 |
+
LOGGER.info("Delaunay Interpolation for xyY to Munsell Conversion")
|
| 209 |
+
LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
|
| 210 |
+
LOGGER.info("=" * 80)
|
| 211 |
+
|
| 212 |
+
# Load reference data from colour library
|
| 213 |
+
LOGGER.info("")
|
| 214 |
+
LOGGER.info("Loading reference Munsell data...")
|
| 215 |
+
X_all, y_all = load_munsell_reference_data()
|
| 216 |
+
LOGGER.info("Total reference colors: %d", len(X_all))
|
| 217 |
+
|
| 218 |
+
# Split into train/validation (80/20)
|
| 219 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 220 |
+
X_all, y_all, test_size=0.2, random_state=42
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 224 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 225 |
+
|
| 226 |
+
# Test with and without fallback
|
| 227 |
+
LOGGER.info("")
|
| 228 |
+
LOGGER.info("Testing Delaunay interpolation...")
|
| 229 |
+
LOGGER.info("-" * 60)
|
| 230 |
+
|
| 231 |
+
best_config = None
|
| 232 |
+
best_mae = float("inf")
|
| 233 |
+
|
| 234 |
+
for fallback in [True, False]:
|
| 235 |
+
LOGGER.info("")
|
| 236 |
+
LOGGER.info("Fallback to nearest: %s", fallback)
|
| 237 |
+
|
| 238 |
+
interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=fallback)
|
| 239 |
+
interpolator.fit(X_train, y_train)
|
| 240 |
+
|
| 241 |
+
results = evaluate_delaunay(interpolator, X_val, y_val, "Validation")
|
| 242 |
+
|
| 243 |
+
# Skip if results contain NaN
|
| 244 |
+
if any(np.isnan(v) for v in results.values()):
|
| 245 |
+
LOGGER.info(" Skipping due to NaN results")
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
total_mae = sum(results.values())
|
| 249 |
+
|
| 250 |
+
if total_mae < best_mae:
|
| 251 |
+
best_mae = total_mae
|
| 252 |
+
best_config = fallback
|
| 253 |
+
|
| 254 |
+
LOGGER.info("")
|
| 255 |
+
LOGGER.info("=" * 60)
|
| 256 |
+
LOGGER.info("Best configuration: fallback_to_nearest=%s", best_config)
|
| 257 |
+
LOGGER.info("=" * 60)
|
| 258 |
+
|
| 259 |
+
# Train final model on ALL data
|
| 260 |
+
LOGGER.info("")
|
| 261 |
+
LOGGER.info("Training final model on all %d reference colors...", len(X_all))
|
| 262 |
+
|
| 263 |
+
final_interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=best_config)
|
| 264 |
+
final_interpolator.fit(X_all, y_all)
|
| 265 |
+
|
| 266 |
+
LOGGER.info("")
|
| 267 |
+
LOGGER.info("Final evaluation (training set = all data):")
|
| 268 |
+
evaluate_delaunay(final_interpolator, X_all, y_all, "All data")
|
| 269 |
+
|
| 270 |
+
# Save the model
|
| 271 |
+
model_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 272 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
model_path = model_dir / "delaunay_interpolator.pkl"
|
| 274 |
+
final_interpolator.save(model_path)
|
| 275 |
+
|
| 276 |
+
LOGGER.info("")
|
| 277 |
+
LOGGER.info("=" * 80)
|
| 278 |
+
|
| 279 |
+
log_file.close()
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
main()
|
learning_munsell/interpolation/from_xyY/kdtree_interpolator.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
KD-Tree based interpolation for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This approach uses scipy's KDTree for fast nearest neighbor lookups,
|
| 5 |
+
with optional weighted interpolation using k nearest neighbors.
|
| 6 |
+
|
| 7 |
+
Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
|
| 8 |
+
|
| 9 |
+
Advantages over RBF:
|
| 10 |
+
- O(n) memory, O(log n) query time
|
| 11 |
+
- Scales to millions of data points
|
| 12 |
+
- No matrix inversion required
|
| 13 |
+
|
| 14 |
+
Advantages over ML:
|
| 15 |
+
- Deterministic
|
| 16 |
+
- No training required
|
| 17 |
+
- Easy to understand
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import pickle
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from numpy.typing import NDArray
|
| 26 |
+
from scipy.spatial import KDTree
|
| 27 |
+
from sklearn.model_selection import train_test_split
|
| 28 |
+
|
| 29 |
+
from learning_munsell import PROJECT_ROOT, setup_logging
|
| 30 |
+
from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
|
| 31 |
+
|
| 32 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 33 |
+
LOGGER = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MunsellKDTreeInterpolator:
|
| 37 |
+
"""
|
| 38 |
+
KD-Tree based interpolator for xyY to Munsell conversion.
|
| 39 |
+
|
| 40 |
+
Uses k-nearest neighbors with inverse distance weighting
|
| 41 |
+
for smooth interpolation.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, k: int = 5, power: float = 2.0) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Initialize the KD-Tree interpolator.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
k
|
| 51 |
+
Number of nearest neighbors to use for interpolation.
|
| 52 |
+
power
|
| 53 |
+
Power for inverse distance weighting. Higher = sharper.
|
| 54 |
+
"""
|
| 55 |
+
self.k = k
|
| 56 |
+
self.power = power
|
| 57 |
+
self.tree: KDTree | None = None
|
| 58 |
+
self.y_data: NDArray | None = None
|
| 59 |
+
self.fitted = False
|
| 60 |
+
|
| 61 |
+
def fit(self, X: NDArray, y: NDArray) -> "MunsellKDTreeInterpolator":
|
| 62 |
+
"""
|
| 63 |
+
Build the KD-Tree from training data.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
X
|
| 68 |
+
xyY input values of shape (n, 3)
|
| 69 |
+
y
|
| 70 |
+
Munsell output values [hue, value, chroma, code] of shape (n, 4)
|
| 71 |
+
|
| 72 |
+
Returns
|
| 73 |
+
-------
|
| 74 |
+
self
|
| 75 |
+
"""
|
| 76 |
+
LOGGER.info("Building KD-Tree interpolator...")
|
| 77 |
+
LOGGER.info(" k neighbors: %d", self.k)
|
| 78 |
+
LOGGER.info(" IDW power: %.1f", self.power)
|
| 79 |
+
LOGGER.info(" Data points: %d", len(X))
|
| 80 |
+
|
| 81 |
+
self.tree = KDTree(X)
|
| 82 |
+
self.y_data = y.copy()
|
| 83 |
+
self.fitted = True
|
| 84 |
+
|
| 85 |
+
LOGGER.info("KD-Tree built successfully")
|
| 86 |
+
return self
|
| 87 |
+
|
| 88 |
+
def predict(self, X: NDArray) -> NDArray:
|
| 89 |
+
"""
|
| 90 |
+
Predict Munsell values using k-NN with IDW.
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
X
|
| 95 |
+
xyY input values of shape (n, 3)
|
| 96 |
+
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
NDArray
|
| 100 |
+
Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
|
| 101 |
+
"""
|
| 102 |
+
if not self.fitted:
|
| 103 |
+
msg = "Interpolator not fitted. Call fit() first."
|
| 104 |
+
raise RuntimeError(msg)
|
| 105 |
+
|
| 106 |
+
# Query k nearest neighbors
|
| 107 |
+
distances, indices = self.tree.query(X, k=self.k)
|
| 108 |
+
|
| 109 |
+
# Ensure 2D arrays for consistent handling
|
| 110 |
+
if self.k == 1:
|
| 111 |
+
distances = distances.reshape(-1, 1)
|
| 112 |
+
indices = indices.reshape(-1, 1)
|
| 113 |
+
|
| 114 |
+
# Inverse distance weighting
|
| 115 |
+
# Avoid division by zero
|
| 116 |
+
distances = np.maximum(distances, 1e-10)
|
| 117 |
+
weights = 1.0 / (distances**self.power)
|
| 118 |
+
weights /= weights.sum(axis=1, keepdims=True)
|
| 119 |
+
|
| 120 |
+
# Weighted average of neighbor values
|
| 121 |
+
results = np.zeros((len(X), 4))
|
| 122 |
+
for i in range(len(X)):
|
| 123 |
+
neighbor_values = self.y_data[indices[i]]
|
| 124 |
+
if self.k == 1:
|
| 125 |
+
results[i] = neighbor_values.flatten()
|
| 126 |
+
else:
|
| 127 |
+
results[i] = np.sum(weights[i, :, np.newaxis] * neighbor_values, axis=0)
|
| 128 |
+
|
| 129 |
+
return results
|
| 130 |
+
|
| 131 |
+
def save(self, path: Path) -> None:
|
| 132 |
+
"""Save the interpolator to disk."""
|
| 133 |
+
with open(path, "wb") as f:
|
| 134 |
+
pickle.dump(
|
| 135 |
+
{
|
| 136 |
+
"k": self.k,
|
| 137 |
+
"power": self.power,
|
| 138 |
+
"tree": self.tree,
|
| 139 |
+
"y_data": self.y_data,
|
| 140 |
+
},
|
| 141 |
+
f,
|
| 142 |
+
)
|
| 143 |
+
LOGGER.info("Saved KD-Tree interpolator to %s", path)
|
| 144 |
+
|
| 145 |
+
@classmethod
|
| 146 |
+
def load(cls, path: Path) -> "MunsellKDTreeInterpolator":
|
| 147 |
+
"""Load the interpolator from disk."""
|
| 148 |
+
with open(path, "rb") as f:
|
| 149 |
+
data = pickle.load(f) # noqa: S301
|
| 150 |
+
|
| 151 |
+
instance = cls(k=data["k"], power=data["power"])
|
| 152 |
+
instance.tree = data["tree"]
|
| 153 |
+
instance.y_data = data["y_data"]
|
| 154 |
+
instance.fitted = True
|
| 155 |
+
|
| 156 |
+
LOGGER.info("Loaded KD-Tree interpolator from %s", path)
|
| 157 |
+
return instance
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def evaluate_kdtree(
|
| 161 |
+
interpolator: MunsellKDTreeInterpolator,
|
| 162 |
+
X: NDArray,
|
| 163 |
+
y: NDArray,
|
| 164 |
+
name: str = "Test",
|
| 165 |
+
) -> dict:
|
| 166 |
+
"""Evaluate KD-Tree interpolator performance."""
|
| 167 |
+
predictions = interpolator.predict(X)
|
| 168 |
+
errors = np.abs(predictions - y)
|
| 169 |
+
|
| 170 |
+
component_names = ["Hue", "Value", "Chroma", "Code"]
|
| 171 |
+
results = {}
|
| 172 |
+
|
| 173 |
+
LOGGER.info("%s set MAE:", name)
|
| 174 |
+
for i, comp_name in enumerate(component_names):
|
| 175 |
+
mae = errors[:, i].mean()
|
| 176 |
+
results[comp_name.lower()] = mae
|
| 177 |
+
LOGGER.info(" %s: %.4f", comp_name, mae)
|
| 178 |
+
|
| 179 |
+
return results
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def main() -> None:
|
| 183 |
+
"""Build and evaluate KD-Tree interpolator using reference Munsell data."""
|
| 184 |
+
|
| 185 |
+
log_file = setup_logging("kdtree_interpolator", "from_xyY")
|
| 186 |
+
|
| 187 |
+
LOGGER.info("=" * 80)
|
| 188 |
+
LOGGER.info("KD-Tree Interpolation for xyY to Munsell Conversion")
|
| 189 |
+
LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
|
| 190 |
+
LOGGER.info("=" * 80)
|
| 191 |
+
|
| 192 |
+
# Load reference data from colour library
|
| 193 |
+
LOGGER.info("")
|
| 194 |
+
LOGGER.info("Loading reference Munsell data...")
|
| 195 |
+
X_all, y_all = load_munsell_reference_data()
|
| 196 |
+
LOGGER.info("Total reference colors: %d", len(X_all))
|
| 197 |
+
|
| 198 |
+
# Split into train/validation (80/20)
|
| 199 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 200 |
+
X_all, y_all, test_size=0.2, random_state=42
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 204 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 205 |
+
|
| 206 |
+
# Test different k values
|
| 207 |
+
k_values = [1, 3, 5, 10, 20, 50]
|
| 208 |
+
|
| 209 |
+
best_k = None
|
| 210 |
+
best_mae = float("inf")
|
| 211 |
+
|
| 212 |
+
LOGGER.info("")
|
| 213 |
+
LOGGER.info("Testing different k values...")
|
| 214 |
+
LOGGER.info("-" * 60)
|
| 215 |
+
|
| 216 |
+
for k in k_values:
|
| 217 |
+
LOGGER.info("")
|
| 218 |
+
LOGGER.info("k = %d:", k)
|
| 219 |
+
|
| 220 |
+
interpolator = MunsellKDTreeInterpolator(k=k, power=2.0)
|
| 221 |
+
interpolator.fit(X_train, y_train)
|
| 222 |
+
|
| 223 |
+
results = evaluate_kdtree(interpolator, X_val, y_val, "Validation")
|
| 224 |
+
total_mae = sum(results.values())
|
| 225 |
+
|
| 226 |
+
if total_mae < best_mae:
|
| 227 |
+
best_mae = total_mae
|
| 228 |
+
best_k = k
|
| 229 |
+
|
| 230 |
+
LOGGER.info("")
|
| 231 |
+
LOGGER.info("=" * 60)
|
| 232 |
+
LOGGER.info("Best k: %d", best_k)
|
| 233 |
+
LOGGER.info("=" * 60)
|
| 234 |
+
|
| 235 |
+
# Train final model with best k on ALL data
|
| 236 |
+
LOGGER.info("")
|
| 237 |
+
LOGGER.info(
|
| 238 |
+
"Training final model on all %d reference colors with k=%d...",
|
| 239 |
+
len(X_all),
|
| 240 |
+
best_k,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
final_interpolator = MunsellKDTreeInterpolator(k=best_k, power=2.0)
|
| 244 |
+
final_interpolator.fit(X_all, y_all)
|
| 245 |
+
|
| 246 |
+
LOGGER.info("")
|
| 247 |
+
LOGGER.info("Final evaluation (training set = all data):")
|
| 248 |
+
evaluate_kdtree(final_interpolator, X_all, y_all, "All data")
|
| 249 |
+
|
| 250 |
+
# Save the model
|
| 251 |
+
model_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 252 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
model_path = model_dir / "kdtree_interpolator.pkl"
|
| 254 |
+
final_interpolator.save(model_path)
|
| 255 |
+
|
| 256 |
+
LOGGER.info("")
|
| 257 |
+
LOGGER.info("=" * 80)
|
| 258 |
+
|
| 259 |
+
log_file.close()
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
main()
|
learning_munsell/interpolation/from_xyY/rbf_interpolator.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RBF (Radial Basis Function) interpolation for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This approach uses scipy's RBFInterpolator to build a lookup table
|
| 5 |
+
with smooth interpolation between known color samples.
|
| 6 |
+
|
| 7 |
+
Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
|
| 8 |
+
|
| 9 |
+
Advantages over ML:
|
| 10 |
+
- Deterministic, no training required
|
| 11 |
+
- Exact interpolation at known points
|
| 12 |
+
- Smooth interpolation between points
|
| 13 |
+
- Easy to understand and debug
|
| 14 |
+
|
| 15 |
+
Disadvantages:
|
| 16 |
+
- Memory scales with number of data points
|
| 17 |
+
- Query time scales with data points (O(n) naive, can optimize)
|
| 18 |
+
- May struggle with extrapolation
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
import pickle
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
from numpy.typing import NDArray
|
| 27 |
+
from scipy.interpolate import RBFInterpolator
|
| 28 |
+
from sklearn.model_selection import train_test_split
|
| 29 |
+
|
| 30 |
+
from learning_munsell import PROJECT_ROOT, setup_logging
|
| 31 |
+
from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MunsellRBFInterpolator:
|
| 38 |
+
"""
|
| 39 |
+
RBF-based interpolator for xyY to Munsell conversion.
|
| 40 |
+
|
| 41 |
+
Uses separate RBF interpolators for each Munsell component
|
| 42 |
+
(hue, value, chroma, code) to allow independent kernel tuning.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
kernel: str = "thin_plate_spline",
|
| 48 |
+
smoothing: float = 0.0,
|
| 49 |
+
epsilon: float | None = None,
|
| 50 |
+
) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Initialize the RBF interpolator.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
kernel
|
| 57 |
+
RBF kernel type. Options: 'linear', 'thin_plate_spline',
|
| 58 |
+
'cubic', 'quintic', 'multiquadric', 'inverse_multiquadric',
|
| 59 |
+
'inverse_quadratic', 'gaussian'
|
| 60 |
+
smoothing
|
| 61 |
+
Smoothing parameter. 0 = exact interpolation.
|
| 62 |
+
epsilon
|
| 63 |
+
Shape parameter for kernels that use it.
|
| 64 |
+
"""
|
| 65 |
+
self.kernel = kernel
|
| 66 |
+
self.smoothing = smoothing
|
| 67 |
+
self.epsilon = epsilon
|
| 68 |
+
|
| 69 |
+
self.interpolators: dict[str, RBFInterpolator] = {}
|
| 70 |
+
self.fitted = False
|
| 71 |
+
|
| 72 |
+
def fit(self, X: NDArray, y: NDArray) -> "MunsellRBFInterpolator":
|
| 73 |
+
"""
|
| 74 |
+
Fit RBF interpolators to the training data.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
X
|
| 79 |
+
xyY input values of shape (n, 3)
|
| 80 |
+
y
|
| 81 |
+
Munsell output values [hue, value, chroma, code] of shape (n, 4)
|
| 82 |
+
|
| 83 |
+
Returns
|
| 84 |
+
-------
|
| 85 |
+
self
|
| 86 |
+
"""
|
| 87 |
+
LOGGER.info("Fitting RBF interpolators...")
|
| 88 |
+
LOGGER.info(" Kernel: %s", self.kernel)
|
| 89 |
+
LOGGER.info(" Smoothing: %s", self.smoothing)
|
| 90 |
+
LOGGER.info(" Data points: %d", len(X))
|
| 91 |
+
|
| 92 |
+
component_names = ["hue", "value", "chroma", "code"]
|
| 93 |
+
|
| 94 |
+
for i, name in enumerate(component_names):
|
| 95 |
+
LOGGER.info(" Building %s interpolator...", name)
|
| 96 |
+
|
| 97 |
+
kwargs = {
|
| 98 |
+
"kernel": self.kernel,
|
| 99 |
+
"smoothing": self.smoothing,
|
| 100 |
+
}
|
| 101 |
+
if self.epsilon is not None:
|
| 102 |
+
kwargs["epsilon"] = self.epsilon
|
| 103 |
+
|
| 104 |
+
self.interpolators[name] = RBFInterpolator(X, y[:, i], **kwargs)
|
| 105 |
+
|
| 106 |
+
self.fitted = True
|
| 107 |
+
LOGGER.info("RBF interpolators fitted successfully")
|
| 108 |
+
|
| 109 |
+
return self
|
| 110 |
+
|
| 111 |
+
def predict(self, X: NDArray) -> NDArray:
|
| 112 |
+
"""
|
| 113 |
+
Predict Munsell values for given xyY inputs.
|
| 114 |
+
|
| 115 |
+
Parameters
|
| 116 |
+
----------
|
| 117 |
+
X
|
| 118 |
+
xyY input values of shape (n, 3)
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
NDArray
|
| 123 |
+
Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
|
| 124 |
+
"""
|
| 125 |
+
if not self.fitted:
|
| 126 |
+
msg = "Interpolator not fitted. Call fit() first."
|
| 127 |
+
raise RuntimeError(msg)
|
| 128 |
+
|
| 129 |
+
results = np.zeros((len(X), 4))
|
| 130 |
+
|
| 131 |
+
for i, name in enumerate(["hue", "value", "chroma", "code"]):
|
| 132 |
+
results[:, i] = self.interpolators[name](X)
|
| 133 |
+
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
def save(self, path: Path) -> None:
|
| 137 |
+
"""Save the interpolator to disk."""
|
| 138 |
+
with open(path, "wb") as f:
|
| 139 |
+
pickle.dump(
|
| 140 |
+
{
|
| 141 |
+
"kernel": self.kernel,
|
| 142 |
+
"smoothing": self.smoothing,
|
| 143 |
+
"epsilon": self.epsilon,
|
| 144 |
+
"interpolators": self.interpolators,
|
| 145 |
+
},
|
| 146 |
+
f,
|
| 147 |
+
)
|
| 148 |
+
LOGGER.info("Saved RBF interpolator to %s", path)
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def load(cls, path: Path) -> "MunsellRBFInterpolator":
|
| 152 |
+
"""Load the interpolator from disk."""
|
| 153 |
+
with open(path, "rb") as f:
|
| 154 |
+
data = pickle.load(f) # noqa: S301
|
| 155 |
+
|
| 156 |
+
instance = cls(
|
| 157 |
+
kernel=data["kernel"],
|
| 158 |
+
smoothing=data["smoothing"],
|
| 159 |
+
epsilon=data["epsilon"],
|
| 160 |
+
)
|
| 161 |
+
instance.interpolators = data["interpolators"]
|
| 162 |
+
instance.fitted = True
|
| 163 |
+
|
| 164 |
+
LOGGER.info("Loaded RBF interpolator from %s", path)
|
| 165 |
+
return instance
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def evaluate_rbf(
|
| 169 |
+
interpolator: MunsellRBFInterpolator,
|
| 170 |
+
X: NDArray,
|
| 171 |
+
y: NDArray,
|
| 172 |
+
name: str = "Test",
|
| 173 |
+
) -> dict[str, float]:
|
| 174 |
+
"""
|
| 175 |
+
Evaluate RBF interpolator performance.
|
| 176 |
+
|
| 177 |
+
Parameters
|
| 178 |
+
----------
|
| 179 |
+
interpolator
|
| 180 |
+
Fitted RBF interpolator
|
| 181 |
+
X
|
| 182 |
+
Input xyY values
|
| 183 |
+
y
|
| 184 |
+
Ground truth Munsell values
|
| 185 |
+
name
|
| 186 |
+
Name for logging
|
| 187 |
+
|
| 188 |
+
Returns
|
| 189 |
+
-------
|
| 190 |
+
dict
|
| 191 |
+
Dictionary of MAE values for each component
|
| 192 |
+
"""
|
| 193 |
+
predictions = interpolator.predict(X)
|
| 194 |
+
errors = np.abs(predictions - y)
|
| 195 |
+
|
| 196 |
+
component_names = ["Hue", "Value", "Chroma", "Code"]
|
| 197 |
+
results = {}
|
| 198 |
+
|
| 199 |
+
LOGGER.info("%s set MAE:", name)
|
| 200 |
+
for i, comp_name in enumerate(component_names):
|
| 201 |
+
mae = errors[:, i].mean()
|
| 202 |
+
results[comp_name.lower()] = mae
|
| 203 |
+
LOGGER.info(" %s: %.4f", comp_name, mae)
|
| 204 |
+
|
| 205 |
+
return results
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main() -> None:
|
| 209 |
+
"""Build and evaluate RBF interpolator using reference Munsell data."""
|
| 210 |
+
|
| 211 |
+
log_file = setup_logging("rbf_interpolator", "from_xyY")
|
| 212 |
+
|
| 213 |
+
LOGGER.info("=" * 80)
|
| 214 |
+
LOGGER.info("RBF Interpolation for xyY to Munsell Conversion")
|
| 215 |
+
LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
|
| 216 |
+
LOGGER.info("=" * 80)
|
| 217 |
+
|
| 218 |
+
# Load reference data from colour library
|
| 219 |
+
LOGGER.info("")
|
| 220 |
+
LOGGER.info("Loading reference Munsell data...")
|
| 221 |
+
X_all, y_all = load_munsell_reference_data()
|
| 222 |
+
LOGGER.info("Total reference colors: %d", len(X_all))
|
| 223 |
+
|
| 224 |
+
# Split into train/validation (80/20)
|
| 225 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 226 |
+
X_all, y_all, test_size=0.2, random_state=42
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 230 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 231 |
+
|
| 232 |
+
# Test different kernels
|
| 233 |
+
kernels_to_test = [
|
| 234 |
+
("thin_plate_spline", 0.0),
|
| 235 |
+
("thin_plate_spline", 0.001),
|
| 236 |
+
("thin_plate_spline", 0.01),
|
| 237 |
+
("cubic", 0.0),
|
| 238 |
+
("linear", 0.0),
|
| 239 |
+
("multiquadric", 0.0),
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
best_kernel = None
|
| 243 |
+
best_smoothing = None
|
| 244 |
+
best_mae = float("inf")
|
| 245 |
+
|
| 246 |
+
LOGGER.info("")
|
| 247 |
+
LOGGER.info("Testing different RBF kernels...")
|
| 248 |
+
LOGGER.info("-" * 60)
|
| 249 |
+
|
| 250 |
+
for kernel, smoothing in kernels_to_test:
|
| 251 |
+
LOGGER.info("")
|
| 252 |
+
LOGGER.info("Kernel: %s, Smoothing: %s", kernel, smoothing)
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
interpolator = MunsellRBFInterpolator(kernel=kernel, smoothing=smoothing)
|
| 256 |
+
interpolator.fit(X_train, y_train)
|
| 257 |
+
|
| 258 |
+
results = evaluate_rbf(interpolator, X_val, y_val, "Validation")
|
| 259 |
+
total_mae = sum(results.values())
|
| 260 |
+
|
| 261 |
+
if total_mae < best_mae:
|
| 262 |
+
best_mae = total_mae
|
| 263 |
+
best_kernel = kernel
|
| 264 |
+
best_smoothing = smoothing
|
| 265 |
+
|
| 266 |
+
except Exception:
|
| 267 |
+
LOGGER.exception(" Failed")
|
| 268 |
+
|
| 269 |
+
LOGGER.info("")
|
| 270 |
+
LOGGER.info("=" * 60)
|
| 271 |
+
LOGGER.info("Best configuration: %s with smoothing=%s", best_kernel, best_smoothing)
|
| 272 |
+
LOGGER.info("=" * 60)
|
| 273 |
+
|
| 274 |
+
# Train final model with best kernel on ALL data
|
| 275 |
+
LOGGER.info("")
|
| 276 |
+
LOGGER.info("Training final model on all %d reference colors...", len(X_all))
|
| 277 |
+
|
| 278 |
+
final_interpolator = MunsellRBFInterpolator(
|
| 279 |
+
kernel=best_kernel, smoothing=best_smoothing
|
| 280 |
+
)
|
| 281 |
+
final_interpolator.fit(X_all, y_all)
|
| 282 |
+
|
| 283 |
+
LOGGER.info("")
|
| 284 |
+
LOGGER.info("Final evaluation (training set = all data):")
|
| 285 |
+
evaluate_rbf(final_interpolator, X_all, y_all, "All data")
|
| 286 |
+
|
| 287 |
+
# Save the model
|
| 288 |
+
model_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 289 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 290 |
+
model_path = model_dir / "rbf_interpolator.pkl"
|
| 291 |
+
final_interpolator.save(model_path)
|
| 292 |
+
|
| 293 |
+
LOGGER.info("")
|
| 294 |
+
LOGGER.info("=" * 80)
|
| 295 |
+
|
| 296 |
+
log_file.close()
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
main()
|
learning_munsell/losses/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loss functions for Munsell ML training."""
|
| 2 |
+
|
| 3 |
+
from learning_munsell.losses.jax_delta_e import (
|
| 4 |
+
XYZ_to_Lab,
|
| 5 |
+
delta_E_CIE2000,
|
| 6 |
+
delta_E_loss,
|
| 7 |
+
xyY_to_Lab,
|
| 8 |
+
xyY_to_XYZ,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"delta_E_CIE2000",
|
| 13 |
+
"delta_E_loss",
|
| 14 |
+
"xyY_to_Lab",
|
| 15 |
+
"xyY_to_XYZ",
|
| 16 |
+
"XYZ_to_Lab",
|
| 17 |
+
]
|
learning_munsell/losses/jax_delta_e.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Differentiable Delta-E Loss Functions using JAX
|
| 3 |
+
================================================
|
| 4 |
+
|
| 5 |
+
This module provides JAX implementations of color space conversions
|
| 6 |
+
and Delta-E (CIE2000) loss function for use in training.
|
| 7 |
+
|
| 8 |
+
The key insight is that we can compute Delta-E between:
|
| 9 |
+
- The input xyY (which we convert to Lab as the "target")
|
| 10 |
+
- The predicted Munsell converted back to Lab
|
| 11 |
+
|
| 12 |
+
For the Munsell -> xyY conversion, we either:
|
| 13 |
+
1. Use a pre-trained neural network approximator
|
| 14 |
+
2. Use differentiable interpolation on the Munsell Renotation data
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import colour
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
import numpy as np
|
| 23 |
+
from jax import Array
|
| 24 |
+
|
| 25 |
+
# D65 illuminant XYZ reference values (standard for sRGB)
|
| 26 |
+
D65_XYZ = jnp.array([95.047, 100.0, 108.883])
|
| 27 |
+
|
| 28 |
+
# Illuminant C XYZ reference values (used by Munsell system)
|
| 29 |
+
ILLUMINANT_C_XYZ = jnp.array([98.074, 100.0, 118.232])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def xyY_to_XYZ(xyY: Array, scale_Y: bool = True) -> Array:
|
| 33 |
+
"""
|
| 34 |
+
Convert CIE xyY to CIE XYZ.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
xyY : Array
|
| 39 |
+
CIE xyY values with shape (..., 3)
|
| 40 |
+
scale_Y : bool
|
| 41 |
+
If True, scale Y from 0-1 to 0-100 range (required for Lab conversion)
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
-------
|
| 45 |
+
Array
|
| 46 |
+
CIE XYZ values with shape (..., 3)
|
| 47 |
+
"""
|
| 48 |
+
x = xyY[..., 0]
|
| 49 |
+
y = xyY[..., 1]
|
| 50 |
+
Y = xyY[..., 2]
|
| 51 |
+
|
| 52 |
+
# Scale Y to 0-100 range if needed (colour library uses 0-100)
|
| 53 |
+
if scale_Y:
|
| 54 |
+
Y = Y * 100.0
|
| 55 |
+
|
| 56 |
+
# Avoid division by zero
|
| 57 |
+
y_safe = jnp.where(y == 0, 1e-10, y)
|
| 58 |
+
|
| 59 |
+
X = (x * Y) / y_safe
|
| 60 |
+
Z = ((1 - x - y) * Y) / y_safe
|
| 61 |
+
|
| 62 |
+
# Handle y=0 case (set X=Z=0)
|
| 63 |
+
X = jnp.where(y == 0, 0.0, X)
|
| 64 |
+
Z = jnp.where(y == 0, 0.0, Z)
|
| 65 |
+
|
| 66 |
+
return jnp.stack([X, Y, Z], axis=-1)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def XYZ_to_Lab(XYZ: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array:
|
| 70 |
+
"""
|
| 71 |
+
Convert CIE XYZ to CIE Lab.
|
| 72 |
+
|
| 73 |
+
Parameters
|
| 74 |
+
----------
|
| 75 |
+
XYZ : Array
|
| 76 |
+
CIE XYZ values with shape (..., 3)
|
| 77 |
+
illuminant : Array
|
| 78 |
+
Reference white XYZ values
|
| 79 |
+
|
| 80 |
+
Returns
|
| 81 |
+
-------
|
| 82 |
+
Array
|
| 83 |
+
CIE Lab values with shape (..., 3)
|
| 84 |
+
"""
|
| 85 |
+
# Normalize by illuminant
|
| 86 |
+
XYZ_n = XYZ / illuminant
|
| 87 |
+
|
| 88 |
+
# CIE Lab transfer function
|
| 89 |
+
delta = 6.0 / 29.0
|
| 90 |
+
delta_cube = delta**3
|
| 91 |
+
|
| 92 |
+
# f(t) = t^(1/3) if t > delta^3, else t/(3*delta^2) + 4/29
|
| 93 |
+
def f(t: Array) -> Array:
|
| 94 |
+
return jnp.where(t > delta_cube, jnp.cbrt(t), t / (3 * delta**2) + 4.0 / 29.0)
|
| 95 |
+
|
| 96 |
+
f_X = f(XYZ_n[..., 0])
|
| 97 |
+
f_Y = f(XYZ_n[..., 1])
|
| 98 |
+
f_Z = f(XYZ_n[..., 2])
|
| 99 |
+
|
| 100 |
+
L = 116.0 * f_Y - 16.0
|
| 101 |
+
a = 500.0 * (f_X - f_Y)
|
| 102 |
+
b = 200.0 * (f_Y - f_Z)
|
| 103 |
+
|
| 104 |
+
return jnp.stack([L, a, b], axis=-1)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def xyY_to_Lab(xyY: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array:
|
| 108 |
+
"""Convert CIE xyY directly to CIE Lab."""
|
| 109 |
+
return XYZ_to_Lab(xyY_to_XYZ(xyY), illuminant)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def delta_E_CIE2000(Lab_1: Array, Lab_2: Array) -> Array:
|
| 113 |
+
"""
|
| 114 |
+
Compute CIE 2000 Delta-E color difference.
|
| 115 |
+
|
| 116 |
+
This is a differentiable JAX implementation of the CIE 2000 Delta-E formula.
|
| 117 |
+
|
| 118 |
+
Parameters
|
| 119 |
+
----------
|
| 120 |
+
Lab_1 : Array
|
| 121 |
+
First CIE Lab color(s) with shape (..., 3)
|
| 122 |
+
Lab_2 : Array
|
| 123 |
+
Second CIE Lab color(s) with shape (..., 3)
|
| 124 |
+
|
| 125 |
+
Returns
|
| 126 |
+
-------
|
| 127 |
+
Array
|
| 128 |
+
Delta-E values with shape (...)
|
| 129 |
+
"""
|
| 130 |
+
L_1, a_1, b_1 = Lab_1[..., 0], Lab_1[..., 1], Lab_1[..., 2]
|
| 131 |
+
L_2, a_2, b_2 = Lab_2[..., 0], Lab_2[..., 1], Lab_2[..., 2]
|
| 132 |
+
|
| 133 |
+
# Chroma
|
| 134 |
+
C_1_ab = jnp.sqrt(a_1**2 + b_1**2)
|
| 135 |
+
C_2_ab = jnp.sqrt(a_2**2 + b_2**2)
|
| 136 |
+
|
| 137 |
+
C_bar_ab = (C_1_ab + C_2_ab) / 2
|
| 138 |
+
C_bar_ab_7 = C_bar_ab**7
|
| 139 |
+
|
| 140 |
+
# G factor for a' adjustment (25^7 = 6103515625.0)
|
| 141 |
+
G = 0.5 * (1 - jnp.sqrt(C_bar_ab_7 / (C_bar_ab_7 + 6103515625.0)))
|
| 142 |
+
|
| 143 |
+
# Adjusted a'
|
| 144 |
+
a_p_1 = (1 + G) * a_1
|
| 145 |
+
a_p_2 = (1 + G) * a_2
|
| 146 |
+
|
| 147 |
+
# Adjusted chroma C'
|
| 148 |
+
C_p_1 = jnp.sqrt(a_p_1**2 + b_1**2)
|
| 149 |
+
C_p_2 = jnp.sqrt(a_p_2**2 + b_2**2)
|
| 150 |
+
|
| 151 |
+
# Hue angle h' (in degrees)
|
| 152 |
+
h_p_1 = jnp.degrees(jnp.arctan2(b_1, a_p_1)) % 360
|
| 153 |
+
h_p_2 = jnp.degrees(jnp.arctan2(b_2, a_p_2)) % 360
|
| 154 |
+
|
| 155 |
+
# Handle achromatic case
|
| 156 |
+
h_p_1 = jnp.where((b_1 == 0) & (a_p_1 == 0), 0.0, h_p_1)
|
| 157 |
+
h_p_2 = jnp.where((b_2 == 0) & (a_p_2 == 0), 0.0, h_p_2)
|
| 158 |
+
|
| 159 |
+
# Delta L', C'
|
| 160 |
+
delta_L_p = L_2 - L_1
|
| 161 |
+
delta_C_p = C_p_2 - C_p_1
|
| 162 |
+
|
| 163 |
+
# Delta h'
|
| 164 |
+
h_p_diff = h_p_2 - h_p_1
|
| 165 |
+
C_p_product = C_p_1 * C_p_2
|
| 166 |
+
|
| 167 |
+
delta_h_p = jnp.where(
|
| 168 |
+
C_p_product == 0,
|
| 169 |
+
0.0,
|
| 170 |
+
jnp.where(
|
| 171 |
+
jnp.abs(h_p_diff) <= 180,
|
| 172 |
+
h_p_diff,
|
| 173 |
+
jnp.where(h_p_diff > 180, h_p_diff - 360, h_p_diff + 360),
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Delta H'
|
| 178 |
+
delta_H_p = 2 * jnp.sqrt(C_p_product) * jnp.sin(jnp.radians(delta_h_p / 2))
|
| 179 |
+
|
| 180 |
+
# Mean L', C'
|
| 181 |
+
L_bar_p = (L_1 + L_2) / 2
|
| 182 |
+
C_bar_p = (C_p_1 + C_p_2) / 2
|
| 183 |
+
|
| 184 |
+
# Mean h'
|
| 185 |
+
h_p_sum = h_p_1 + h_p_2
|
| 186 |
+
h_p_abs_diff = jnp.abs(h_p_1 - h_p_2)
|
| 187 |
+
|
| 188 |
+
h_bar_p = jnp.where(
|
| 189 |
+
C_p_product == 0,
|
| 190 |
+
h_p_sum,
|
| 191 |
+
jnp.where(
|
| 192 |
+
h_p_abs_diff <= 180,
|
| 193 |
+
h_p_sum / 2,
|
| 194 |
+
jnp.where(h_p_sum < 360, (h_p_sum + 360) / 2, (h_p_sum - 360) / 2),
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# T factor
|
| 199 |
+
T = (
|
| 200 |
+
1
|
| 201 |
+
- 0.17 * jnp.cos(jnp.radians(h_bar_p - 30))
|
| 202 |
+
+ 0.24 * jnp.cos(jnp.radians(2 * h_bar_p))
|
| 203 |
+
+ 0.32 * jnp.cos(jnp.radians(3 * h_bar_p + 6))
|
| 204 |
+
- 0.20 * jnp.cos(jnp.radians(4 * h_bar_p - 63))
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Delta theta
|
| 208 |
+
delta_theta = 30 * jnp.exp(-(((h_bar_p - 275) / 25) ** 2))
|
| 209 |
+
|
| 210 |
+
# R_C (25^7 = 6103515625.0)
|
| 211 |
+
C_bar_p_7 = C_bar_p**7
|
| 212 |
+
R_C = 2 * jnp.sqrt(C_bar_p_7 / (C_bar_p_7 + 6103515625.0))
|
| 213 |
+
|
| 214 |
+
# S_L, S_C, S_H
|
| 215 |
+
L_bar_p_minus_50_sq = (L_bar_p - 50) ** 2
|
| 216 |
+
S_L = 1 + (0.015 * L_bar_p_minus_50_sq) / jnp.sqrt(20 + L_bar_p_minus_50_sq)
|
| 217 |
+
S_C = 1 + 0.045 * C_bar_p
|
| 218 |
+
S_H = 1 + 0.015 * C_bar_p * T
|
| 219 |
+
|
| 220 |
+
# R_T
|
| 221 |
+
R_T = -jnp.sin(jnp.radians(2 * delta_theta)) * R_C
|
| 222 |
+
|
| 223 |
+
# Final Delta E
|
| 224 |
+
k_L, k_C, k_H = 1.0, 1.0, 1.0
|
| 225 |
+
|
| 226 |
+
term_L = delta_L_p / (k_L * S_L)
|
| 227 |
+
term_C = delta_C_p / (k_C * S_C)
|
| 228 |
+
term_H = delta_H_p / (k_H * S_H)
|
| 229 |
+
|
| 230 |
+
return jnp.sqrt(term_L**2 + term_C**2 + term_H**2 + R_T * term_C * term_H)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def delta_E_loss(pred_xyY: Array, target_xyY: Array) -> Array:
|
| 234 |
+
"""
|
| 235 |
+
Compute mean Delta-E loss between predicted and target xyY values.
|
| 236 |
+
|
| 237 |
+
This is the primary loss function for training with perceptual accuracy.
|
| 238 |
+
|
| 239 |
+
Parameters
|
| 240 |
+
----------
|
| 241 |
+
pred_xyY : Array
|
| 242 |
+
Predicted xyY values with shape (batch, 3)
|
| 243 |
+
target_xyY : Array
|
| 244 |
+
Target xyY values with shape (batch, 3)
|
| 245 |
+
|
| 246 |
+
Returns
|
| 247 |
+
-------
|
| 248 |
+
Array
|
| 249 |
+
Scalar mean Delta-E loss
|
| 250 |
+
"""
|
| 251 |
+
pred_Lab = xyY_to_Lab(pred_xyY)
|
| 252 |
+
target_Lab = xyY_to_Lab(target_xyY)
|
| 253 |
+
return jnp.mean(delta_E_CIE2000(pred_Lab, target_Lab))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# JIT-compiled versions for performance
|
| 257 |
+
xyY_to_XYZ_jit = jax.jit(xyY_to_XYZ)
|
| 258 |
+
XYZ_to_Lab_jit = jax.jit(XYZ_to_Lab)
|
| 259 |
+
xyY_to_Lab_jit = jax.jit(xyY_to_Lab)
|
| 260 |
+
delta_E_CIE2000_jit = jax.jit(delta_E_CIE2000)
|
| 261 |
+
delta_E_loss_jit = jax.jit(delta_E_loss)
|
| 262 |
+
|
| 263 |
+
# Gradient functions
|
| 264 |
+
grad_delta_E_loss = jax.grad(delta_E_loss)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def test_jax_delta_e() -> None:
|
| 268 |
+
"""Test the JAX Delta-E implementation against colour library."""
|
| 269 |
+
# Test xyY values
|
| 270 |
+
xyY_1 = np.array([0.3127, 0.3290, 0.5]) # D65 white point, Y=0.5
|
| 271 |
+
xyY_2 = np.array([0.35, 0.35, 0.5]) # Slightly shifted
|
| 272 |
+
|
| 273 |
+
# Convert using JAX
|
| 274 |
+
Lab_1_jax = xyY_to_Lab(jnp.array(xyY_1))
|
| 275 |
+
Lab_2_jax = xyY_to_Lab(jnp.array(xyY_2))
|
| 276 |
+
delta_E_CIE2000(Lab_1_jax, Lab_2_jax)
|
| 277 |
+
|
| 278 |
+
# Convert using colour library
|
| 279 |
+
XYZ_1 = colour.xyY_to_XYZ(xyY_1)
|
| 280 |
+
XYZ_2 = colour.xyY_to_XYZ(xyY_2)
|
| 281 |
+
Lab_1_colour = colour.XYZ_to_Lab(
|
| 282 |
+
XYZ_1, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"]
|
| 283 |
+
)
|
| 284 |
+
Lab_2_colour = colour.XYZ_to_Lab(
|
| 285 |
+
XYZ_2, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"]
|
| 286 |
+
)
|
| 287 |
+
colour.delta_E(Lab_1_colour, Lab_2_colour, method="CIE 2000")
|
| 288 |
+
|
| 289 |
+
# Test gradient computation
|
| 290 |
+
pred_xyY = jnp.array([[0.35, 0.35, 0.5]])
|
| 291 |
+
target_xyY = jnp.array([[0.3127, 0.3290, 0.5]])
|
| 292 |
+
|
| 293 |
+
# Compute gradient
|
| 294 |
+
grad_fn = jax.grad(lambda x: delta_E_loss(x, target_xyY))
|
| 295 |
+
grad_fn(pred_xyY)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
test_jax_delta_e()
|
learning_munsell/models/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Neural network models for Munsell color conversions."""
|
| 2 |
+
|
| 3 |
+
from learning_munsell.models.networks import (
|
| 4 |
+
# Building blocks
|
| 5 |
+
ResidualBlock,
|
| 6 |
+
# Component networks
|
| 7 |
+
ComponentMLP,
|
| 8 |
+
ComponentErrorPredictor,
|
| 9 |
+
# Transformer building blocks
|
| 10 |
+
FeatureTokenizer,
|
| 11 |
+
TransformerBlock,
|
| 12 |
+
# Composite models: xyY → Munsell
|
| 13 |
+
MLPToMunsell,
|
| 14 |
+
MultiHeadMLPToMunsell,
|
| 15 |
+
MultiMLPToMunsell,
|
| 16 |
+
TransformerToMunsell,
|
| 17 |
+
# Error predictors: xyY → Munsell
|
| 18 |
+
MultiHeadErrorPredictorToMunsell,
|
| 19 |
+
MultiMLPErrorPredictorToMunsell,
|
| 20 |
+
# Composite models: Munsell → xyY
|
| 21 |
+
MultiMLPToxyY,
|
| 22 |
+
# Error predictors: Munsell → xyY
|
| 23 |
+
MultiMLPErrorPredictorToxyY,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
# Building blocks
|
| 28 |
+
"ResidualBlock",
|
| 29 |
+
# Component networks (single output)
|
| 30 |
+
"ComponentMLP",
|
| 31 |
+
"ComponentErrorPredictor",
|
| 32 |
+
# Transformer building blocks
|
| 33 |
+
"FeatureTokenizer",
|
| 34 |
+
"TransformerBlock",
|
| 35 |
+
# Composite models: xyY → Munsell
|
| 36 |
+
"MLPToMunsell",
|
| 37 |
+
"MultiHeadMLPToMunsell",
|
| 38 |
+
"MultiMLPToMunsell",
|
| 39 |
+
"TransformerToMunsell",
|
| 40 |
+
# Error predictors: xyY → Munsell
|
| 41 |
+
"MultiHeadErrorPredictorToMunsell",
|
| 42 |
+
"MultiMLPErrorPredictorToMunsell",
|
| 43 |
+
# Composite models: Munsell → xyY
|
| 44 |
+
"MultiMLPToxyY",
|
| 45 |
+
# Error predictors: Munsell → xyY
|
| 46 |
+
"MultiMLPErrorPredictorToxyY",
|
| 47 |
+
]
|
learning_munsell/models/networks.py
ADDED
|
@@ -0,0 +1,1294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reusable neural network building blocks.
|
| 3 |
+
|
| 4 |
+
Provides shared network architectures for training scripts,
|
| 5 |
+
including MLP components and error predictors.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
# Building blocks
|
| 15 |
+
"ResidualBlock",
|
| 16 |
+
# Component networks (single output)
|
| 17 |
+
"ComponentMLP",
|
| 18 |
+
"ComponentResNet",
|
| 19 |
+
"ComponentErrorPredictor",
|
| 20 |
+
# Transformer building blocks
|
| 21 |
+
"FeatureTokenizer",
|
| 22 |
+
"TransformerBlock",
|
| 23 |
+
# Composite models: xyY → Munsell
|
| 24 |
+
"MLPToMunsell",
|
| 25 |
+
"MultiHeadMLPToMunsell",
|
| 26 |
+
"MultiMLPToMunsell",
|
| 27 |
+
"MultiResNetToMunsell",
|
| 28 |
+
"TransformerToMunsell",
|
| 29 |
+
# Error predictors: xyY → Munsell
|
| 30 |
+
"MultiHeadErrorPredictorToMunsell",
|
| 31 |
+
"MultiMLPErrorPredictorToMunsell",
|
| 32 |
+
"MultiResNetErrorPredictorToMunsell",
|
| 33 |
+
# Composite models: Munsell → xyY
|
| 34 |
+
"MultiMLPToxyY",
|
| 35 |
+
# Error predictors: Munsell → xyY
|
| 36 |
+
"MultiMLPErrorPredictorToxyY",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# =============================================================================
|
| 41 |
+
# Building Blocks
|
| 42 |
+
# =============================================================================
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ResidualBlock(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
Residual block with GELU activation and batch normalization.
|
| 48 |
+
|
| 49 |
+
Architecture: input → Linear → GELU → BatchNorm → Linear → BatchNorm → add input → GELU
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
dim : int
|
| 54 |
+
Dimension of input and output features.
|
| 55 |
+
|
| 56 |
+
Attributes
|
| 57 |
+
----------
|
| 58 |
+
block : nn.Sequential
|
| 59 |
+
Sequential block with linear layers, GELU, and BatchNorm.
|
| 60 |
+
activation : nn.GELU
|
| 61 |
+
Final activation after residual addition.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, dim: int) -> None:
|
| 65 |
+
"""Initialize residual block."""
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.block = nn.Sequential(
|
| 68 |
+
nn.Linear(dim, dim),
|
| 69 |
+
nn.GELU(),
|
| 70 |
+
nn.BatchNorm1d(dim),
|
| 71 |
+
nn.Linear(dim, dim),
|
| 72 |
+
nn.BatchNorm1d(dim),
|
| 73 |
+
)
|
| 74 |
+
self.activation = nn.GELU()
|
| 75 |
+
|
| 76 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Forward pass with residual connection.
|
| 79 |
+
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
x : Tensor
|
| 83 |
+
Input tensor of shape (batch_size, dim).
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
Tensor
|
| 88 |
+
Output tensor of shape (batch_size, dim).
|
| 89 |
+
"""
|
| 90 |
+
return self.activation(x + self.block(x))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# =============================================================================
|
| 94 |
+
# Component Networks (Single Output)
|
| 95 |
+
# =============================================================================
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ComponentMLP(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
Independent MLP for a single Munsell component.
|
| 101 |
+
|
| 102 |
+
Architecture: input_dim → 128 → 256 → 512 → 256 → 128 → 1
|
| 103 |
+
|
| 104 |
+
Parameters
|
| 105 |
+
----------
|
| 106 |
+
input_dim : int, optional
|
| 107 |
+
Input feature dimension. Default is 3 (for xyY).
|
| 108 |
+
width_multiplier : float, optional
|
| 109 |
+
Multiplier for hidden layer dimensions. Default is 1.0.
|
| 110 |
+
dropout : float, optional
|
| 111 |
+
Dropout probability between layers. Default is 0.0.
|
| 112 |
+
|
| 113 |
+
Attributes
|
| 114 |
+
----------
|
| 115 |
+
network : nn.Sequential
|
| 116 |
+
Feed-forward network with encoder-decoder structure.
|
| 117 |
+
|
| 118 |
+
Notes
|
| 119 |
+
-----
|
| 120 |
+
Uses ReLU activations and batch normalization. The encoder-decoder
|
| 121 |
+
architecture expands to 512-dim (or scaled by width_multiplier) and
|
| 122 |
+
then contracts back to a single output. Optional dropout can be
|
| 123 |
+
applied between layers for regularization.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
input_dim: int = 3,
|
| 129 |
+
width_multiplier: float = 1.0,
|
| 130 |
+
dropout: float = 0.0,
|
| 131 |
+
) -> None:
|
| 132 |
+
"""Initialize the component-specific MLP."""
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
# Scale hidden dimensions
|
| 136 |
+
h1 = int(128 * width_multiplier)
|
| 137 |
+
h2 = int(256 * width_multiplier)
|
| 138 |
+
h3 = int(512 * width_multiplier)
|
| 139 |
+
|
| 140 |
+
layers: list[nn.Module] = [
|
| 141 |
+
# Encoder
|
| 142 |
+
nn.Linear(input_dim, h1),
|
| 143 |
+
nn.ReLU(),
|
| 144 |
+
nn.BatchNorm1d(h1),
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
if dropout > 0:
|
| 148 |
+
layers.append(nn.Dropout(dropout))
|
| 149 |
+
|
| 150 |
+
layers.extend(
|
| 151 |
+
[
|
| 152 |
+
nn.Linear(h1, h2),
|
| 153 |
+
nn.ReLU(),
|
| 154 |
+
nn.BatchNorm1d(h2),
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if dropout > 0:
|
| 159 |
+
layers.append(nn.Dropout(dropout))
|
| 160 |
+
|
| 161 |
+
layers.extend(
|
| 162 |
+
[
|
| 163 |
+
nn.Linear(h2, h3),
|
| 164 |
+
nn.ReLU(),
|
| 165 |
+
nn.BatchNorm1d(h3),
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if dropout > 0:
|
| 170 |
+
layers.append(nn.Dropout(dropout))
|
| 171 |
+
|
| 172 |
+
layers.extend(
|
| 173 |
+
[
|
| 174 |
+
# Decoder
|
| 175 |
+
nn.Linear(h3, h2),
|
| 176 |
+
nn.ReLU(),
|
| 177 |
+
nn.BatchNorm1d(h2),
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if dropout > 0:
|
| 182 |
+
layers.append(nn.Dropout(dropout))
|
| 183 |
+
|
| 184 |
+
layers.extend(
|
| 185 |
+
[
|
| 186 |
+
nn.Linear(h2, h1),
|
| 187 |
+
nn.ReLU(),
|
| 188 |
+
nn.BatchNorm1d(h1),
|
| 189 |
+
# Output
|
| 190 |
+
nn.Linear(h1, 1),
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.network = nn.Sequential(*layers)
|
| 195 |
+
|
| 196 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Forward pass through the component-specific network.
|
| 199 |
+
|
| 200 |
+
Parameters
|
| 201 |
+
----------
|
| 202 |
+
x : Tensor
|
| 203 |
+
Input tensor of shape (batch_size, input_dim).
|
| 204 |
+
|
| 205 |
+
Returns
|
| 206 |
+
-------
|
| 207 |
+
Tensor
|
| 208 |
+
Output tensor of shape (batch_size, 1) containing the predicted
|
| 209 |
+
component value.
|
| 210 |
+
"""
|
| 211 |
+
return self.network(x)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ComponentResNet(nn.Module):
|
| 215 |
+
"""
|
| 216 |
+
Independent ResNet for a single Munsell component with true skip connections.
|
| 217 |
+
|
| 218 |
+
Architecture: input → projection → ResidualBlock × num_blocks → output
|
| 219 |
+
|
| 220 |
+
Unlike ComponentMLP, this uses actual residual blocks where:
|
| 221 |
+
output = activation(x + f(x))
|
| 222 |
+
|
| 223 |
+
Parameters
|
| 224 |
+
----------
|
| 225 |
+
input_dim : int, optional
|
| 226 |
+
Input feature dimension. Default is 3 (for xyY).
|
| 227 |
+
hidden_dim : int, optional
|
| 228 |
+
Hidden dimension for residual blocks. Default is 256.
|
| 229 |
+
num_blocks : int, optional
|
| 230 |
+
Number of residual blocks. Default is 4.
|
| 231 |
+
|
| 232 |
+
Attributes
|
| 233 |
+
----------
|
| 234 |
+
input_proj : nn.Sequential
|
| 235 |
+
Projects input to hidden dimension with GELU activation.
|
| 236 |
+
res_blocks : nn.ModuleList
|
| 237 |
+
List of ResidualBlock modules with skip connections.
|
| 238 |
+
output_proj : nn.Linear
|
| 239 |
+
Projects hidden dimension to single output.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
input_dim: int = 3,
|
| 245 |
+
hidden_dim: int = 256,
|
| 246 |
+
num_blocks: int = 4,
|
| 247 |
+
) -> None:
|
| 248 |
+
"""Initialize the component-specific ResNet."""
|
| 249 |
+
super().__init__()
|
| 250 |
+
|
| 251 |
+
# Project input to hidden dimension
|
| 252 |
+
self.input_proj = nn.Sequential(
|
| 253 |
+
nn.Linear(input_dim, hidden_dim),
|
| 254 |
+
nn.GELU(),
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Stack of residual blocks with skip connections
|
| 258 |
+
self.res_blocks = nn.ModuleList(
|
| 259 |
+
[ResidualBlock(hidden_dim) for _ in range(num_blocks)]
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Project to output
|
| 263 |
+
self.output_proj = nn.Linear(hidden_dim, 1)
|
| 264 |
+
|
| 265 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 266 |
+
"""
|
| 267 |
+
Forward pass through the ResNet with skip connections.
|
| 268 |
+
|
| 269 |
+
Parameters
|
| 270 |
+
----------
|
| 271 |
+
x : Tensor
|
| 272 |
+
Input tensor of shape (batch_size, input_dim).
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
Tensor
|
| 277 |
+
Output tensor of shape (batch_size, 1).
|
| 278 |
+
"""
|
| 279 |
+
x = self.input_proj(x)
|
| 280 |
+
for block in self.res_blocks:
|
| 281 |
+
x = block(x) # Each block applies: activation(x + f(x))
|
| 282 |
+
return self.output_proj(x)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class ComponentErrorPredictor(nn.Module):
|
| 286 |
+
"""
|
| 287 |
+
Independent error predictor for a single Munsell component.
|
| 288 |
+
|
| 289 |
+
A deep MLP that learns to predict residual errors for one Munsell
|
| 290 |
+
component (hue, value, chroma, or code).
|
| 291 |
+
|
| 292 |
+
Parameters
|
| 293 |
+
----------
|
| 294 |
+
input_dim : int, optional
|
| 295 |
+
Input feature dimension. Default is 7 (xyY_norm + base_pred_norm).
|
| 296 |
+
width_multiplier : float, optional
|
| 297 |
+
Multiplier for hidden layer widths. Default is 1.0.
|
| 298 |
+
Use 1.5 for chroma which requires more capacity.
|
| 299 |
+
|
| 300 |
+
Attributes
|
| 301 |
+
----------
|
| 302 |
+
network : nn.Sequential
|
| 303 |
+
Feed-forward network: input → 128 → 256 → 512 → 256 → 128 → 1
|
| 304 |
+
with GELU activations and BatchNorm after each hidden layer.
|
| 305 |
+
|
| 306 |
+
Notes
|
| 307 |
+
-----
|
| 308 |
+
Default input is [xyY_norm (3) + base_pred_norm (4)] = 7 features.
|
| 309 |
+
Output is a single scalar error correction for the component.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
def __init__(
|
| 313 |
+
self,
|
| 314 |
+
input_dim: int = 7,
|
| 315 |
+
width_multiplier: float = 1.0,
|
| 316 |
+
) -> None:
|
| 317 |
+
"""Initialize the error predictor."""
|
| 318 |
+
super().__init__()
|
| 319 |
+
|
| 320 |
+
# Scale hidden dimensions
|
| 321 |
+
h1 = int(128 * width_multiplier)
|
| 322 |
+
h2 = int(256 * width_multiplier)
|
| 323 |
+
h3 = int(512 * width_multiplier)
|
| 324 |
+
|
| 325 |
+
self.network = nn.Sequential(
|
| 326 |
+
# Encoder
|
| 327 |
+
nn.Linear(input_dim, h1),
|
| 328 |
+
nn.GELU(),
|
| 329 |
+
nn.BatchNorm1d(h1),
|
| 330 |
+
nn.Linear(h1, h2),
|
| 331 |
+
nn.GELU(),
|
| 332 |
+
nn.BatchNorm1d(h2),
|
| 333 |
+
nn.Linear(h2, h3),
|
| 334 |
+
nn.GELU(),
|
| 335 |
+
nn.BatchNorm1d(h3),
|
| 336 |
+
# Decoder
|
| 337 |
+
nn.Linear(h3, h2),
|
| 338 |
+
nn.GELU(),
|
| 339 |
+
nn.BatchNorm1d(h2),
|
| 340 |
+
nn.Linear(h2, h1),
|
| 341 |
+
nn.GELU(),
|
| 342 |
+
nn.BatchNorm1d(h1),
|
| 343 |
+
# Output
|
| 344 |
+
nn.Linear(h1, 1),
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 348 |
+
"""
|
| 349 |
+
Forward pass through the error predictor.
|
| 350 |
+
|
| 351 |
+
Parameters
|
| 352 |
+
----------
|
| 353 |
+
x : Tensor
|
| 354 |
+
Combined input of shape (batch_size, input_dim).
|
| 355 |
+
|
| 356 |
+
Returns
|
| 357 |
+
-------
|
| 358 |
+
Tensor
|
| 359 |
+
Predicted error correction of shape (batch_size, 1).
|
| 360 |
+
"""
|
| 361 |
+
return self.network(x)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# =============================================================================
|
| 365 |
+
# Transformer Building Blocks
|
| 366 |
+
# =============================================================================
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class FeatureTokenizer(nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
Tokenize each input feature into high-dimensional embedding.
|
| 372 |
+
|
| 373 |
+
Converts each scalar input feature into a learned embedding vector,
|
| 374 |
+
similar to word embeddings in NLP. Also prepends a learnable CLS token
|
| 375 |
+
used for regression output.
|
| 376 |
+
|
| 377 |
+
Parameters
|
| 378 |
+
----------
|
| 379 |
+
num_features : int
|
| 380 |
+
Number of input features to tokenize.
|
| 381 |
+
embedding_dim : int
|
| 382 |
+
Dimensionality of each token embedding.
|
| 383 |
+
|
| 384 |
+
Attributes
|
| 385 |
+
----------
|
| 386 |
+
feature_embeddings : nn.ModuleList
|
| 387 |
+
List of linear layers, one per input feature.
|
| 388 |
+
cls_token : nn.Parameter
|
| 389 |
+
Learnable classification token prepended to feature tokens.
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
def __init__(self, num_features: int, embedding_dim: int) -> None:
|
| 393 |
+
"""Initialize the feature tokenizer."""
|
| 394 |
+
super().__init__()
|
| 395 |
+
# Each feature gets its own embedding
|
| 396 |
+
self.feature_embeddings = nn.ModuleList(
|
| 397 |
+
[nn.Linear(1, embedding_dim) for _ in range(num_features)]
|
| 398 |
+
)
|
| 399 |
+
# CLS token for regression
|
| 400 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
|
| 401 |
+
|
| 402 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 403 |
+
"""
|
| 404 |
+
Transform input features into token embeddings.
|
| 405 |
+
|
| 406 |
+
Parameters
|
| 407 |
+
----------
|
| 408 |
+
x : Tensor
|
| 409 |
+
Input tensor of shape (batch_size, num_features).
|
| 410 |
+
|
| 411 |
+
Returns
|
| 412 |
+
-------
|
| 413 |
+
Tensor
|
| 414 |
+
Token embeddings of shape (batch_size, 1+num_features, embedding_dim).
|
| 415 |
+
First token is CLS, followed by feature tokens.
|
| 416 |
+
"""
|
| 417 |
+
batch_size = x.size(0)
|
| 418 |
+
|
| 419 |
+
# Tokenize each feature
|
| 420 |
+
tokens = []
|
| 421 |
+
for i, embedding in enumerate(self.feature_embeddings):
|
| 422 |
+
feature_val = x[:, i : i + 1] # (batch_size, 1)
|
| 423 |
+
token = embedding(feature_val) # (batch_size, embedding_dim)
|
| 424 |
+
tokens.append(token.unsqueeze(1)) # (batch_size, 1, embedding_dim)
|
| 425 |
+
|
| 426 |
+
# Concatenate feature tokens
|
| 427 |
+
feature_tokens = torch.cat(
|
| 428 |
+
tokens, dim=1
|
| 429 |
+
) # (batch_size, num_features, embedding_dim)
|
| 430 |
+
|
| 431 |
+
# Prepend CLS token
|
| 432 |
+
cls_tokens = self.cls_token.expand(
|
| 433 |
+
batch_size, -1, -1
|
| 434 |
+
) # (batch_size, 1, embedding_dim)
|
| 435 |
+
return torch.cat(
|
| 436 |
+
[cls_tokens, feature_tokens], dim=1
|
| 437 |
+
) # (batch_size, 1+num_features, embedding_dim)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class TransformerBlock(nn.Module):
|
| 441 |
+
"""
|
| 442 |
+
Standard transformer block with multi-head attention and feedforward network.
|
| 443 |
+
|
| 444 |
+
Implements the classic transformer architecture with self-attention,
|
| 445 |
+
feedforward layers, layer normalization, and residual connections.
|
| 446 |
+
|
| 447 |
+
Parameters
|
| 448 |
+
----------
|
| 449 |
+
embedding_dim : int
|
| 450 |
+
Dimension of token embeddings.
|
| 451 |
+
num_heads : int
|
| 452 |
+
Number of attention heads.
|
| 453 |
+
ff_dim : int
|
| 454 |
+
Hidden dimension of feedforward network.
|
| 455 |
+
dropout : float, optional
|
| 456 |
+
Dropout probability, default is 0.1.
|
| 457 |
+
|
| 458 |
+
Attributes
|
| 459 |
+
----------
|
| 460 |
+
attention : nn.MultiheadAttention
|
| 461 |
+
Multi-head self-attention mechanism.
|
| 462 |
+
norm1 : nn.LayerNorm
|
| 463 |
+
Layer normalization after attention.
|
| 464 |
+
feedforward : nn.Sequential
|
| 465 |
+
Feedforward network with GELU activation.
|
| 466 |
+
norm2 : nn.LayerNorm
|
| 467 |
+
Layer normalization after feedforward.
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
def __init__(
|
| 471 |
+
self, embedding_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1
|
| 472 |
+
) -> None:
|
| 473 |
+
"""Initialize the transformer block."""
|
| 474 |
+
super().__init__()
|
| 475 |
+
|
| 476 |
+
self.attention = nn.MultiheadAttention(
|
| 477 |
+
embedding_dim, num_heads, dropout=dropout, batch_first=True
|
| 478 |
+
)
|
| 479 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 480 |
+
|
| 481 |
+
self.feedforward = nn.Sequential(
|
| 482 |
+
nn.Linear(embedding_dim, ff_dim),
|
| 483 |
+
nn.GELU(),
|
| 484 |
+
nn.Dropout(dropout),
|
| 485 |
+
nn.Linear(ff_dim, embedding_dim),
|
| 486 |
+
nn.Dropout(dropout),
|
| 487 |
+
)
|
| 488 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 489 |
+
|
| 490 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 491 |
+
"""
|
| 492 |
+
Apply transformer block to input tokens.
|
| 493 |
+
|
| 494 |
+
Parameters
|
| 495 |
+
----------
|
| 496 |
+
x : Tensor
|
| 497 |
+
Input tokens of shape (batch_size, num_tokens, embedding_dim).
|
| 498 |
+
|
| 499 |
+
Returns
|
| 500 |
+
-------
|
| 501 |
+
Tensor
|
| 502 |
+
Transformed tokens of shape (batch_size, num_tokens, embedding_dim).
|
| 503 |
+
"""
|
| 504 |
+
# Self-attention with residual
|
| 505 |
+
attn_output, _ = self.attention(x, x, x)
|
| 506 |
+
x = self.norm1(x + attn_output)
|
| 507 |
+
|
| 508 |
+
# Feedforward with residual
|
| 509 |
+
ff_output = self.feedforward(x)
|
| 510 |
+
return self.norm2(x + ff_output)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# =============================================================================
|
| 514 |
+
# Composite Models: xyY → Munsell
|
| 515 |
+
# =============================================================================
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class MLPToMunsell(nn.Module):
|
| 519 |
+
"""
|
| 520 |
+
Large MLP for xyY to Munsell conversion.
|
| 521 |
+
|
| 522 |
+
Architecture: 3 → 128 → 256 → 512 → 512 → 256 → 128 → 4
|
| 523 |
+
|
| 524 |
+
Attributes
|
| 525 |
+
----------
|
| 526 |
+
network : nn.Sequential
|
| 527 |
+
Feed-forward network with ReLU activations and BatchNorm.
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(self) -> None:
|
| 531 |
+
"""Initialize the MunsellMLP network."""
|
| 532 |
+
super().__init__()
|
| 533 |
+
|
| 534 |
+
self.network = nn.Sequential(
|
| 535 |
+
nn.Linear(3, 128),
|
| 536 |
+
nn.ReLU(),
|
| 537 |
+
nn.BatchNorm1d(128),
|
| 538 |
+
nn.Linear(128, 256),
|
| 539 |
+
nn.ReLU(),
|
| 540 |
+
nn.BatchNorm1d(256),
|
| 541 |
+
nn.Linear(256, 512),
|
| 542 |
+
nn.ReLU(),
|
| 543 |
+
nn.BatchNorm1d(512),
|
| 544 |
+
nn.Linear(512, 512),
|
| 545 |
+
nn.ReLU(),
|
| 546 |
+
nn.BatchNorm1d(512),
|
| 547 |
+
nn.Linear(512, 256),
|
| 548 |
+
nn.ReLU(),
|
| 549 |
+
nn.BatchNorm1d(256),
|
| 550 |
+
nn.Linear(256, 128),
|
| 551 |
+
nn.ReLU(),
|
| 552 |
+
nn.BatchNorm1d(128),
|
| 553 |
+
nn.Linear(128, 4),
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 557 |
+
"""
|
| 558 |
+
Forward pass through the network.
|
| 559 |
+
|
| 560 |
+
Parameters
|
| 561 |
+
----------
|
| 562 |
+
x : Tensor
|
| 563 |
+
Input tensor of shape (batch_size, 3) containing normalized xyY values.
|
| 564 |
+
|
| 565 |
+
Returns
|
| 566 |
+
-------
|
| 567 |
+
Tensor
|
| 568 |
+
Output tensor of shape (batch_size, 4) containing normalized Munsell
|
| 569 |
+
specifications [hue, value, chroma, code].
|
| 570 |
+
"""
|
| 571 |
+
return self.network(x)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class MultiHeadMLPToMunsell(nn.Module):
|
| 575 |
+
"""
|
| 576 |
+
Multi-head MLP for xyY to Munsell conversion.
|
| 577 |
+
|
| 578 |
+
Each component (hue, value, chroma, code) has a specialized decoder head
|
| 579 |
+
after a shared encoder. The chroma head is wider to handle the more complex
|
| 580 |
+
non-linear relationship between xyY and chroma.
|
| 581 |
+
|
| 582 |
+
Attributes
|
| 583 |
+
----------
|
| 584 |
+
encoder : nn.Sequential
|
| 585 |
+
Shared encoder: 3 → 128 → 256 → 512 with ReLU and BatchNorm.
|
| 586 |
+
hue_head : nn.Sequential
|
| 587 |
+
Hue decoder: 512 → 256 → 128 → 1 (circular component).
|
| 588 |
+
value_head : nn.Sequential
|
| 589 |
+
Value decoder: 512 → 256 → 128 → 1 (linear component).
|
| 590 |
+
chroma_head : nn.Sequential
|
| 591 |
+
Chroma decoder: 512 → 384 → 256 → 128 → 1 (wider for complexity).
|
| 592 |
+
code_head : nn.Sequential
|
| 593 |
+
Code decoder: 512 → 256 → 128 → 1 (discrete component).
|
| 594 |
+
|
| 595 |
+
Notes
|
| 596 |
+
-----
|
| 597 |
+
The chroma head has increased capacity (384 units in first layer) to handle
|
| 598 |
+
the more complex non-linear relationship between xyY and chroma.
|
| 599 |
+
"""
|
| 600 |
+
|
| 601 |
+
def __init__(self) -> None:
|
| 602 |
+
"""Initialize the multi-head MLP model."""
|
| 603 |
+
super().__init__()
|
| 604 |
+
|
| 605 |
+
# Shared encoder - learns general color space features
|
| 606 |
+
self.encoder = nn.Sequential(
|
| 607 |
+
nn.Linear(3, 128),
|
| 608 |
+
nn.ReLU(),
|
| 609 |
+
nn.BatchNorm1d(128),
|
| 610 |
+
nn.Linear(128, 256),
|
| 611 |
+
nn.ReLU(),
|
| 612 |
+
nn.BatchNorm1d(256),
|
| 613 |
+
nn.Linear(256, 512),
|
| 614 |
+
nn.ReLU(),
|
| 615 |
+
nn.BatchNorm1d(512),
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# Hue head - circular/angular component
|
| 619 |
+
self.hue_head = nn.Sequential(
|
| 620 |
+
nn.Linear(512, 256),
|
| 621 |
+
nn.ReLU(),
|
| 622 |
+
nn.BatchNorm1d(256),
|
| 623 |
+
nn.Linear(256, 128),
|
| 624 |
+
nn.ReLU(),
|
| 625 |
+
nn.BatchNorm1d(128),
|
| 626 |
+
nn.Linear(128, 1),
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# Value head - linear lightness
|
| 630 |
+
self.value_head = nn.Sequential(
|
| 631 |
+
nn.Linear(512, 256),
|
| 632 |
+
nn.ReLU(),
|
| 633 |
+
nn.BatchNorm1d(256),
|
| 634 |
+
nn.Linear(256, 128),
|
| 635 |
+
nn.ReLU(),
|
| 636 |
+
nn.BatchNorm1d(128),
|
| 637 |
+
nn.Linear(128, 1),
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Chroma head - non-linear saturation (WIDER for harder task)
|
| 641 |
+
self.chroma_head = nn.Sequential(
|
| 642 |
+
nn.Linear(512, 384), # Wider than other heads
|
| 643 |
+
nn.ReLU(),
|
| 644 |
+
nn.BatchNorm1d(384),
|
| 645 |
+
nn.Linear(384, 256),
|
| 646 |
+
nn.ReLU(),
|
| 647 |
+
nn.BatchNorm1d(256),
|
| 648 |
+
nn.Linear(256, 128),
|
| 649 |
+
nn.ReLU(),
|
| 650 |
+
nn.BatchNorm1d(128),
|
| 651 |
+
nn.Linear(128, 1),
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# Code head - discrete categorical
|
| 655 |
+
self.code_head = nn.Sequential(
|
| 656 |
+
nn.Linear(512, 256),
|
| 657 |
+
nn.ReLU(),
|
| 658 |
+
nn.BatchNorm1d(256),
|
| 659 |
+
nn.Linear(256, 128),
|
| 660 |
+
nn.ReLU(),
|
| 661 |
+
nn.BatchNorm1d(128),
|
| 662 |
+
nn.Linear(128, 1),
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 666 |
+
"""
|
| 667 |
+
Forward pass through the multi-head network.
|
| 668 |
+
|
| 669 |
+
Parameters
|
| 670 |
+
----------
|
| 671 |
+
x : Tensor
|
| 672 |
+
Input xyY values of shape (batch_size, 3).
|
| 673 |
+
|
| 674 |
+
Returns
|
| 675 |
+
-------
|
| 676 |
+
Tensor
|
| 677 |
+
Concatenated Munsell predictions [hue, value, chroma, code]
|
| 678 |
+
of shape (batch_size, 4).
|
| 679 |
+
"""
|
| 680 |
+
# Shared feature extraction
|
| 681 |
+
features = self.encoder(x)
|
| 682 |
+
|
| 683 |
+
# Component-specific predictions
|
| 684 |
+
hue = self.hue_head(features)
|
| 685 |
+
value = self.value_head(features)
|
| 686 |
+
chroma = self.chroma_head(features)
|
| 687 |
+
code = self.code_head(features)
|
| 688 |
+
|
| 689 |
+
# Concatenate: [Hue, Value, Chroma, Code]
|
| 690 |
+
return torch.cat([hue, value, chroma, code], dim=1)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class MultiMLPToMunsell(nn.Module):
|
| 694 |
+
"""
|
| 695 |
+
Multi-MLP for xyY to Munsell conversion.
|
| 696 |
+
|
| 697 |
+
Uses 4 independent ComponentMLP branches, one for each Munsell component.
|
| 698 |
+
The chroma branch can be wider to handle the more complex relationship.
|
| 699 |
+
|
| 700 |
+
Parameters
|
| 701 |
+
----------
|
| 702 |
+
chroma_width_multiplier : float, optional
|
| 703 |
+
Width multiplier for the chroma branch. Default is 2.0.
|
| 704 |
+
dropout : float, optional
|
| 705 |
+
Dropout probability for all branches. Default is 0.1.
|
| 706 |
+
|
| 707 |
+
Attributes
|
| 708 |
+
----------
|
| 709 |
+
hue_branch : ComponentMLP
|
| 710 |
+
MLP for hue component (1.0x width).
|
| 711 |
+
value_branch : ComponentMLP
|
| 712 |
+
MLP for value component (1.0x width).
|
| 713 |
+
chroma_branch : ComponentMLP
|
| 714 |
+
MLP for chroma component (configurable width).
|
| 715 |
+
code_branch : ComponentMLP
|
| 716 |
+
MLP for hue code component (1.0x width).
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
def __init__(
|
| 720 |
+
self, chroma_width_multiplier: float = 2.0, dropout: float = 0.1
|
| 721 |
+
) -> None:
|
| 722 |
+
"""Initialize the multi-branch MLP model."""
|
| 723 |
+
super().__init__()
|
| 724 |
+
|
| 725 |
+
self.hue_branch = ComponentMLP(
|
| 726 |
+
input_dim=3, width_multiplier=1.0, dropout=dropout
|
| 727 |
+
)
|
| 728 |
+
self.value_branch = ComponentMLP(
|
| 729 |
+
input_dim=3, width_multiplier=1.0, dropout=dropout
|
| 730 |
+
)
|
| 731 |
+
self.chroma_branch = ComponentMLP(
|
| 732 |
+
input_dim=3, width_multiplier=chroma_width_multiplier, dropout=dropout
|
| 733 |
+
)
|
| 734 |
+
self.code_branch = ComponentMLP(
|
| 735 |
+
input_dim=3, width_multiplier=1.0, dropout=dropout
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 739 |
+
"""
|
| 740 |
+
Forward pass through all 4 independent branches.
|
| 741 |
+
|
| 742 |
+
Parameters
|
| 743 |
+
----------
|
| 744 |
+
x : Tensor
|
| 745 |
+
Input tensor of shape (batch_size, 3) containing normalized xyY values.
|
| 746 |
+
|
| 747 |
+
Returns
|
| 748 |
+
-------
|
| 749 |
+
Tensor
|
| 750 |
+
Concatenated predictions [hue, value, chroma, code]
|
| 751 |
+
of shape (batch_size, 4).
|
| 752 |
+
"""
|
| 753 |
+
hue = self.hue_branch(x)
|
| 754 |
+
value = self.value_branch(x)
|
| 755 |
+
chroma = self.chroma_branch(x)
|
| 756 |
+
code = self.code_branch(x)
|
| 757 |
+
return torch.cat([hue, value, chroma, code], dim=1)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
class MultiResNetToMunsell(nn.Module):
|
| 761 |
+
"""
|
| 762 |
+
Multi-ResNet for xyY to Munsell conversion with true skip connections.
|
| 763 |
+
|
| 764 |
+
Uses 4 independent ComponentResNet branches, one for each Munsell component.
|
| 765 |
+
Each branch contains actual residual blocks with skip connections.
|
| 766 |
+
|
| 767 |
+
Parameters
|
| 768 |
+
----------
|
| 769 |
+
hidden_dim : int, optional
|
| 770 |
+
Hidden dimension for residual blocks. Default is 256.
|
| 771 |
+
num_blocks : int, optional
|
| 772 |
+
Number of residual blocks per branch. Default is 4.
|
| 773 |
+
chroma_hidden_dim : int, optional
|
| 774 |
+
Hidden dimension for chroma branch (typically larger). Default is 512.
|
| 775 |
+
|
| 776 |
+
Attributes
|
| 777 |
+
----------
|
| 778 |
+
hue_branch : ComponentResNet
|
| 779 |
+
ResNet for hue component.
|
| 780 |
+
value_branch : ComponentResNet
|
| 781 |
+
ResNet for value component.
|
| 782 |
+
chroma_branch : ComponentResNet
|
| 783 |
+
ResNet for chroma component (larger hidden dim).
|
| 784 |
+
code_branch : ComponentResNet
|
| 785 |
+
ResNet for hue code component.
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
def __init__(
|
| 789 |
+
self,
|
| 790 |
+
hidden_dim: int = 256,
|
| 791 |
+
num_blocks: int = 4,
|
| 792 |
+
chroma_hidden_dim: int = 512,
|
| 793 |
+
) -> None:
|
| 794 |
+
"""Initialize the multi-branch ResNet model."""
|
| 795 |
+
super().__init__()
|
| 796 |
+
|
| 797 |
+
self.hue_branch = ComponentResNet(
|
| 798 |
+
input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
|
| 799 |
+
)
|
| 800 |
+
self.value_branch = ComponentResNet(
|
| 801 |
+
input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
|
| 802 |
+
)
|
| 803 |
+
self.chroma_branch = ComponentResNet(
|
| 804 |
+
input_dim=3, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks
|
| 805 |
+
)
|
| 806 |
+
self.code_branch = ComponentResNet(
|
| 807 |
+
input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 811 |
+
"""
|
| 812 |
+
Forward pass through all 4 independent ResNet branches.
|
| 813 |
+
|
| 814 |
+
Parameters
|
| 815 |
+
----------
|
| 816 |
+
x : Tensor
|
| 817 |
+
Input tensor of shape (batch_size, 3) containing normalized xyY values.
|
| 818 |
+
|
| 819 |
+
Returns
|
| 820 |
+
-------
|
| 821 |
+
Tensor
|
| 822 |
+
Concatenated predictions [hue, value, chroma, code]
|
| 823 |
+
of shape (batch_size, 4).
|
| 824 |
+
"""
|
| 825 |
+
hue = self.hue_branch(x)
|
| 826 |
+
value = self.value_branch(x)
|
| 827 |
+
chroma = self.chroma_branch(x)
|
| 828 |
+
code = self.code_branch(x)
|
| 829 |
+
return torch.cat([hue, value, chroma, code], dim=1)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
class TransformerToMunsell(nn.Module):
|
| 833 |
+
"""
|
| 834 |
+
Transformer for xyY to Munsell conversion.
|
| 835 |
+
|
| 836 |
+
Uses a feature tokenizer to convert input features to embeddings,
|
| 837 |
+
followed by transformer blocks with self-attention, and separate
|
| 838 |
+
output heads for each Munsell component.
|
| 839 |
+
|
| 840 |
+
Parameters
|
| 841 |
+
----------
|
| 842 |
+
num_features : int, optional
|
| 843 |
+
Number of input features (default is 3 for xyY).
|
| 844 |
+
embedding_dim : int, optional
|
| 845 |
+
Dimension of token embeddings (default is 256).
|
| 846 |
+
num_blocks : int, optional
|
| 847 |
+
Number of transformer blocks (default is 6).
|
| 848 |
+
num_heads : int, optional
|
| 849 |
+
Number of attention heads (default is 8).
|
| 850 |
+
ff_dim : int, optional
|
| 851 |
+
Feedforward network hidden dimension (default is 1024).
|
| 852 |
+
dropout : float, optional
|
| 853 |
+
Dropout probability (default is 0.1).
|
| 854 |
+
|
| 855 |
+
Attributes
|
| 856 |
+
----------
|
| 857 |
+
tokenizer : FeatureTokenizer
|
| 858 |
+
Converts input features to token embeddings with CLS token.
|
| 859 |
+
transformer_blocks : nn.ModuleList
|
| 860 |
+
Stack of transformer blocks with self-attention.
|
| 861 |
+
final_norm : nn.LayerNorm
|
| 862 |
+
Final layer normalization before output heads.
|
| 863 |
+
hue_head : nn.Sequential
|
| 864 |
+
Output head for hue prediction.
|
| 865 |
+
value_head : nn.Sequential
|
| 866 |
+
Output head for value prediction.
|
| 867 |
+
chroma_head : nn.Sequential
|
| 868 |
+
Deeper output head for chroma prediction.
|
| 869 |
+
code_head : nn.Sequential
|
| 870 |
+
Output head for hue code prediction.
|
| 871 |
+
|
| 872 |
+
Notes
|
| 873 |
+
-----
|
| 874 |
+
Architecture: 3 xyY features → 3 tokens + 1 CLS token → transformer blocks
|
| 875 |
+
with self-attention → multi-head output with specialized component heads.
|
| 876 |
+
The chroma head has additional depth due to prediction difficulty.
|
| 877 |
+
"""
|
| 878 |
+
|
| 879 |
+
def __init__(
|
| 880 |
+
self,
|
| 881 |
+
num_features: int = 3,
|
| 882 |
+
embedding_dim: int = 256,
|
| 883 |
+
num_blocks: int = 6,
|
| 884 |
+
num_heads: int = 8,
|
| 885 |
+
ff_dim: int = 1024,
|
| 886 |
+
dropout: float = 0.1,
|
| 887 |
+
) -> None:
|
| 888 |
+
"""Initialize the transformer model."""
|
| 889 |
+
super().__init__()
|
| 890 |
+
|
| 891 |
+
self.tokenizer = FeatureTokenizer(num_features, embedding_dim)
|
| 892 |
+
|
| 893 |
+
self.transformer_blocks = nn.ModuleList(
|
| 894 |
+
[
|
| 895 |
+
TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
|
| 896 |
+
for _ in range(num_blocks)
|
| 897 |
+
]
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
self.final_norm = nn.LayerNorm(embedding_dim)
|
| 901 |
+
|
| 902 |
+
# Multi-head output - separate heads for each Munsell component
|
| 903 |
+
self.hue_head = nn.Sequential(
|
| 904 |
+
nn.Linear(embedding_dim, 128),
|
| 905 |
+
nn.GELU(),
|
| 906 |
+
nn.Dropout(dropout),
|
| 907 |
+
nn.Linear(128, 1),
|
| 908 |
+
)
|
| 909 |
+
self.value_head = nn.Sequential(
|
| 910 |
+
nn.Linear(embedding_dim, 128),
|
| 911 |
+
nn.GELU(),
|
| 912 |
+
nn.Dropout(dropout),
|
| 913 |
+
nn.Linear(128, 1),
|
| 914 |
+
)
|
| 915 |
+
self.chroma_head = nn.Sequential(
|
| 916 |
+
nn.Linear(embedding_dim, 256),
|
| 917 |
+
nn.GELU(),
|
| 918 |
+
nn.Dropout(dropout),
|
| 919 |
+
nn.Linear(256, 128),
|
| 920 |
+
nn.GELU(),
|
| 921 |
+
nn.Linear(128, 1),
|
| 922 |
+
)
|
| 923 |
+
self.code_head = nn.Sequential(
|
| 924 |
+
nn.Linear(embedding_dim, 128),
|
| 925 |
+
nn.GELU(),
|
| 926 |
+
nn.Dropout(dropout),
|
| 927 |
+
nn.Linear(128, 1),
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 931 |
+
"""
|
| 932 |
+
Forward pass through the transformer.
|
| 933 |
+
|
| 934 |
+
Parameters
|
| 935 |
+
----------
|
| 936 |
+
x : Tensor
|
| 937 |
+
Input xyY values of shape (batch_size, 3).
|
| 938 |
+
|
| 939 |
+
Returns
|
| 940 |
+
-------
|
| 941 |
+
Tensor
|
| 942 |
+
Predicted Munsell specification [hue, value, chroma, code]
|
| 943 |
+
of shape (batch_size, 4).
|
| 944 |
+
|
| 945 |
+
Notes
|
| 946 |
+
-----
|
| 947 |
+
The CLS token representation is used for the final prediction through
|
| 948 |
+
separate task-specific heads for each Munsell component.
|
| 949 |
+
"""
|
| 950 |
+
tokens = self.tokenizer(x)
|
| 951 |
+
|
| 952 |
+
for block in self.transformer_blocks:
|
| 953 |
+
tokens = block(tokens)
|
| 954 |
+
|
| 955 |
+
tokens = self.final_norm(tokens)
|
| 956 |
+
cls_token = tokens[:, 0, :]
|
| 957 |
+
|
| 958 |
+
hue = self.hue_head(cls_token)
|
| 959 |
+
value = self.value_head(cls_token)
|
| 960 |
+
chroma = self.chroma_head(cls_token)
|
| 961 |
+
code = self.code_head(cls_token)
|
| 962 |
+
|
| 963 |
+
return torch.cat([hue, value, chroma, code], dim=1)
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
# =============================================================================
|
| 967 |
+
# Error Predictors: xyY → Munsell
|
| 968 |
+
# =============================================================================
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class MultiHeadErrorPredictorToMunsell(nn.Module):
|
| 972 |
+
"""
|
| 973 |
+
Multi-Head error predictor for xyY to Munsell conversion.
|
| 974 |
+
|
| 975 |
+
Each branch is a ComponentErrorPredictor specialized for one
|
| 976 |
+
Munsell component. The chroma branch is wider (1.5x) to handle
|
| 977 |
+
the more complex error patterns in chroma prediction.
|
| 978 |
+
|
| 979 |
+
Parameters
|
| 980 |
+
----------
|
| 981 |
+
input_dim : int, optional
|
| 982 |
+
Input feature dimension. Default is 7.
|
| 983 |
+
chroma_width : float, optional
|
| 984 |
+
Width multiplier for chroma branch. Default is 1.5.
|
| 985 |
+
|
| 986 |
+
Attributes
|
| 987 |
+
----------
|
| 988 |
+
hue_branch : ComponentErrorPredictor
|
| 989 |
+
Error predictor for hue component (1.0x width).
|
| 990 |
+
value_branch : ComponentErrorPredictor
|
| 991 |
+
Error predictor for value component (1.0x width).
|
| 992 |
+
chroma_branch : ComponentErrorPredictor
|
| 993 |
+
Error predictor for chroma component (1.5x width by default).
|
| 994 |
+
code_branch : ComponentErrorPredictor
|
| 995 |
+
Error predictor for hue code component (1.0x width).
|
| 996 |
+
"""
|
| 997 |
+
|
| 998 |
+
def __init__(
|
| 999 |
+
self,
|
| 1000 |
+
input_dim: int = 7,
|
| 1001 |
+
chroma_width: float = 1.5,
|
| 1002 |
+
) -> None:
|
| 1003 |
+
"""Initialize the multi-head error predictor."""
|
| 1004 |
+
super().__init__()
|
| 1005 |
+
|
| 1006 |
+
# Independent error predictor for each component
|
| 1007 |
+
self.hue_branch = ComponentErrorPredictor(
|
| 1008 |
+
input_dim=input_dim, width_multiplier=1.0
|
| 1009 |
+
)
|
| 1010 |
+
self.value_branch = ComponentErrorPredictor(
|
| 1011 |
+
input_dim=input_dim, width_multiplier=1.0
|
| 1012 |
+
)
|
| 1013 |
+
self.chroma_branch = ComponentErrorPredictor(
|
| 1014 |
+
input_dim=input_dim, width_multiplier=chroma_width
|
| 1015 |
+
)
|
| 1016 |
+
self.code_branch = ComponentErrorPredictor(
|
| 1017 |
+
input_dim=input_dim, width_multiplier=1.0
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1021 |
+
"""
|
| 1022 |
+
Forward pass through all error predictor branches.
|
| 1023 |
+
|
| 1024 |
+
Parameters
|
| 1025 |
+
----------
|
| 1026 |
+
x : Tensor
|
| 1027 |
+
Combined input of shape (batch_size, input_dim).
|
| 1028 |
+
|
| 1029 |
+
Returns
|
| 1030 |
+
-------
|
| 1031 |
+
Tensor
|
| 1032 |
+
Concatenated error corrections [hue, value, chroma, code]
|
| 1033 |
+
of shape (batch_size, 4).
|
| 1034 |
+
"""
|
| 1035 |
+
# Each branch processes the same combined input independently
|
| 1036 |
+
hue_error = self.hue_branch(x)
|
| 1037 |
+
value_error = self.value_branch(x)
|
| 1038 |
+
chroma_error = self.chroma_branch(x)
|
| 1039 |
+
code_error = self.code_branch(x)
|
| 1040 |
+
|
| 1041 |
+
# Concatenate: [Hue_error, Value_error, Chroma_error, Code_error]
|
| 1042 |
+
return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
class MultiMLPErrorPredictorToMunsell(nn.Module):
|
| 1046 |
+
"""
|
| 1047 |
+
Multi-MLP error predictor for xyY to Munsell conversion.
|
| 1048 |
+
|
| 1049 |
+
Uses 4 independent ComponentErrorPredictor branches, one for each
|
| 1050 |
+
Munsell component error.
|
| 1051 |
+
|
| 1052 |
+
Parameters
|
| 1053 |
+
----------
|
| 1054 |
+
chroma_width : float, optional
|
| 1055 |
+
Width multiplier for chroma branch. Default is 1.5.
|
| 1056 |
+
|
| 1057 |
+
Attributes
|
| 1058 |
+
----------
|
| 1059 |
+
hue_branch : ComponentErrorPredictor
|
| 1060 |
+
Error predictor for hue component (1.0x width).
|
| 1061 |
+
value_branch : ComponentErrorPredictor
|
| 1062 |
+
Error predictor for value component (1.0x width).
|
| 1063 |
+
chroma_branch : ComponentErrorPredictor
|
| 1064 |
+
Error predictor for chroma component (configurable width).
|
| 1065 |
+
code_branch : ComponentErrorPredictor
|
| 1066 |
+
Error predictor for hue code component (1.0x width).
|
| 1067 |
+
"""
|
| 1068 |
+
|
| 1069 |
+
def __init__(self, chroma_width: float = 1.5) -> None:
|
| 1070 |
+
"""Initialize the multi-head error predictor."""
|
| 1071 |
+
super().__init__()
|
| 1072 |
+
|
| 1073 |
+
self.hue_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
|
| 1074 |
+
self.value_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
|
| 1075 |
+
self.chroma_branch = ComponentErrorPredictor(
|
| 1076 |
+
input_dim=7, width_multiplier=chroma_width
|
| 1077 |
+
)
|
| 1078 |
+
self.code_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
|
| 1079 |
+
|
| 1080 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1081 |
+
"""
|
| 1082 |
+
Forward pass through all error predictor branches.
|
| 1083 |
+
|
| 1084 |
+
Parameters
|
| 1085 |
+
----------
|
| 1086 |
+
x : Tensor
|
| 1087 |
+
Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
|
| 1088 |
+
|
| 1089 |
+
Returns
|
| 1090 |
+
-------
|
| 1091 |
+
Tensor
|
| 1092 |
+
Concatenated error corrections [hue, value, chroma, code]
|
| 1093 |
+
of shape (batch_size, 4).
|
| 1094 |
+
"""
|
| 1095 |
+
hue_error = self.hue_branch(x)
|
| 1096 |
+
value_error = self.value_branch(x)
|
| 1097 |
+
chroma_error = self.chroma_branch(x)
|
| 1098 |
+
code_error = self.code_branch(x)
|
| 1099 |
+
return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
class MultiResNetErrorPredictorToMunsell(nn.Module):
|
| 1103 |
+
"""
|
| 1104 |
+
Multi-ResNet error predictor for xyY to Munsell conversion.
|
| 1105 |
+
|
| 1106 |
+
Uses 4 independent ComponentResNet branches with true skip connections,
|
| 1107 |
+
one for each Munsell component error.
|
| 1108 |
+
|
| 1109 |
+
Parameters
|
| 1110 |
+
----------
|
| 1111 |
+
hidden_dim : int, optional
|
| 1112 |
+
Hidden dimension for residual blocks. Default is 256.
|
| 1113 |
+
num_blocks : int, optional
|
| 1114 |
+
Number of residual blocks per branch. Default is 4.
|
| 1115 |
+
chroma_hidden_dim : int, optional
|
| 1116 |
+
Hidden dimension for chroma branch. Default is 384.
|
| 1117 |
+
|
| 1118 |
+
Attributes
|
| 1119 |
+
----------
|
| 1120 |
+
hue_branch : ComponentResNet
|
| 1121 |
+
ResNet error predictor for hue component.
|
| 1122 |
+
value_branch : ComponentResNet
|
| 1123 |
+
ResNet error predictor for value component.
|
| 1124 |
+
chroma_branch : ComponentResNet
|
| 1125 |
+
ResNet error predictor for chroma component.
|
| 1126 |
+
code_branch : ComponentResNet
|
| 1127 |
+
ResNet error predictor for code component.
|
| 1128 |
+
"""
|
| 1129 |
+
|
| 1130 |
+
def __init__(
|
| 1131 |
+
self,
|
| 1132 |
+
hidden_dim: int = 256,
|
| 1133 |
+
num_blocks: int = 4,
|
| 1134 |
+
chroma_hidden_dim: int = 384,
|
| 1135 |
+
) -> None:
|
| 1136 |
+
"""Initialize the multi-ResNet error predictor."""
|
| 1137 |
+
super().__init__()
|
| 1138 |
+
|
| 1139 |
+
# Input: xyY (3) + base prediction (4) = 7
|
| 1140 |
+
self.hue_branch = ComponentResNet(
|
| 1141 |
+
input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
|
| 1142 |
+
)
|
| 1143 |
+
self.value_branch = ComponentResNet(
|
| 1144 |
+
input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
|
| 1145 |
+
)
|
| 1146 |
+
self.chroma_branch = ComponentResNet(
|
| 1147 |
+
input_dim=7, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks
|
| 1148 |
+
)
|
| 1149 |
+
self.code_branch = ComponentResNet(
|
| 1150 |
+
input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1154 |
+
"""
|
| 1155 |
+
Forward pass through all error predictor branches.
|
| 1156 |
+
|
| 1157 |
+
Parameters
|
| 1158 |
+
----------
|
| 1159 |
+
x : Tensor
|
| 1160 |
+
Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
|
| 1161 |
+
|
| 1162 |
+
Returns
|
| 1163 |
+
-------
|
| 1164 |
+
Tensor
|
| 1165 |
+
Concatenated error corrections [hue, value, chroma, code]
|
| 1166 |
+
of shape (batch_size, 4).
|
| 1167 |
+
"""
|
| 1168 |
+
hue_error = self.hue_branch(x)
|
| 1169 |
+
value_error = self.value_branch(x)
|
| 1170 |
+
chroma_error = self.chroma_branch(x)
|
| 1171 |
+
code_error = self.code_branch(x)
|
| 1172 |
+
return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
# =============================================================================
|
| 1176 |
+
# Composite Models: Munsell → xyY
|
| 1177 |
+
# =============================================================================
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
class MultiMLPToxyY(nn.Module):
|
| 1181 |
+
"""
|
| 1182 |
+
Multi-MLP for Munsell to xyY conversion.
|
| 1183 |
+
|
| 1184 |
+
Uses 3 independent ComponentMLP branches, one for each xyY component.
|
| 1185 |
+
|
| 1186 |
+
Parameters
|
| 1187 |
+
----------
|
| 1188 |
+
width_multiplier : float, optional
|
| 1189 |
+
Width multiplier for x and y branches. Default is 1.0.
|
| 1190 |
+
y_width_multiplier : float, optional
|
| 1191 |
+
Width multiplier for Y (luminance) branch. Default is 1.25.
|
| 1192 |
+
|
| 1193 |
+
Attributes
|
| 1194 |
+
----------
|
| 1195 |
+
x_branch : ComponentMLP
|
| 1196 |
+
MLP for x chromaticity component.
|
| 1197 |
+
y_branch : ComponentMLP
|
| 1198 |
+
MLP for y chromaticity component.
|
| 1199 |
+
Y_branch : ComponentMLP
|
| 1200 |
+
MLP for Y luminance component.
|
| 1201 |
+
"""
|
| 1202 |
+
|
| 1203 |
+
def __init__(
|
| 1204 |
+
self, width_multiplier: float = 1.0, y_width_multiplier: float = 1.25
|
| 1205 |
+
) -> None:
|
| 1206 |
+
"""Initialize the multi-MLP model."""
|
| 1207 |
+
super().__init__()
|
| 1208 |
+
|
| 1209 |
+
self.x_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier)
|
| 1210 |
+
self.y_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier)
|
| 1211 |
+
self.Y_branch = ComponentMLP(
|
| 1212 |
+
input_dim=4, width_multiplier=y_width_multiplier
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
def forward(self, munsell: Tensor) -> Tensor:
|
| 1216 |
+
"""
|
| 1217 |
+
Forward pass through all branches.
|
| 1218 |
+
|
| 1219 |
+
Parameters
|
| 1220 |
+
----------
|
| 1221 |
+
munsell : Tensor
|
| 1222 |
+
Normalized Munsell specification [hue, value, chroma, code]
|
| 1223 |
+
of shape (batch_size, 4).
|
| 1224 |
+
|
| 1225 |
+
Returns
|
| 1226 |
+
-------
|
| 1227 |
+
Tensor
|
| 1228 |
+
Predicted xyY values [x, y, Y] of shape (batch_size, 3).
|
| 1229 |
+
"""
|
| 1230 |
+
x = self.x_branch(munsell)
|
| 1231 |
+
y = self.y_branch(munsell)
|
| 1232 |
+
Y = self.Y_branch(munsell)
|
| 1233 |
+
return torch.cat([x, y, Y], dim=1)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
# =============================================================================
|
| 1237 |
+
# Error Predictors: Munsell → xyY
|
| 1238 |
+
# =============================================================================
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
class MultiMLPErrorPredictorToxyY(nn.Module):
|
| 1242 |
+
"""
|
| 1243 |
+
Multi-MLP error predictor for Munsell to xyY conversion.
|
| 1244 |
+
|
| 1245 |
+
Uses 3 independent ComponentErrorPredictor branches, one for each
|
| 1246 |
+
xyY component error.
|
| 1247 |
+
|
| 1248 |
+
Parameters
|
| 1249 |
+
----------
|
| 1250 |
+
width_multiplier : float, optional
|
| 1251 |
+
Width multiplier for all branches. Default is 1.0.
|
| 1252 |
+
|
| 1253 |
+
Attributes
|
| 1254 |
+
----------
|
| 1255 |
+
x_branch : ComponentErrorPredictor
|
| 1256 |
+
Error predictor for x chromaticity component.
|
| 1257 |
+
y_branch : ComponentErrorPredictor
|
| 1258 |
+
Error predictor for y chromaticity component.
|
| 1259 |
+
Y_branch : ComponentErrorPredictor
|
| 1260 |
+
Error predictor for Y luminance component.
|
| 1261 |
+
"""
|
| 1262 |
+
|
| 1263 |
+
def __init__(self, width_multiplier: float = 1.0) -> None:
|
| 1264 |
+
"""Initialize the multi-head error predictor."""
|
| 1265 |
+
super().__init__()
|
| 1266 |
+
|
| 1267 |
+
self.x_branch = ComponentErrorPredictor(
|
| 1268 |
+
input_dim=7, width_multiplier=width_multiplier
|
| 1269 |
+
)
|
| 1270 |
+
self.y_branch = ComponentErrorPredictor(
|
| 1271 |
+
input_dim=7, width_multiplier=width_multiplier
|
| 1272 |
+
)
|
| 1273 |
+
self.Y_branch = ComponentErrorPredictor(
|
| 1274 |
+
input_dim=7, width_multiplier=width_multiplier
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
def forward(self, combined_input: Tensor) -> Tensor:
|
| 1278 |
+
"""
|
| 1279 |
+
Forward pass through all error predictor branches.
|
| 1280 |
+
|
| 1281 |
+
Parameters
|
| 1282 |
+
----------
|
| 1283 |
+
combined_input : Tensor
|
| 1284 |
+
Combined input [munsell_norm, base_pred] of shape (batch_size, 7).
|
| 1285 |
+
|
| 1286 |
+
Returns
|
| 1287 |
+
-------
|
| 1288 |
+
Tensor
|
| 1289 |
+
Concatenated error corrections [x, y, Y] of shape (batch_size, 3).
|
| 1290 |
+
"""
|
| 1291 |
+
x_error = self.x_branch(combined_input)
|
| 1292 |
+
y_error = self.y_branch(combined_input)
|
| 1293 |
+
Y_error = self.Y_branch(combined_input)
|
| 1294 |
+
return torch.cat([x_error, y_error, Y_error], dim=1)
|
learning_munsell/training/from_xyY/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training scripts for xyY to Munsell conversion."""
|
learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter search for Multi-Error Predictor using Optuna.
|
| 3 |
+
|
| 4 |
+
Optimizes:
|
| 5 |
+
- Learning rate
|
| 6 |
+
- Batch size
|
| 7 |
+
- Chroma width multiplier
|
| 8 |
+
- Loss function weights (MSE, MAE, log penalty, Huber)
|
| 9 |
+
- Huber delta
|
| 10 |
+
- Dropout
|
| 11 |
+
|
| 12 |
+
Objective: Minimize validation loss
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import mlflow
|
| 20 |
+
import numpy as np
|
| 21 |
+
import onnxruntime as ort
|
| 22 |
+
import optuna
|
| 23 |
+
import torch
|
| 24 |
+
from numpy.typing import NDArray
|
| 25 |
+
from optuna.trial import Trial
|
| 26 |
+
from torch import nn, optim
|
| 27 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 28 |
+
|
| 29 |
+
from learning_munsell import PROJECT_ROOT
|
| 30 |
+
from learning_munsell.models.networks import (
|
| 31 |
+
ComponentErrorPredictor,
|
| 32 |
+
MultiMLPErrorPredictorToMunsell,
|
| 33 |
+
)
|
| 34 |
+
from learning_munsell.utilities.common import setup_mlflow_experiment
|
| 35 |
+
from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
|
| 36 |
+
|
| 37 |
+
LOGGER = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def precision_focused_loss(
|
| 41 |
+
pred: torch.Tensor,
|
| 42 |
+
target: torch.Tensor,
|
| 43 |
+
mse_weight: float = 1.0,
|
| 44 |
+
mae_weight: float = 0.5,
|
| 45 |
+
log_weight: float = 0.3,
|
| 46 |
+
huber_weight: float = 0.5,
|
| 47 |
+
huber_delta: float = 0.01,
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
"""
|
| 50 |
+
Precision-focused loss function with configurable weights.
|
| 51 |
+
|
| 52 |
+
Combines multiple loss components to encourage accurate error prediction:
|
| 53 |
+
- MSE: Standard mean squared error
|
| 54 |
+
- MAE: Mean absolute error for robustness
|
| 55 |
+
- Log penalty: Penalizes small errors more heavily
|
| 56 |
+
- Huber loss: Robust to outliers with adjustable delta
|
| 57 |
+
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
pred : torch.Tensor
|
| 61 |
+
Predicted values, shape (batch_size, n_components).
|
| 62 |
+
target : torch.Tensor
|
| 63 |
+
Target values, shape (batch_size, n_components).
|
| 64 |
+
mse_weight : float, optional
|
| 65 |
+
Weight for MSE component. Default is 1.0.
|
| 66 |
+
mae_weight : float, optional
|
| 67 |
+
Weight for MAE component. Default is 0.5.
|
| 68 |
+
log_weight : float, optional
|
| 69 |
+
Weight for logarithmic penalty component. Default is 0.3.
|
| 70 |
+
huber_weight : float, optional
|
| 71 |
+
Weight for Huber loss component. Default is 0.5.
|
| 72 |
+
huber_delta : float, optional
|
| 73 |
+
Delta parameter for Huber loss transition point. Default is 0.01.
|
| 74 |
+
|
| 75 |
+
Returns
|
| 76 |
+
-------
|
| 77 |
+
torch.Tensor
|
| 78 |
+
Weighted combination of loss components, scalar tensor.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
mse = torch.mean((pred - target) ** 2)
|
| 82 |
+
mae = torch.mean(torch.abs(pred - target))
|
| 83 |
+
log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
|
| 84 |
+
|
| 85 |
+
abs_error = torch.abs(pred - target)
|
| 86 |
+
huber = torch.where(
|
| 87 |
+
abs_error <= huber_delta,
|
| 88 |
+
0.5 * abs_error**2,
|
| 89 |
+
huber_delta * (abs_error - 0.5 * huber_delta),
|
| 90 |
+
)
|
| 91 |
+
huber_loss = torch.mean(huber)
|
| 92 |
+
|
| 93 |
+
return (
|
| 94 |
+
mse_weight * mse
|
| 95 |
+
+ mae_weight * mae
|
| 96 |
+
+ log_weight * log_penalty
|
| 97 |
+
+ huber_weight * huber_loss
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_base_model(
|
| 102 |
+
model_path: Path, params_path: Path
|
| 103 |
+
) -> tuple[ort.InferenceSession, dict, dict]:
|
| 104 |
+
"""
|
| 105 |
+
Load the base ONNX model and its normalization parameters.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
model_path : Path
|
| 110 |
+
Path to the base model ONNX file.
|
| 111 |
+
params_path : Path
|
| 112 |
+
Path to the normalization parameters NPZ file.
|
| 113 |
+
|
| 114 |
+
Returns
|
| 115 |
+
-------
|
| 116 |
+
ort.InferenceSession
|
| 117 |
+
ONNX Runtime inference session for the base model.
|
| 118 |
+
dict
|
| 119 |
+
Input normalization parameters (x_range, y_range, Y_range).
|
| 120 |
+
dict
|
| 121 |
+
Output normalization parameters (hue_range, value_range, chroma_range, code_range).
|
| 122 |
+
"""
|
| 123 |
+
session = ort.InferenceSession(str(model_path))
|
| 124 |
+
params = np.load(params_path, allow_pickle=True)
|
| 125 |
+
return session, params["input_params"].item(), params["output_params"].item()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def train_epoch(
|
| 129 |
+
model: nn.Module,
|
| 130 |
+
dataloader: DataLoader,
|
| 131 |
+
optimizer: optim.Optimizer,
|
| 132 |
+
device: torch.device,
|
| 133 |
+
loss_params: dict[str, float],
|
| 134 |
+
) -> float:
|
| 135 |
+
"""
|
| 136 |
+
Train the model for one epoch.
|
| 137 |
+
|
| 138 |
+
Parameters
|
| 139 |
+
----------
|
| 140 |
+
model : nn.Module
|
| 141 |
+
Error predictor model to train.
|
| 142 |
+
dataloader : DataLoader
|
| 143 |
+
DataLoader providing training batches.
|
| 144 |
+
optimizer : optim.Optimizer
|
| 145 |
+
Optimizer for updating model parameters.
|
| 146 |
+
device : torch.device
|
| 147 |
+
Device to run training on (CPU, CUDA, or MPS).
|
| 148 |
+
loss_params : dict of str to float
|
| 149 |
+
Parameters for precision_focused_loss function.
|
| 150 |
+
|
| 151 |
+
Returns
|
| 152 |
+
-------
|
| 153 |
+
float
|
| 154 |
+
Average training loss over the epoch.
|
| 155 |
+
"""
|
| 156 |
+
model.train()
|
| 157 |
+
total_loss = 0.0
|
| 158 |
+
|
| 159 |
+
for X_batch, y_batch in dataloader:
|
| 160 |
+
X_batch = X_batch.to(device)
|
| 161 |
+
y_batch = y_batch.to(device)
|
| 162 |
+
outputs = model(X_batch)
|
| 163 |
+
loss = precision_focused_loss(outputs, y_batch, **loss_params)
|
| 164 |
+
|
| 165 |
+
optimizer.zero_grad()
|
| 166 |
+
loss.backward()
|
| 167 |
+
optimizer.step()
|
| 168 |
+
|
| 169 |
+
total_loss += loss.item()
|
| 170 |
+
|
| 171 |
+
return total_loss / len(dataloader)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def validate(
|
| 175 |
+
model: nn.Module,
|
| 176 |
+
dataloader: DataLoader,
|
| 177 |
+
device: torch.device,
|
| 178 |
+
loss_params: dict[str, float],
|
| 179 |
+
) -> float:
|
| 180 |
+
"""
|
| 181 |
+
Validate the model on the validation set.
|
| 182 |
+
|
| 183 |
+
Parameters
|
| 184 |
+
----------
|
| 185 |
+
model : nn.Module
|
| 186 |
+
Error predictor model to validate.
|
| 187 |
+
dataloader : DataLoader
|
| 188 |
+
DataLoader providing validation batches.
|
| 189 |
+
device : torch.device
|
| 190 |
+
Device to run validation on (CPU, CUDA, or MPS).
|
| 191 |
+
loss_params : dict of str to float
|
| 192 |
+
Parameters for precision_focused_loss function.
|
| 193 |
+
|
| 194 |
+
Returns
|
| 195 |
+
-------
|
| 196 |
+
float
|
| 197 |
+
Average validation loss.
|
| 198 |
+
"""
|
| 199 |
+
model.eval()
|
| 200 |
+
total_loss = 0.0
|
| 201 |
+
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
for X_batch, y_batch in dataloader:
|
| 204 |
+
X_batch = X_batch.to(device)
|
| 205 |
+
y_batch = y_batch.to(device)
|
| 206 |
+
outputs = model(X_batch)
|
| 207 |
+
loss = precision_focused_loss(outputs, y_batch, **loss_params)
|
| 208 |
+
|
| 209 |
+
total_loss += loss.item()
|
| 210 |
+
|
| 211 |
+
return total_loss / len(dataloader)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def objective(trial: Trial) -> float:
|
| 215 |
+
"""
|
| 216 |
+
Optuna objective function to minimize validation loss.
|
| 217 |
+
|
| 218 |
+
This function defines the hyperparameter search space and training
|
| 219 |
+
procedure for each trial. It optimizes:
|
| 220 |
+
- Learning rate (5e-4 to 1e-3, log scale)
|
| 221 |
+
- Batch size (512 or 1024)
|
| 222 |
+
- Chroma branch width multiplier (1.0 to 1.5)
|
| 223 |
+
- Dropout rate (0.1 to 0.2)
|
| 224 |
+
- Loss function weights (MSE, Huber)
|
| 225 |
+
- Huber delta parameter (0.01 to 0.05)
|
| 226 |
+
|
| 227 |
+
Parameters
|
| 228 |
+
----------
|
| 229 |
+
trial : Trial
|
| 230 |
+
Optuna trial object for suggesting hyperparameters.
|
| 231 |
+
|
| 232 |
+
Returns
|
| 233 |
+
-------
|
| 234 |
+
float
|
| 235 |
+
Best validation loss achieved during training.
|
| 236 |
+
|
| 237 |
+
Raises
|
| 238 |
+
------
|
| 239 |
+
FileNotFoundError
|
| 240 |
+
If base model or training data files are not found.
|
| 241 |
+
optuna.TrialPruned
|
| 242 |
+
If trial is pruned based on intermediate results.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
# Hyperparameters to optimize - constrained based on Trial 0 insights
|
| 246 |
+
lr = trial.suggest_float("lr", 5e-4, 1e-3, log=True) # Higher LR worked well
|
| 247 |
+
batch_size = trial.suggest_categorical(
|
| 248 |
+
"batch_size", [512, 1024]
|
| 249 |
+
) # Smaller batches better
|
| 250 |
+
chroma_width = trial.suggest_float(
|
| 251 |
+
"chroma_width", 1.0, 1.5, step=0.25
|
| 252 |
+
) # Smaller worked
|
| 253 |
+
dropout = trial.suggest_float("dropout", 0.1, 0.2, step=0.05)
|
| 254 |
+
|
| 255 |
+
# Simplified loss - just MSE + optional small Huber (no log penalty!)
|
| 256 |
+
mse_weight = trial.suggest_float("mse_weight", 1.0, 2.0, step=0.25)
|
| 257 |
+
huber_weight = trial.suggest_float("huber_weight", 0.0, 0.5, step=0.25)
|
| 258 |
+
huber_delta = trial.suggest_float("huber_delta", 0.01, 0.05, step=0.01)
|
| 259 |
+
|
| 260 |
+
loss_params = {
|
| 261 |
+
"mse_weight": mse_weight,
|
| 262 |
+
"mae_weight": 0.0, # Fixed at 0
|
| 263 |
+
"log_weight": 0.0, # Fixed at 0 (was causing scale issues)
|
| 264 |
+
"huber_weight": huber_weight,
|
| 265 |
+
"huber_delta": huber_delta,
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
LOGGER.info("")
|
| 269 |
+
LOGGER.info("=" * 80)
|
| 270 |
+
LOGGER.info("Trial %d", trial.number)
|
| 271 |
+
LOGGER.info("=" * 80)
|
| 272 |
+
LOGGER.info(" lr: %.6f", lr)
|
| 273 |
+
LOGGER.info(" batch_size: %d", batch_size)
|
| 274 |
+
LOGGER.info(" chroma_width: %.2f", chroma_width)
|
| 275 |
+
LOGGER.info(" dropout: %.2f", dropout)
|
| 276 |
+
LOGGER.info(" mse_weight: %.2f", mse_weight)
|
| 277 |
+
LOGGER.info(" huber_weight: %.2f", huber_weight)
|
| 278 |
+
LOGGER.info(" huber_delta: %.3f", huber_delta)
|
| 279 |
+
|
| 280 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 281 |
+
|
| 282 |
+
# Load base model and data
|
| 283 |
+
model_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 284 |
+
data_dir = PROJECT_ROOT / "data"
|
| 285 |
+
|
| 286 |
+
base_model_path = model_dir / "multi_mlp.onnx"
|
| 287 |
+
params_path = model_dir / "multi_mlp_normalization_params.npz"
|
| 288 |
+
cache_file = data_dir / "training_data.npz"
|
| 289 |
+
|
| 290 |
+
if not base_model_path.exists():
|
| 291 |
+
msg = f"Base model not found: {base_model_path}"
|
| 292 |
+
raise FileNotFoundError(msg)
|
| 293 |
+
|
| 294 |
+
base_session, input_params, output_params = load_base_model(
|
| 295 |
+
base_model_path, params_path
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Load data
|
| 299 |
+
data = np.load(cache_file)
|
| 300 |
+
X_train = data["X_train"]
|
| 301 |
+
y_train = data["y_train"]
|
| 302 |
+
X_val = data["X_val"]
|
| 303 |
+
y_val = data["y_val"]
|
| 304 |
+
|
| 305 |
+
# Normalize and generate base predictions
|
| 306 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 307 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 308 |
+
base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
|
| 309 |
+
|
| 310 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 311 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 312 |
+
base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
|
| 313 |
+
|
| 314 |
+
# Compute errors
|
| 315 |
+
error_train = y_train_norm - base_pred_train_norm
|
| 316 |
+
error_val = y_val_norm - base_pred_val_norm
|
| 317 |
+
|
| 318 |
+
# Combined input
|
| 319 |
+
X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
|
| 320 |
+
X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
|
| 321 |
+
|
| 322 |
+
# PyTorch tensors
|
| 323 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 324 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 325 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 326 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 327 |
+
|
| 328 |
+
# Data loaders
|
| 329 |
+
train_dataset = TensorDataset(X_train_t, error_train_t)
|
| 330 |
+
val_dataset = TensorDataset(X_val_t, error_val_t)
|
| 331 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 332 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 333 |
+
|
| 334 |
+
# Initialize model
|
| 335 |
+
model = MultiMLPErrorPredictorToMunsell(chroma_width=chroma_width, dropout=dropout).to(
|
| 336 |
+
device
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 340 |
+
LOGGER.info(" Total parameters: %s", f"{total_params:,}")
|
| 341 |
+
|
| 342 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 343 |
+
|
| 344 |
+
# MLflow setup
|
| 345 |
+
run_name = setup_mlflow_experiment(
|
| 346 |
+
"from_xyY", f"hparam_error_predictor_trial_{trial.number}"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Training loop
|
| 350 |
+
num_epochs = 100
|
| 351 |
+
patience = 15
|
| 352 |
+
best_val_loss = float("inf")
|
| 353 |
+
patience_counter = 0
|
| 354 |
+
|
| 355 |
+
with mlflow.start_run(run_name=run_name):
|
| 356 |
+
mlflow.log_params(
|
| 357 |
+
{
|
| 358 |
+
"trial": trial.number,
|
| 359 |
+
"lr": lr,
|
| 360 |
+
"batch_size": batch_size,
|
| 361 |
+
"chroma_width": chroma_width,
|
| 362 |
+
"dropout": dropout,
|
| 363 |
+
"mse_weight": mse_weight,
|
| 364 |
+
"huber_weight": huber_weight,
|
| 365 |
+
"huber_delta": huber_delta,
|
| 366 |
+
"total_params": total_params,
|
| 367 |
+
}
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
for epoch in range(num_epochs):
|
| 371 |
+
train_loss = train_epoch(
|
| 372 |
+
model, train_loader, optimizer, device, loss_params
|
| 373 |
+
)
|
| 374 |
+
val_loss = validate(model, val_loader, device, loss_params)
|
| 375 |
+
|
| 376 |
+
mlflow.log_metrics(
|
| 377 |
+
{
|
| 378 |
+
"train_loss": train_loss,
|
| 379 |
+
"val_loss": val_loss,
|
| 380 |
+
},
|
| 381 |
+
step=epoch,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if (epoch + 1) % 10 == 0:
|
| 385 |
+
LOGGER.info(
|
| 386 |
+
" Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 387 |
+
epoch + 1,
|
| 388 |
+
num_epochs,
|
| 389 |
+
train_loss,
|
| 390 |
+
val_loss,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if val_loss < best_val_loss:
|
| 394 |
+
best_val_loss = val_loss
|
| 395 |
+
patience_counter = 0
|
| 396 |
+
else:
|
| 397 |
+
patience_counter += 1
|
| 398 |
+
if patience_counter >= patience:
|
| 399 |
+
LOGGER.info(" Early stopping at epoch %d", epoch + 1)
|
| 400 |
+
break
|
| 401 |
+
|
| 402 |
+
trial.report(val_loss, epoch)
|
| 403 |
+
|
| 404 |
+
if trial.should_prune():
|
| 405 |
+
LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
|
| 406 |
+
mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
|
| 407 |
+
raise optuna.TrialPruned
|
| 408 |
+
|
| 409 |
+
# Log final results
|
| 410 |
+
mlflow.log_metrics(
|
| 411 |
+
{
|
| 412 |
+
"best_val_loss": best_val_loss,
|
| 413 |
+
"final_train_loss": train_loss,
|
| 414 |
+
"final_epoch": epoch + 1,
|
| 415 |
+
}
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
LOGGER.info(" Final validation loss: %.6f", best_val_loss)
|
| 419 |
+
|
| 420 |
+
return best_val_loss
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def main() -> None:
|
| 424 |
+
"""
|
| 425 |
+
Run hyperparameter search for Multi-MLP Error Predictor.
|
| 426 |
+
|
| 427 |
+
Performs systematic hyperparameter optimization using Optuna with:
|
| 428 |
+
- MedianPruner for early stopping of unpromising trials
|
| 429 |
+
- 15 total trials
|
| 430 |
+
- MLflow logging for each trial
|
| 431 |
+
- Result visualization and saving
|
| 432 |
+
|
| 433 |
+
The search aims to find optimal hyperparameters for predicting errors
|
| 434 |
+
in a base Munsell prediction model, which can then be used to improve
|
| 435 |
+
predictions by correcting systematic biases.
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
LOGGER.info("=" * 80)
|
| 439 |
+
LOGGER.info("Multi-Error Predictor Hyperparameter Search with Optuna")
|
| 440 |
+
LOGGER.info("=" * 80)
|
| 441 |
+
|
| 442 |
+
study = optuna.create_study(
|
| 443 |
+
direction="minimize",
|
| 444 |
+
study_name="multi_mlp_error_predictor_hparam_search",
|
| 445 |
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
n_trials = 15
|
| 449 |
+
|
| 450 |
+
LOGGER.info("")
|
| 451 |
+
LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
|
| 452 |
+
LOGGER.info("")
|
| 453 |
+
|
| 454 |
+
study.optimize(objective, n_trials=n_trials, timeout=None)
|
| 455 |
+
|
| 456 |
+
# Print results
|
| 457 |
+
LOGGER.info("")
|
| 458 |
+
LOGGER.info("=" * 80)
|
| 459 |
+
LOGGER.info("Hyperparameter Search Results")
|
| 460 |
+
LOGGER.info("=" * 80)
|
| 461 |
+
LOGGER.info("")
|
| 462 |
+
LOGGER.info("Best trial:")
|
| 463 |
+
LOGGER.info(" Value (val_loss): %.6f", study.best_value)
|
| 464 |
+
LOGGER.info("")
|
| 465 |
+
LOGGER.info("Best hyperparameters:")
|
| 466 |
+
for key, value in study.best_params.items():
|
| 467 |
+
LOGGER.info(" %s: %s", key, value)
|
| 468 |
+
|
| 469 |
+
# Save results
|
| 470 |
+
results_dir = PROJECT_ROOT / "results" / "from_xyY"
|
| 471 |
+
results_dir.mkdir(exist_ok=True)
|
| 472 |
+
|
| 473 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 474 |
+
results_file = results_dir / f"error_predictor_hparam_search_{timestamp}.txt"
|
| 475 |
+
|
| 476 |
+
with open(results_file, "w") as f:
|
| 477 |
+
f.write("=" * 80 + "\n")
|
| 478 |
+
f.write("Multi-Error Predictor Hyperparameter Search Results\n")
|
| 479 |
+
f.write("=" * 80 + "\n\n")
|
| 480 |
+
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 481 |
+
f.write(f"Number of trials: {len(study.trials)}\n")
|
| 482 |
+
f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
|
| 483 |
+
f.write("Best hyperparameters:\n")
|
| 484 |
+
for key, value in study.best_params.items():
|
| 485 |
+
f.write(f" {key}: {value}\n")
|
| 486 |
+
f.write("\n\nAll trials:\n")
|
| 487 |
+
f.write("-" * 80 + "\n")
|
| 488 |
+
|
| 489 |
+
for trial in study.trials:
|
| 490 |
+
f.write(f"\nTrial {trial.number}:\n")
|
| 491 |
+
f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n")
|
| 492 |
+
f.write(" Params:\n")
|
| 493 |
+
for key, value in trial.params.items():
|
| 494 |
+
f.write(f" {key}: {value}\n")
|
| 495 |
+
|
| 496 |
+
LOGGER.info("")
|
| 497 |
+
LOGGER.info("Results saved to: %s", results_file)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if __name__ == "__main__":
|
| 501 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 502 |
+
|
| 503 |
+
main()
|
learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter search for Multi-Head model (xyY to Munsell) using Optuna.
|
| 3 |
+
|
| 4 |
+
Optimizes:
|
| 5 |
+
- Learning rate
|
| 6 |
+
- Batch size
|
| 7 |
+
- Encoder width multiplier (shared encoder capacity)
|
| 8 |
+
- Head width multiplier (component-specific head capacity)
|
| 9 |
+
- Chroma head width (specialized for chroma prediction)
|
| 10 |
+
- Dropout
|
| 11 |
+
- Weight decay
|
| 12 |
+
|
| 13 |
+
Objective: Minimize validation loss
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import mlflow
|
| 23 |
+
import numpy as np
|
| 24 |
+
import optuna
|
| 25 |
+
import torch
|
| 26 |
+
from optuna.trial import Trial
|
| 27 |
+
from torch import nn, optim
|
| 28 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 29 |
+
|
| 30 |
+
from learning_munsell import PROJECT_ROOT
|
| 31 |
+
from learning_munsell.utilities.common import setup_mlflow_experiment
|
| 32 |
+
from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell
|
| 33 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 34 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 35 |
+
|
| 36 |
+
LOGGER = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiHeadParametric(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Parametric Multi-Head model for hyperparameter search (xyY to Munsell).
|
| 42 |
+
|
| 43 |
+
This model uses a shared encoder to extract general color space features
|
| 44 |
+
from xyY inputs, followed by component-specific heads for predicting
|
| 45 |
+
each Munsell component independently.
|
| 46 |
+
|
| 47 |
+
Architecture:
|
| 48 |
+
- Shared encoder: 3 → h1 → h2 → h3 (scaled by encoder_width)
|
| 49 |
+
- hue, value, code heads: h3 → h2' → h1' → 1 (scaled by head_width)
|
| 50 |
+
- chroma head: h3 → h2'' → h1'' → 1 (scaled by chroma_head_width)
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
encoder_width : float, optional
|
| 55 |
+
Width multiplier for shared encoder layers. Default is 1.0.
|
| 56 |
+
Base dimensions: h1=128, h2=256, h3=512.
|
| 57 |
+
head_width : float, optional
|
| 58 |
+
Width multiplier for hue, value, and code heads. Default is 1.0.
|
| 59 |
+
Base dimensions: h1=128, h2=256.
|
| 60 |
+
chroma_head_width : float, optional
|
| 61 |
+
Width multiplier for chroma head (typically wider). Default is 1.0.
|
| 62 |
+
Base dimensions: h1=128, h2=256, h3=384.
|
| 63 |
+
dropout : float, optional
|
| 64 |
+
Dropout rate applied after hidden layers. Default is 0.0.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
encoder_width: float = 1.0,
|
| 70 |
+
head_width: float = 1.0,
|
| 71 |
+
chroma_head_width: float = 1.0,
|
| 72 |
+
dropout: float = 0.0,
|
| 73 |
+
) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
# Encoder dimensions (shared)
|
| 77 |
+
e_h1 = int(128 * encoder_width)
|
| 78 |
+
e_h2 = int(256 * encoder_width)
|
| 79 |
+
e_h3 = int(512 * encoder_width)
|
| 80 |
+
|
| 81 |
+
# Head dimensions (component-specific)
|
| 82 |
+
h_h1 = int(128 * head_width)
|
| 83 |
+
h_h2 = int(256 * head_width)
|
| 84 |
+
|
| 85 |
+
# Chroma head dimensions (specialized)
|
| 86 |
+
c_h1 = int(128 * chroma_head_width)
|
| 87 |
+
c_h2 = int(256 * chroma_head_width)
|
| 88 |
+
c_h3 = int(384 * chroma_head_width)
|
| 89 |
+
|
| 90 |
+
# Shared encoder - learns general color space features
|
| 91 |
+
encoder_layers = [
|
| 92 |
+
nn.Linear(3, e_h1),
|
| 93 |
+
nn.ReLU(),
|
| 94 |
+
nn.BatchNorm1d(e_h1),
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
if dropout > 0:
|
| 98 |
+
encoder_layers.append(nn.Dropout(dropout))
|
| 99 |
+
|
| 100 |
+
encoder_layers.extend(
|
| 101 |
+
[
|
| 102 |
+
nn.Linear(e_h1, e_h2),
|
| 103 |
+
nn.ReLU(),
|
| 104 |
+
nn.BatchNorm1d(e_h2),
|
| 105 |
+
]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if dropout > 0:
|
| 109 |
+
encoder_layers.append(nn.Dropout(dropout))
|
| 110 |
+
|
| 111 |
+
encoder_layers.extend(
|
| 112 |
+
[
|
| 113 |
+
nn.Linear(e_h2, e_h3),
|
| 114 |
+
nn.ReLU(),
|
| 115 |
+
nn.BatchNorm1d(e_h3),
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if dropout > 0:
|
| 120 |
+
encoder_layers.append(nn.Dropout(dropout))
|
| 121 |
+
|
| 122 |
+
self.encoder = nn.Sequential(*encoder_layers)
|
| 123 |
+
|
| 124 |
+
# Component-specific heads (hue, value, code)
|
| 125 |
+
def create_head() -> nn.Sequential:
|
| 126 |
+
head_layers = [
|
| 127 |
+
nn.Linear(e_h3, h_h2),
|
| 128 |
+
nn.ReLU(),
|
| 129 |
+
nn.BatchNorm1d(h_h2),
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
if dropout > 0:
|
| 133 |
+
head_layers.append(nn.Dropout(dropout))
|
| 134 |
+
|
| 135 |
+
head_layers.extend(
|
| 136 |
+
[
|
| 137 |
+
nn.Linear(h_h2, h_h1),
|
| 138 |
+
nn.ReLU(),
|
| 139 |
+
nn.BatchNorm1d(h_h1),
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if dropout > 0:
|
| 144 |
+
head_layers.append(nn.Dropout(dropout))
|
| 145 |
+
|
| 146 |
+
head_layers.append(nn.Linear(h_h1, 1))
|
| 147 |
+
|
| 148 |
+
return nn.Sequential(*head_layers)
|
| 149 |
+
|
| 150 |
+
self.hue_head = create_head()
|
| 151 |
+
self.value_head = create_head()
|
| 152 |
+
self.code_head = create_head()
|
| 153 |
+
|
| 154 |
+
# Chroma head - wider for harder task
|
| 155 |
+
chroma_layers = [
|
| 156 |
+
nn.Linear(e_h3, c_h3),
|
| 157 |
+
nn.ReLU(),
|
| 158 |
+
nn.BatchNorm1d(c_h3),
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
if dropout > 0:
|
| 162 |
+
chroma_layers.append(nn.Dropout(dropout))
|
| 163 |
+
|
| 164 |
+
chroma_layers.extend(
|
| 165 |
+
[
|
| 166 |
+
nn.Linear(c_h3, c_h2),
|
| 167 |
+
nn.ReLU(),
|
| 168 |
+
nn.BatchNorm1d(c_h2),
|
| 169 |
+
]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if dropout > 0:
|
| 173 |
+
chroma_layers.append(nn.Dropout(dropout))
|
| 174 |
+
|
| 175 |
+
chroma_layers.extend(
|
| 176 |
+
[
|
| 177 |
+
nn.Linear(c_h2, c_h1),
|
| 178 |
+
nn.ReLU(),
|
| 179 |
+
nn.BatchNorm1d(c_h1),
|
| 180 |
+
]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if dropout > 0:
|
| 184 |
+
chroma_layers.append(nn.Dropout(dropout))
|
| 185 |
+
|
| 186 |
+
chroma_layers.append(nn.Linear(c_h1, 1))
|
| 187 |
+
|
| 188 |
+
self.chroma_head = nn.Sequential(*chroma_layers)
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""
|
| 192 |
+
Forward pass through shared encoder and component-specific heads.
|
| 193 |
+
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
x : torch.Tensor
|
| 197 |
+
Input tensor of shape (batch_size, 3) containing normalized
|
| 198 |
+
xyY values.
|
| 199 |
+
|
| 200 |
+
Returns
|
| 201 |
+
-------
|
| 202 |
+
torch.Tensor
|
| 203 |
+
Predicted Munsell components, shape (batch_size, 4).
|
| 204 |
+
Output order: [hue, value, chroma, code].
|
| 205 |
+
"""
|
| 206 |
+
# Shared feature extraction
|
| 207 |
+
features = self.encoder(x)
|
| 208 |
+
|
| 209 |
+
# Component-specific predictions
|
| 210 |
+
hue = self.hue_head(features)
|
| 211 |
+
value = self.value_head(features)
|
| 212 |
+
chroma = self.chroma_head(features)
|
| 213 |
+
code = self.code_head(features)
|
| 214 |
+
|
| 215 |
+
# Concatenate: [hue, value, chroma, code]
|
| 216 |
+
return torch.cat([hue, value, chroma, code], dim=1)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def objective(trial: Trial) -> float:
|
| 220 |
+
"""
|
| 221 |
+
Optuna objective function to minimize validation loss.
|
| 222 |
+
|
| 223 |
+
This function defines the hyperparameter search space and training
|
| 224 |
+
procedure for each trial. It optimizes:
|
| 225 |
+
- Learning rate (1e-4 to 1e-3, log scale)
|
| 226 |
+
- Batch size (256, 512, or 1024)
|
| 227 |
+
- Encoder width multiplier (0.75 to 1.5)
|
| 228 |
+
- Head width multiplier (0.75 to 1.5)
|
| 229 |
+
- Chroma head width multiplier (1.0 to 1.75)
|
| 230 |
+
- Dropout rate (0.0 to 0.2)
|
| 231 |
+
- Weight decay (1e-5 to 1e-3, log scale)
|
| 232 |
+
|
| 233 |
+
Parameters
|
| 234 |
+
----------
|
| 235 |
+
trial : Trial
|
| 236 |
+
Optuna trial object for suggesting hyperparameters.
|
| 237 |
+
|
| 238 |
+
Returns
|
| 239 |
+
-------
|
| 240 |
+
float
|
| 241 |
+
Best validation loss achieved during training.
|
| 242 |
+
|
| 243 |
+
Raises
|
| 244 |
+
------
|
| 245 |
+
optuna.TrialPruned
|
| 246 |
+
If trial is pruned based on intermediate results.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
# Suggest hyperparameters
|
| 250 |
+
lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
|
| 251 |
+
batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024])
|
| 252 |
+
encoder_width = trial.suggest_float("encoder_width", 0.75, 1.5, step=0.25)
|
| 253 |
+
head_width = trial.suggest_float("head_width", 0.75, 1.5, step=0.25)
|
| 254 |
+
chroma_head_width = trial.suggest_float("chroma_head_width", 1.0, 1.75, step=0.25)
|
| 255 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05)
|
| 256 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
| 257 |
+
|
| 258 |
+
LOGGER.info("")
|
| 259 |
+
LOGGER.info("=" * 80)
|
| 260 |
+
LOGGER.info("Trial %d", trial.number)
|
| 261 |
+
LOGGER.info("=" * 80)
|
| 262 |
+
LOGGER.info(" lr: %.6f", lr)
|
| 263 |
+
LOGGER.info(" batch_size: %d", batch_size)
|
| 264 |
+
LOGGER.info(" encoder_width: %.2f", encoder_width)
|
| 265 |
+
LOGGER.info(" head_width: %.2f", head_width)
|
| 266 |
+
LOGGER.info(" chroma_head_width: %.2f", chroma_head_width)
|
| 267 |
+
LOGGER.info(" dropout: %.2f", dropout)
|
| 268 |
+
LOGGER.info(" weight_decay: %.6f", weight_decay)
|
| 269 |
+
|
| 270 |
+
# Set device
|
| 271 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 272 |
+
LOGGER.info(" device: %s", device)
|
| 273 |
+
|
| 274 |
+
# Load data
|
| 275 |
+
data_dir = PROJECT_ROOT / "data"
|
| 276 |
+
cache_file = data_dir / "training_data.npz"
|
| 277 |
+
data = np.load(cache_file)
|
| 278 |
+
|
| 279 |
+
X_train = data["X_train"]
|
| 280 |
+
y_train = data["y_train"]
|
| 281 |
+
X_val = data["X_val"]
|
| 282 |
+
y_val = data["y_val"]
|
| 283 |
+
|
| 284 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 285 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 286 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 287 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 288 |
+
|
| 289 |
+
# Convert to tensors
|
| 290 |
+
X_train_t = torch.from_numpy(X_train).float()
|
| 291 |
+
y_train_t = torch.from_numpy(y_train_norm).float()
|
| 292 |
+
X_val_t = torch.from_numpy(X_val).float()
|
| 293 |
+
y_val_t = torch.from_numpy(y_val_norm).float()
|
| 294 |
+
|
| 295 |
+
train_loader = DataLoader(
|
| 296 |
+
TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True
|
| 297 |
+
)
|
| 298 |
+
val_loader = DataLoader(
|
| 299 |
+
TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
LOGGER.info(
|
| 303 |
+
" Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Initialize model
|
| 307 |
+
model = MultiHeadParametric(
|
| 308 |
+
encoder_width=encoder_width,
|
| 309 |
+
head_width=head_width,
|
| 310 |
+
chroma_head_width=chroma_head_width,
|
| 311 |
+
dropout=dropout,
|
| 312 |
+
).to(device)
|
| 313 |
+
|
| 314 |
+
# Count parameters
|
| 315 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 316 |
+
LOGGER.info(" Total parameters: %s", f"{total_params:,}")
|
| 317 |
+
|
| 318 |
+
# Training setup
|
| 319 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 320 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
|
| 321 |
+
|
| 322 |
+
# MLflow setup
|
| 323 |
+
run_name = setup_mlflow_experiment(
|
| 324 |
+
"from_xyY", f"hparam_multi_head_trial_{trial.number}"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Training loop with early stopping
|
| 328 |
+
num_epochs = 100 # Reduced for hyperparameter search
|
| 329 |
+
patience = 15
|
| 330 |
+
best_val_loss = float("inf")
|
| 331 |
+
patience_counter = 0
|
| 332 |
+
|
| 333 |
+
with mlflow.start_run(run_name=run_name):
|
| 334 |
+
mlflow.log_params(
|
| 335 |
+
{
|
| 336 |
+
"trial": trial.number,
|
| 337 |
+
"lr": lr,
|
| 338 |
+
"batch_size": batch_size,
|
| 339 |
+
"encoder_width": encoder_width,
|
| 340 |
+
"head_width": head_width,
|
| 341 |
+
"chroma_head_width": chroma_head_width,
|
| 342 |
+
"dropout": dropout,
|
| 343 |
+
"weight_decay": weight_decay,
|
| 344 |
+
"total_params": total_params,
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
for epoch in range(num_epochs):
|
| 349 |
+
train_loss = train_epoch(
|
| 350 |
+
model, train_loader, optimizer, weighted_mse_loss, device
|
| 351 |
+
)
|
| 352 |
+
val_loss = validate(model, val_loader, weighted_mse_loss, device)
|
| 353 |
+
scheduler.step()
|
| 354 |
+
|
| 355 |
+
# Per-component MAE
|
| 356 |
+
with torch.no_grad():
|
| 357 |
+
pred_val = model(X_val_t.to(device))
|
| 358 |
+
mae = torch.mean(torch.abs(pred_val - y_val_t.to(device)), dim=0).cpu()
|
| 359 |
+
|
| 360 |
+
# Log to MLflow
|
| 361 |
+
mlflow.log_metrics(
|
| 362 |
+
{
|
| 363 |
+
"train_loss": train_loss,
|
| 364 |
+
"val_loss": val_loss,
|
| 365 |
+
"mae_hue": mae[0].item(),
|
| 366 |
+
"mae_value": mae[1].item(),
|
| 367 |
+
"mae_chroma": mae[2].item(),
|
| 368 |
+
"mae_code": mae[3].item(),
|
| 369 |
+
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 370 |
+
},
|
| 371 |
+
step=epoch,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if (epoch + 1) % 10 == 0:
|
| 375 |
+
LOGGER.info(
|
| 376 |
+
" Epoch %03d/%d - Train: %.6f, Val: %.6f - "
|
| 377 |
+
"MAE: hue=%.6f, value=%.6f, chroma=%.6f, code=%.6f",
|
| 378 |
+
epoch + 1,
|
| 379 |
+
num_epochs,
|
| 380 |
+
train_loss,
|
| 381 |
+
val_loss,
|
| 382 |
+
mae[0],
|
| 383 |
+
mae[1],
|
| 384 |
+
mae[2],
|
| 385 |
+
mae[3],
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Early stopping
|
| 389 |
+
if val_loss < best_val_loss:
|
| 390 |
+
best_val_loss = val_loss
|
| 391 |
+
patience_counter = 0
|
| 392 |
+
else:
|
| 393 |
+
patience_counter += 1
|
| 394 |
+
if patience_counter >= patience:
|
| 395 |
+
LOGGER.info(" Early stopping at epoch %d", epoch + 1)
|
| 396 |
+
break
|
| 397 |
+
|
| 398 |
+
# Report intermediate value for pruning
|
| 399 |
+
trial.report(val_loss, epoch)
|
| 400 |
+
|
| 401 |
+
# Handle pruning
|
| 402 |
+
if trial.should_prune():
|
| 403 |
+
LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
|
| 404 |
+
mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
|
| 405 |
+
raise optuna.TrialPruned
|
| 406 |
+
|
| 407 |
+
# Log final results
|
| 408 |
+
mlflow.log_metrics(
|
| 409 |
+
{
|
| 410 |
+
"best_val_loss": best_val_loss,
|
| 411 |
+
"final_train_loss": train_loss,
|
| 412 |
+
"final_mae_hue": mae[0].item(),
|
| 413 |
+
"final_mae_value": mae[1].item(),
|
| 414 |
+
"final_mae_chroma": mae[2].item(),
|
| 415 |
+
"final_mae_code": mae[3].item(),
|
| 416 |
+
"final_epoch": epoch + 1,
|
| 417 |
+
}
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
LOGGER.info(" Final validation loss: %.6f", best_val_loss)
|
| 421 |
+
|
| 422 |
+
return best_val_loss
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def main() -> None:
|
| 426 |
+
"""
|
| 427 |
+
Run hyperparameter search for Multi-Head model (xyY to Munsell).
|
| 428 |
+
|
| 429 |
+
Performs systematic hyperparameter optimization using Optuna with:
|
| 430 |
+
- MedianPruner for early stopping of unpromising trials
|
| 431 |
+
- 20 total trials
|
| 432 |
+
- MLflow logging for each trial
|
| 433 |
+
- Result visualization using matplotlib (optimization history,
|
| 434 |
+
parameter importances, parallel coordinate plot)
|
| 435 |
+
|
| 436 |
+
The search aims to find optimal hyperparameters for converting xyY
|
| 437 |
+
color coordinates to Munsell color specifications using a multi-head
|
| 438 |
+
architecture with shared encoder and component-specific heads.
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
LOGGER.info("=" * 80)
|
| 442 |
+
LOGGER.info("Multi-Head (from_xyY) Hyperparameter Search with Optuna")
|
| 443 |
+
LOGGER.info("=" * 80)
|
| 444 |
+
|
| 445 |
+
# Create study
|
| 446 |
+
study = optuna.create_study(
|
| 447 |
+
direction="minimize",
|
| 448 |
+
study_name="multi_head_from_xyY_hparam_search",
|
| 449 |
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Run optimization
|
| 453 |
+
n_trials = 20 # Number of trials to run
|
| 454 |
+
|
| 455 |
+
LOGGER.info("")
|
| 456 |
+
LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
|
| 457 |
+
LOGGER.info("")
|
| 458 |
+
|
| 459 |
+
study.optimize(objective, n_trials=n_trials, timeout=None)
|
| 460 |
+
|
| 461 |
+
# Print results
|
| 462 |
+
LOGGER.info("")
|
| 463 |
+
LOGGER.info("=" * 80)
|
| 464 |
+
LOGGER.info("Hyperparameter Search Results")
|
| 465 |
+
LOGGER.info("=" * 80)
|
| 466 |
+
LOGGER.info("")
|
| 467 |
+
LOGGER.info("Best trial:")
|
| 468 |
+
LOGGER.info(" Value (val_loss): %.6f", study.best_value)
|
| 469 |
+
LOGGER.info("")
|
| 470 |
+
LOGGER.info("Best hyperparameters:")
|
| 471 |
+
for key, value in study.best_params.items():
|
| 472 |
+
LOGGER.info(" %s: %s", key, value)
|
| 473 |
+
|
| 474 |
+
# Save results
|
| 475 |
+
results_dir = PROJECT_ROOT / "results" / "from_xyY"
|
| 476 |
+
results_dir.mkdir(exist_ok=True, parents=True)
|
| 477 |
+
|
| 478 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 479 |
+
results_file = results_dir / f"hparam_search_multi_head_{timestamp}.txt"
|
| 480 |
+
|
| 481 |
+
with open(results_file, "w") as f:
|
| 482 |
+
f.write("=" * 80 + "\n")
|
| 483 |
+
f.write("Multi-Head (from_xyY) Hyperparameter Search Results\n")
|
| 484 |
+
f.write("=" * 80 + "\n\n")
|
| 485 |
+
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 486 |
+
f.write(f"Number of trials: {len(study.trials)}\n")
|
| 487 |
+
f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
|
| 488 |
+
f.write("Best hyperparameters:\n")
|
| 489 |
+
for key, value in study.best_params.items():
|
| 490 |
+
f.write(f" {key}: {value}\n")
|
| 491 |
+
f.write("\n\nAll trials:\n")
|
| 492 |
+
f.write("-" * 80 + "\n")
|
| 493 |
+
|
| 494 |
+
for t in study.trials:
|
| 495 |
+
f.write(f"\nTrial {t.number}:\n")
|
| 496 |
+
if t.value is not None:
|
| 497 |
+
f.write(f" Value: {t.value:.6f}\n")
|
| 498 |
+
else:
|
| 499 |
+
f.write(" Value: Pruned\n")
|
| 500 |
+
f.write(" Params:\n")
|
| 501 |
+
for key, value in t.params.items():
|
| 502 |
+
f.write(f" {key}: {value}\n")
|
| 503 |
+
|
| 504 |
+
LOGGER.info("")
|
| 505 |
+
LOGGER.info("Results saved to: %s", results_file)
|
| 506 |
+
|
| 507 |
+
# Generate visualizations using matplotlib
|
| 508 |
+
from optuna.visualization.matplotlib import (
|
| 509 |
+
plot_optimization_history,
|
| 510 |
+
plot_param_importances,
|
| 511 |
+
plot_parallel_coordinate,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Optimization history
|
| 515 |
+
ax = plot_optimization_history(study)
|
| 516 |
+
ax.figure.savefig(
|
| 517 |
+
results_dir / f"optimization_history_multi_head_{timestamp}.png", dpi=150
|
| 518 |
+
)
|
| 519 |
+
plt.close(ax.figure)
|
| 520 |
+
|
| 521 |
+
# Parameter importances
|
| 522 |
+
ax = plot_param_importances(study)
|
| 523 |
+
ax.figure.savefig(
|
| 524 |
+
results_dir / f"param_importances_multi_head_{timestamp}.png", dpi=150
|
| 525 |
+
)
|
| 526 |
+
plt.close(ax.figure)
|
| 527 |
+
|
| 528 |
+
# Parallel coordinate plot
|
| 529 |
+
ax = plot_parallel_coordinate(study)
|
| 530 |
+
ax.figure.savefig(
|
| 531 |
+
results_dir / f"parallel_coordinate_multi_head_{timestamp}.png", dpi=150
|
| 532 |
+
)
|
| 533 |
+
plt.close(ax.figure)
|
| 534 |
+
|
| 535 |
+
LOGGER.info("Visualizations saved to: %s", results_dir)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
if __name__ == "__main__":
|
| 539 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 540 |
+
|
| 541 |
+
main()
|
learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter search for Multi-Head Error Predictor using Optuna.
|
| 3 |
+
|
| 4 |
+
Optimizes:
|
| 5 |
+
- Learning rate
|
| 6 |
+
- Batch size
|
| 7 |
+
- Width multipliers for each component branch (hue, value, chroma, code)
|
| 8 |
+
- Loss function component weights
|
| 9 |
+
|
| 10 |
+
Objective: Minimize validation loss (combined base + error predictor)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import mlflow
|
| 21 |
+
import numpy as np
|
| 22 |
+
import onnxruntime as ort
|
| 23 |
+
import optuna
|
| 24 |
+
import torch
|
| 25 |
+
from numpy.typing import NDArray
|
| 26 |
+
from optuna.trial import Trial
|
| 27 |
+
from torch import nn, optim
|
| 28 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 29 |
+
|
| 30 |
+
from learning_munsell import PROJECT_ROOT
|
| 31 |
+
from learning_munsell.models.networks import ComponentErrorPredictor
|
| 32 |
+
from learning_munsell.utilities.common import setup_mlflow_experiment
|
| 33 |
+
from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
|
| 34 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 35 |
+
|
| 36 |
+
LOGGER = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiHeadErrorPredictorParametric(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Parametric Multi-Head error predictor with 4 independent branches.
|
| 42 |
+
|
| 43 |
+
This model consists of four independent ComponentErrorPredictor
|
| 44 |
+
networks, one for each Munsell component (hue, value, chroma, code).
|
| 45 |
+
Each branch can have different widths for hyperparameter optimization.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
hue_width : float, optional
|
| 50 |
+
Width multiplier for the hue branch. Default is 1.0.
|
| 51 |
+
value_width : float, optional
|
| 52 |
+
Width multiplier for the value branch. Default is 1.0.
|
| 53 |
+
chroma_width : float, optional
|
| 54 |
+
Width multiplier for the chroma branch. Default is 1.5.
|
| 55 |
+
code_width : float, optional
|
| 56 |
+
Width multiplier for the code branch. Default is 1.0.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
hue_width: float = 1.0,
|
| 62 |
+
value_width: float = 1.0,
|
| 63 |
+
chroma_width: float = 1.5,
|
| 64 |
+
code_width: float = 1.0,
|
| 65 |
+
) -> None:
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
# Independent error predictor for each component
|
| 69 |
+
self.hue_branch = ComponentErrorPredictor(width_multiplier=hue_width)
|
| 70 |
+
self.value_branch = ComponentErrorPredictor(
|
| 71 |
+
width_multiplier=value_width
|
| 72 |
+
)
|
| 73 |
+
self.chroma_branch = ComponentErrorPredictor(
|
| 74 |
+
width_multiplier=chroma_width
|
| 75 |
+
)
|
| 76 |
+
self.code_branch = ComponentErrorPredictor(
|
| 77 |
+
width_multiplier=code_width
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Forward pass through all four error predictor branches.
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
x : torch.Tensor
|
| 87 |
+
Input tensor of shape (batch_size, 7) containing normalized
|
| 88 |
+
xyY values and base model predictions.
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
torch.Tensor
|
| 93 |
+
Predicted errors for all components, shape (batch_size, 4).
|
| 94 |
+
Output order: [hue_error, value_error, chroma_error, code_error].
|
| 95 |
+
"""
|
| 96 |
+
# Each branch processes the same combined input independently
|
| 97 |
+
hue_error = self.hue_branch(x)
|
| 98 |
+
value_error = self.value_branch(x)
|
| 99 |
+
chroma_error = self.chroma_branch(x)
|
| 100 |
+
code_error = self.code_branch(x)
|
| 101 |
+
|
| 102 |
+
# Concatenate: [Hue_error, Value_error, Chroma_error, Code_error]
|
| 103 |
+
return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def load_base_model(
|
| 107 |
+
model_path: Path, params_path: Path
|
| 108 |
+
) -> tuple[ort.InferenceSession, dict, dict]:
|
| 109 |
+
"""
|
| 110 |
+
Load the base Multi-Head ONNX model and its normalization parameters.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
model_path : Path
|
| 115 |
+
Path to the base Multi-Head model ONNX file.
|
| 116 |
+
params_path : Path
|
| 117 |
+
Path to the normalization parameters NPZ file.
|
| 118 |
+
|
| 119 |
+
Returns
|
| 120 |
+
-------
|
| 121 |
+
ort.InferenceSession
|
| 122 |
+
ONNX Runtime inference session for the base model.
|
| 123 |
+
dict
|
| 124 |
+
Input normalization parameters (x_range, y_range, Y_range).
|
| 125 |
+
dict
|
| 126 |
+
Output normalization parameters (hue_range, value_range, chroma_range, code_range).
|
| 127 |
+
"""
|
| 128 |
+
session = ort.InferenceSession(str(model_path))
|
| 129 |
+
params = np.load(params_path, allow_pickle=True)
|
| 130 |
+
return session, params["input_params"].item(), params["output_params"].item()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def create_weighted_loss(
|
| 134 |
+
mse_weight: float,
|
| 135 |
+
mae_weight: float,
|
| 136 |
+
log_weight: float,
|
| 137 |
+
huber_weight: float,
|
| 138 |
+
huber_delta: float,
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
Create a weighted loss function combining multiple loss components.
|
| 142 |
+
|
| 143 |
+
Parameters
|
| 144 |
+
----------
|
| 145 |
+
mse_weight : float
|
| 146 |
+
Weight for MSE component.
|
| 147 |
+
mae_weight : float
|
| 148 |
+
Weight for MAE component.
|
| 149 |
+
log_weight : float
|
| 150 |
+
Weight for logarithmic penalty component.
|
| 151 |
+
huber_weight : float
|
| 152 |
+
Weight for Huber loss component.
|
| 153 |
+
huber_delta : float
|
| 154 |
+
Delta parameter for Huber loss transition point.
|
| 155 |
+
|
| 156 |
+
Returns
|
| 157 |
+
-------
|
| 158 |
+
callable
|
| 159 |
+
Loss function that accepts (pred, target) and returns a scalar loss.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def weighted_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""
|
| 164 |
+
Compute weighted combination of loss components.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
pred : torch.Tensor
|
| 169 |
+
Predicted values, shape (batch_size, n_components).
|
| 170 |
+
target : torch.Tensor
|
| 171 |
+
Target values, shape (batch_size, n_components).
|
| 172 |
+
|
| 173 |
+
Returns
|
| 174 |
+
-------
|
| 175 |
+
torch.Tensor
|
| 176 |
+
Weighted combination of loss components, scalar tensor.
|
| 177 |
+
"""
|
| 178 |
+
# Standard MSE
|
| 179 |
+
mse = torch.mean((pred - target) ** 2)
|
| 180 |
+
|
| 181 |
+
# Mean absolute error
|
| 182 |
+
mae = torch.mean(torch.abs(pred - target))
|
| 183 |
+
|
| 184 |
+
# Logarithmic penalty
|
| 185 |
+
log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
|
| 186 |
+
|
| 187 |
+
# Huber loss
|
| 188 |
+
abs_error = torch.abs(pred - target)
|
| 189 |
+
huber = torch.where(
|
| 190 |
+
abs_error <= huber_delta,
|
| 191 |
+
0.5 * abs_error**2,
|
| 192 |
+
huber_delta * (abs_error - 0.5 * huber_delta),
|
| 193 |
+
)
|
| 194 |
+
huber_loss = torch.mean(huber)
|
| 195 |
+
|
| 196 |
+
# Combine with weights
|
| 197 |
+
return (
|
| 198 |
+
mse_weight * mse
|
| 199 |
+
+ mae_weight * mae
|
| 200 |
+
+ log_weight * log_penalty
|
| 201 |
+
+ huber_weight * huber_loss
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return weighted_loss
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def objective(trial: Trial) -> float:
|
| 208 |
+
"""
|
| 209 |
+
Optuna objective function to minimize validation loss.
|
| 210 |
+
|
| 211 |
+
This function defines the hyperparameter search space and training
|
| 212 |
+
procedure for each trial. It optimizes:
|
| 213 |
+
- Learning rate (1e-4 to 1e-3, log scale)
|
| 214 |
+
- Batch size (512, 1024, or 2048)
|
| 215 |
+
- Width multipliers for each component branch
|
| 216 |
+
- Loss function weights (MSE, MAE, log penalty, Huber)
|
| 217 |
+
- Huber delta parameter (0.005 to 0.02)
|
| 218 |
+
|
| 219 |
+
Parameters
|
| 220 |
+
----------
|
| 221 |
+
trial : Trial
|
| 222 |
+
Optuna trial object for suggesting hyperparameters.
|
| 223 |
+
|
| 224 |
+
Returns
|
| 225 |
+
-------
|
| 226 |
+
float
|
| 227 |
+
Best validation loss achieved during training.
|
| 228 |
+
|
| 229 |
+
Raises
|
| 230 |
+
------
|
| 231 |
+
optuna.TrialPruned
|
| 232 |
+
If trial is pruned based on intermediate results.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
# Suggest hyperparameters
|
| 236 |
+
lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
|
| 237 |
+
batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048])
|
| 238 |
+
hue_width = trial.suggest_float("hue_width", 0.75, 1.5, step=0.25)
|
| 239 |
+
value_width = trial.suggest_float("value_width", 0.75, 1.5, step=0.25)
|
| 240 |
+
chroma_width = trial.suggest_float("chroma_width", 1.0, 2.0, step=0.25)
|
| 241 |
+
code_width = trial.suggest_float("code_width", 0.75, 1.5, step=0.25)
|
| 242 |
+
|
| 243 |
+
# Loss function weights
|
| 244 |
+
mse_weight = trial.suggest_float("mse_weight", 0.5, 2.0, step=0.5)
|
| 245 |
+
mae_weight = trial.suggest_float("mae_weight", 0.0, 1.0, step=0.25)
|
| 246 |
+
log_weight = trial.suggest_float("log_weight", 0.0, 0.5, step=0.1)
|
| 247 |
+
huber_weight = trial.suggest_float("huber_weight", 0.0, 1.0, step=0.25)
|
| 248 |
+
huber_delta = trial.suggest_float("huber_delta", 0.005, 0.02, step=0.005)
|
| 249 |
+
|
| 250 |
+
LOGGER.info("")
|
| 251 |
+
LOGGER.info("=" * 80)
|
| 252 |
+
LOGGER.info("Trial %d", trial.number)
|
| 253 |
+
LOGGER.info("=" * 80)
|
| 254 |
+
LOGGER.info(" lr: %.6f", lr)
|
| 255 |
+
LOGGER.info(" batch_size: %d", batch_size)
|
| 256 |
+
LOGGER.info(" hue_width: %.2f", hue_width)
|
| 257 |
+
LOGGER.info(" value_width: %.2f", value_width)
|
| 258 |
+
LOGGER.info(" chroma_width: %.2f", chroma_width)
|
| 259 |
+
LOGGER.info(" code_width: %.2f", code_width)
|
| 260 |
+
LOGGER.info(" mse_weight: %.2f", mse_weight)
|
| 261 |
+
LOGGER.info(" mae_weight: %.2f", mae_weight)
|
| 262 |
+
LOGGER.info(" log_weight: %.2f", log_weight)
|
| 263 |
+
LOGGER.info(" huber_weight: %.2f", huber_weight)
|
| 264 |
+
LOGGER.info(" huber_delta: %.3f", huber_delta)
|
| 265 |
+
|
| 266 |
+
# Set device
|
| 267 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 268 |
+
LOGGER.info(" device: %s", device)
|
| 269 |
+
|
| 270 |
+
# Paths
|
| 271 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 272 |
+
data_dir = PROJECT_ROOT / "data"
|
| 273 |
+
|
| 274 |
+
base_model_path = model_directory / "multi_head.onnx"
|
| 275 |
+
params_path = model_directory / "multi_head_normalization_params.npz"
|
| 276 |
+
cache_file = data_dir / "training_data.npz"
|
| 277 |
+
|
| 278 |
+
# Load base model
|
| 279 |
+
base_session, input_params, output_params = load_base_model(
|
| 280 |
+
base_model_path, params_path
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Load training data
|
| 284 |
+
data = np.load(cache_file)
|
| 285 |
+
X_train = data["X_train"]
|
| 286 |
+
y_train = data["y_train"]
|
| 287 |
+
X_val = data["X_val"]
|
| 288 |
+
y_val = data["y_val"]
|
| 289 |
+
|
| 290 |
+
# Normalize
|
| 291 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 292 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 293 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 294 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 295 |
+
|
| 296 |
+
# Generate base model predictions
|
| 297 |
+
base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
|
| 298 |
+
base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
|
| 299 |
+
|
| 300 |
+
# Compute errors
|
| 301 |
+
error_train = y_train_norm - base_pred_train_norm
|
| 302 |
+
error_val = y_val_norm - base_pred_val_norm
|
| 303 |
+
|
| 304 |
+
# Create combined input: [xyY_norm, base_prediction_norm]
|
| 305 |
+
X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
|
| 306 |
+
X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
|
| 307 |
+
|
| 308 |
+
# Convert to PyTorch tensors
|
| 309 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 310 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 311 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 312 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 313 |
+
|
| 314 |
+
# Create data loaders
|
| 315 |
+
train_loader = DataLoader(
|
| 316 |
+
TensorDataset(X_train_t, error_train_t), batch_size=batch_size, shuffle=True
|
| 317 |
+
)
|
| 318 |
+
val_loader = DataLoader(
|
| 319 |
+
TensorDataset(X_val_t, error_val_t), batch_size=batch_size, shuffle=False
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
LOGGER.info(
|
| 323 |
+
" Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Initialize error predictor model
|
| 327 |
+
model = MultiHeadErrorPredictorParametric(
|
| 328 |
+
hue_width=hue_width,
|
| 329 |
+
value_width=value_width,
|
| 330 |
+
chroma_width=chroma_width,
|
| 331 |
+
code_width=code_width,
|
| 332 |
+
).to(device)
|
| 333 |
+
|
| 334 |
+
# Count parameters
|
| 335 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 336 |
+
LOGGER.info(" Total parameters: %s", f"{total_params:,}")
|
| 337 |
+
|
| 338 |
+
# Training setup
|
| 339 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 340 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 341 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Create loss function
|
| 345 |
+
criterion = create_weighted_loss(
|
| 346 |
+
mse_weight, mae_weight, log_weight, huber_weight, huber_delta
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# MLflow setup
|
| 350 |
+
run_name = setup_mlflow_experiment(
|
| 351 |
+
"from_xyY", f"hparam_multi_head_error_trial_{trial.number}"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Training loop with early stopping
|
| 355 |
+
num_epochs = 50 # Reduced for hyperparameter search
|
| 356 |
+
patience = 10
|
| 357 |
+
best_val_loss = float("inf")
|
| 358 |
+
patience_counter = 0
|
| 359 |
+
|
| 360 |
+
with mlflow.start_run(run_name=run_name):
|
| 361 |
+
mlflow.log_params(
|
| 362 |
+
{
|
| 363 |
+
"lr": lr,
|
| 364 |
+
"batch_size": batch_size,
|
| 365 |
+
"hue_width": hue_width,
|
| 366 |
+
"value_width": value_width,
|
| 367 |
+
"chroma_width": chroma_width,
|
| 368 |
+
"code_width": code_width,
|
| 369 |
+
"mse_weight": mse_weight,
|
| 370 |
+
"mae_weight": mae_weight,
|
| 371 |
+
"log_weight": log_weight,
|
| 372 |
+
"huber_weight": huber_weight,
|
| 373 |
+
"huber_delta": huber_delta,
|
| 374 |
+
"total_params": total_params,
|
| 375 |
+
"trial_number": trial.number,
|
| 376 |
+
}
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
for epoch in range(num_epochs):
|
| 380 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 381 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 382 |
+
scheduler.step(val_loss)
|
| 383 |
+
|
| 384 |
+
# Log to MLflow
|
| 385 |
+
mlflow.log_metrics(
|
| 386 |
+
{
|
| 387 |
+
"train_loss": train_loss,
|
| 388 |
+
"val_loss": val_loss,
|
| 389 |
+
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 390 |
+
},
|
| 391 |
+
step=epoch,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if (epoch + 1) % 10 == 0:
|
| 395 |
+
LOGGER.info(
|
| 396 |
+
" Epoch %03d/%d - Train: %.6f, Val: %.6f, LR: %.6f",
|
| 397 |
+
epoch + 1,
|
| 398 |
+
num_epochs,
|
| 399 |
+
train_loss,
|
| 400 |
+
val_loss,
|
| 401 |
+
optimizer.param_groups[0]["lr"],
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Early stopping
|
| 405 |
+
if val_loss < best_val_loss:
|
| 406 |
+
best_val_loss = val_loss
|
| 407 |
+
patience_counter = 0
|
| 408 |
+
else:
|
| 409 |
+
patience_counter += 1
|
| 410 |
+
if patience_counter >= patience:
|
| 411 |
+
LOGGER.info(" Early stopping at epoch %d", epoch + 1)
|
| 412 |
+
break
|
| 413 |
+
|
| 414 |
+
# Report intermediate value for pruning
|
| 415 |
+
trial.report(val_loss, epoch)
|
| 416 |
+
|
| 417 |
+
# Handle pruning
|
| 418 |
+
if trial.should_prune():
|
| 419 |
+
LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
|
| 420 |
+
mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
|
| 421 |
+
raise optuna.TrialPruned
|
| 422 |
+
|
| 423 |
+
# Log final results
|
| 424 |
+
mlflow.log_metrics(
|
| 425 |
+
{
|
| 426 |
+
"best_val_loss": best_val_loss,
|
| 427 |
+
"final_train_loss": train_loss,
|
| 428 |
+
}
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
LOGGER.info(" Final validation loss: %.6f", best_val_loss)
|
| 432 |
+
|
| 433 |
+
return best_val_loss
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def main() -> None:
|
| 437 |
+
"""
|
| 438 |
+
Run hyperparameter search for Multi-Head Error Predictor.
|
| 439 |
+
|
| 440 |
+
Performs systematic hyperparameter optimization using Optuna with:
|
| 441 |
+
- MedianPruner for early stopping of unpromising trials
|
| 442 |
+
- 30 total trials
|
| 443 |
+
- MLflow logging for each trial
|
| 444 |
+
- Result visualization using matplotlib (optimization history,
|
| 445 |
+
parameter importances, parallel coordinate plot)
|
| 446 |
+
|
| 447 |
+
The search aims to find optimal hyperparameters for predicting errors
|
| 448 |
+
in a base Multi-Head model, allowing for error correction and improved
|
| 449 |
+
Munsell predictions.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
LOGGER.info("=" * 80)
|
| 453 |
+
LOGGER.info("Multi-Head Error Predictor Hyperparameter Search with Optuna")
|
| 454 |
+
LOGGER.info("=" * 80)
|
| 455 |
+
|
| 456 |
+
# Create study
|
| 457 |
+
study = optuna.create_study(
|
| 458 |
+
direction="minimize",
|
| 459 |
+
study_name="multi_head_error_predictor_hparam_search",
|
| 460 |
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=5),
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Run optimization
|
| 464 |
+
n_trials = 30 # Number of trials to run
|
| 465 |
+
|
| 466 |
+
LOGGER.info("")
|
| 467 |
+
LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
|
| 468 |
+
LOGGER.info("")
|
| 469 |
+
|
| 470 |
+
study.optimize(objective, n_trials=n_trials, timeout=None)
|
| 471 |
+
|
| 472 |
+
# Print results
|
| 473 |
+
LOGGER.info("")
|
| 474 |
+
LOGGER.info("=" * 80)
|
| 475 |
+
LOGGER.info("Hyperparameter Search Results")
|
| 476 |
+
LOGGER.info("=" * 80)
|
| 477 |
+
LOGGER.info("")
|
| 478 |
+
LOGGER.info("Best trial:")
|
| 479 |
+
LOGGER.info(" Value (val_loss): %.6f", study.best_value)
|
| 480 |
+
LOGGER.info("")
|
| 481 |
+
LOGGER.info("Best hyperparameters:")
|
| 482 |
+
for key, value in study.best_params.items():
|
| 483 |
+
LOGGER.info(" %s: %s", key, value)
|
| 484 |
+
|
| 485 |
+
# Save results
|
| 486 |
+
results_dir = PROJECT_ROOT / "results" / "from_xyY"
|
| 487 |
+
results_dir.mkdir(exist_ok=True, parents=True)
|
| 488 |
+
|
| 489 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 490 |
+
results_file = results_dir / f"hparam_search_multi_head_error_{timestamp}.txt"
|
| 491 |
+
|
| 492 |
+
with open(results_file, "w") as f:
|
| 493 |
+
f.write("=" * 80 + "\n")
|
| 494 |
+
f.write("Multi-Head Error Predictor Hyperparameter Search Results\n")
|
| 495 |
+
f.write("=" * 80 + "\n\n")
|
| 496 |
+
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 497 |
+
f.write(f"Number of trials: {len(study.trials)}\n")
|
| 498 |
+
f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
|
| 499 |
+
f.write("Best hyperparameters:\n")
|
| 500 |
+
for key, value in study.best_params.items():
|
| 501 |
+
f.write(f" {key}: {value}\n")
|
| 502 |
+
f.write("\n\nAll trials:\n")
|
| 503 |
+
f.write("-" * 80 + "\n")
|
| 504 |
+
|
| 505 |
+
for t in study.trials:
|
| 506 |
+
f.write(f"\nTrial {t.number}:\n")
|
| 507 |
+
if t.value is not None:
|
| 508 |
+
f.write(f" Value: {t.value:.6f}\n")
|
| 509 |
+
else:
|
| 510 |
+
f.write(" Value: Pruned\n")
|
| 511 |
+
f.write(" Params:\n")
|
| 512 |
+
for key, value in t.params.items():
|
| 513 |
+
f.write(f" {key}: {value}\n")
|
| 514 |
+
|
| 515 |
+
LOGGER.info("")
|
| 516 |
+
LOGGER.info("Results saved to: %s", results_file)
|
| 517 |
+
|
| 518 |
+
# Generate visualizations using matplotlib
|
| 519 |
+
from optuna.visualization.matplotlib import (
|
| 520 |
+
plot_optimization_history,
|
| 521 |
+
plot_param_importances,
|
| 522 |
+
plot_parallel_coordinate,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Optimization history
|
| 526 |
+
ax = plot_optimization_history(study)
|
| 527 |
+
ax.figure.savefig(
|
| 528 |
+
results_dir / f"optimization_history_multi_head_error_{timestamp}.png", dpi=150
|
| 529 |
+
)
|
| 530 |
+
plt.close(ax.figure)
|
| 531 |
+
|
| 532 |
+
# Parameter importances
|
| 533 |
+
ax = plot_param_importances(study)
|
| 534 |
+
ax.figure.savefig(
|
| 535 |
+
results_dir / f"param_importances_multi_head_error_{timestamp}.png", dpi=150
|
| 536 |
+
)
|
| 537 |
+
plt.close(ax.figure)
|
| 538 |
+
|
| 539 |
+
# Parallel coordinate plot
|
| 540 |
+
ax = plot_parallel_coordinate(study)
|
| 541 |
+
ax.figure.savefig(
|
| 542 |
+
results_dir / f"parallel_coordinate_multi_head_error_{timestamp}.png", dpi=150
|
| 543 |
+
)
|
| 544 |
+
plt.close(ax.figure)
|
| 545 |
+
|
| 546 |
+
LOGGER.info("Visualizations saved to: %s", results_dir)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 551 |
+
|
| 552 |
+
main()
|
learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter search for Multi-MLP model using Optuna.
|
| 3 |
+
|
| 4 |
+
Optimizes:
|
| 5 |
+
- Learning rate
|
| 6 |
+
- Batch size
|
| 7 |
+
- Chroma width multiplier
|
| 8 |
+
- Chroma loss weight
|
| 9 |
+
- Code loss weight
|
| 10 |
+
- Dropout (optional)
|
| 11 |
+
|
| 12 |
+
Objective: Minimize validation loss
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import mlflow
|
| 20 |
+
import numpy as np
|
| 21 |
+
import optuna
|
| 22 |
+
import torch
|
| 23 |
+
from numpy.typing import NDArray
|
| 24 |
+
from optuna.trial import Trial
|
| 25 |
+
from torch import nn, optim
|
| 26 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 27 |
+
|
| 28 |
+
from learning_munsell import PROJECT_ROOT
|
| 29 |
+
from learning_munsell.models.networks import MultiMLPToMunsell
|
| 30 |
+
from learning_munsell.utilities.common import setup_mlflow_experiment
|
| 31 |
+
from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell
|
| 32 |
+
|
| 33 |
+
LOGGER = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def weighted_mse_loss(
|
| 37 |
+
pred: torch.Tensor,
|
| 38 |
+
target: torch.Tensor,
|
| 39 |
+
hue_weight: float = 1.0,
|
| 40 |
+
value_weight: float = 1.0,
|
| 41 |
+
chroma_weight: float = 4.0,
|
| 42 |
+
code_weight: float = 0.5,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Component-wise weighted MSE loss with configurable weights.
|
| 46 |
+
|
| 47 |
+
Applies different weights to each Munsell component to account for
|
| 48 |
+
varying prediction difficulty and importance.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
pred : torch.Tensor
|
| 53 |
+
Predicted values, shape (batch_size, 4).
|
| 54 |
+
target : torch.Tensor
|
| 55 |
+
Target values, shape (batch_size, 4).
|
| 56 |
+
hue_weight : float, optional
|
| 57 |
+
Weight for hue component. Default is 1.0.
|
| 58 |
+
value_weight : float, optional
|
| 59 |
+
Weight for value component. Default is 1.0.
|
| 60 |
+
chroma_weight : float, optional
|
| 61 |
+
Weight for chroma component (typically higher). Default is 4.0.
|
| 62 |
+
code_weight : float, optional
|
| 63 |
+
Weight for code component (typically lower). Default is 0.5.
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
torch.Tensor
|
| 68 |
+
Weighted MSE loss, scalar tensor.
|
| 69 |
+
"""
|
| 70 |
+
weights = torch.tensor(
|
| 71 |
+
[hue_weight, value_weight, chroma_weight, code_weight], device=pred.device
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
mse = (pred - target) ** 2
|
| 75 |
+
weighted_mse = mse * weights
|
| 76 |
+
return weighted_mse.mean()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def train_epoch(
|
| 80 |
+
model: nn.Module,
|
| 81 |
+
dataloader: DataLoader,
|
| 82 |
+
optimizer: optim.Optimizer,
|
| 83 |
+
device: torch.device,
|
| 84 |
+
chroma_weight: float,
|
| 85 |
+
code_weight: float,
|
| 86 |
+
) -> float:
|
| 87 |
+
"""
|
| 88 |
+
Train the model for one epoch.
|
| 89 |
+
|
| 90 |
+
Parameters
|
| 91 |
+
----------
|
| 92 |
+
model : nn.Module
|
| 93 |
+
Multi-MLP model to train.
|
| 94 |
+
dataloader : DataLoader
|
| 95 |
+
DataLoader providing training batches.
|
| 96 |
+
optimizer : optim.Optimizer
|
| 97 |
+
Optimizer for updating model parameters.
|
| 98 |
+
device : torch.device
|
| 99 |
+
Device to run training on (CPU, CUDA, or MPS).
|
| 100 |
+
chroma_weight : float
|
| 101 |
+
Weight for chroma component in loss function.
|
| 102 |
+
code_weight : float
|
| 103 |
+
Weight for code component in loss function.
|
| 104 |
+
|
| 105 |
+
Returns
|
| 106 |
+
-------
|
| 107 |
+
float
|
| 108 |
+
Average training loss over the epoch.
|
| 109 |
+
"""
|
| 110 |
+
model.train()
|
| 111 |
+
total_loss = 0.0
|
| 112 |
+
|
| 113 |
+
for X_batch, y_batch in dataloader:
|
| 114 |
+
X_batch = X_batch.to(device)
|
| 115 |
+
y_batch = y_batch.to(device)
|
| 116 |
+
# Forward pass
|
| 117 |
+
outputs = model(X_batch)
|
| 118 |
+
loss = weighted_mse_loss(
|
| 119 |
+
outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Backward pass
|
| 123 |
+
optimizer.zero_grad()
|
| 124 |
+
loss.backward()
|
| 125 |
+
optimizer.step()
|
| 126 |
+
|
| 127 |
+
total_loss += loss.item()
|
| 128 |
+
|
| 129 |
+
return total_loss / len(dataloader)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def validate(
|
| 133 |
+
model: nn.Module,
|
| 134 |
+
dataloader: DataLoader,
|
| 135 |
+
device: torch.device,
|
| 136 |
+
chroma_weight: float,
|
| 137 |
+
code_weight: float,
|
| 138 |
+
) -> float:
|
| 139 |
+
"""
|
| 140 |
+
Validate the model on the validation set.
|
| 141 |
+
|
| 142 |
+
Parameters
|
| 143 |
+
----------
|
| 144 |
+
model : nn.Module
|
| 145 |
+
Multi-MLP model to validate.
|
| 146 |
+
dataloader : DataLoader
|
| 147 |
+
DataLoader providing validation batches.
|
| 148 |
+
device : torch.device
|
| 149 |
+
Device to run validation on (CPU, CUDA, or MPS).
|
| 150 |
+
chroma_weight : float
|
| 151 |
+
Weight for chroma component in loss function.
|
| 152 |
+
code_weight : float
|
| 153 |
+
Weight for code component in loss function.
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
-------
|
| 157 |
+
float
|
| 158 |
+
Average validation loss.
|
| 159 |
+
"""
|
| 160 |
+
model.eval()
|
| 161 |
+
total_loss = 0.0
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
for X_batch, y_batch in dataloader:
|
| 165 |
+
X_batch = X_batch.to(device)
|
| 166 |
+
y_batch = y_batch.to(device)
|
| 167 |
+
outputs = model(X_batch)
|
| 168 |
+
loss = weighted_mse_loss(
|
| 169 |
+
outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
total_loss += loss.item()
|
| 173 |
+
|
| 174 |
+
return total_loss / len(dataloader)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def objective(trial: Trial) -> float:
|
| 178 |
+
"""
|
| 179 |
+
Optuna objective function to minimize validation loss.
|
| 180 |
+
|
| 181 |
+
This function defines the hyperparameter search space and training
|
| 182 |
+
procedure for each trial. It optimizes:
|
| 183 |
+
- Learning rate (1e-4 to 1e-3, log scale)
|
| 184 |
+
- Batch size (512, 1024, or 2048)
|
| 185 |
+
- Chroma branch width multiplier (1.5 to 2.5)
|
| 186 |
+
- Chroma loss weight (3.0 to 6.0)
|
| 187 |
+
- Code loss weight (0.3 to 1.0)
|
| 188 |
+
- Dropout rate (0.0 to 0.2)
|
| 189 |
+
|
| 190 |
+
Parameters
|
| 191 |
+
----------
|
| 192 |
+
trial : Trial
|
| 193 |
+
Optuna trial object for suggesting hyperparameters.
|
| 194 |
+
|
| 195 |
+
Returns
|
| 196 |
+
-------
|
| 197 |
+
float
|
| 198 |
+
Best validation loss achieved during training.
|
| 199 |
+
|
| 200 |
+
Raises
|
| 201 |
+
------
|
| 202 |
+
FileNotFoundError
|
| 203 |
+
If training data file is not found.
|
| 204 |
+
optuna.TrialPruned
|
| 205 |
+
If trial is pruned based on intermediate results.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
# Suggest hyperparameters
|
| 209 |
+
lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
|
| 210 |
+
batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048])
|
| 211 |
+
chroma_width = trial.suggest_float("chroma_width", 1.5, 2.5, step=0.25)
|
| 212 |
+
chroma_weight = trial.suggest_float("chroma_weight", 3.0, 6.0, step=0.5)
|
| 213 |
+
code_weight = trial.suggest_float("code_weight", 0.3, 1.0, step=0.1)
|
| 214 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05)
|
| 215 |
+
|
| 216 |
+
LOGGER.info("")
|
| 217 |
+
LOGGER.info("=" * 80)
|
| 218 |
+
LOGGER.info("Trial %d", trial.number)
|
| 219 |
+
LOGGER.info("=" * 80)
|
| 220 |
+
LOGGER.info(" lr: %.6f", lr)
|
| 221 |
+
LOGGER.info(" batch_size: %d", batch_size)
|
| 222 |
+
LOGGER.info(" chroma_width: %.2f", chroma_width)
|
| 223 |
+
LOGGER.info(" chroma_weight: %.1f", chroma_weight)
|
| 224 |
+
LOGGER.info(" code_weight: %.1f", code_weight)
|
| 225 |
+
LOGGER.info(" dropout: %.2f", dropout)
|
| 226 |
+
|
| 227 |
+
# Set device
|
| 228 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 229 |
+
|
| 230 |
+
# Load training data
|
| 231 |
+
data_file = PROJECT_ROOT / "data" / "training_data.npz"
|
| 232 |
+
|
| 233 |
+
if not data_file.exists():
|
| 234 |
+
LOGGER.error("Training data not found at %s", data_file)
|
| 235 |
+
LOGGER.error("Run generate_training_data.py first")
|
| 236 |
+
msg = f"Training data not found: {data_file}"
|
| 237 |
+
raise FileNotFoundError(msg)
|
| 238 |
+
|
| 239 |
+
data = np.load(data_file)
|
| 240 |
+
|
| 241 |
+
# Use pre-split data
|
| 242 |
+
X_train = data["X_train"]
|
| 243 |
+
y_train = data["y_train"]
|
| 244 |
+
X_val = data["X_val"]
|
| 245 |
+
y_val = data["y_val"]
|
| 246 |
+
|
| 247 |
+
LOGGER.info(
|
| 248 |
+
"Loaded %d training samples, %d validation samples", len(X_train), len(X_val)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 252 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 253 |
+
y_train = normalize_munsell(y_train, output_params)
|
| 254 |
+
y_val = normalize_munsell(y_val, output_params)
|
| 255 |
+
|
| 256 |
+
# Convert to PyTorch tensors
|
| 257 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 258 |
+
y_train_t = torch.FloatTensor(y_train)
|
| 259 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 260 |
+
y_val_t = torch.FloatTensor(y_val)
|
| 261 |
+
|
| 262 |
+
# Create data loaders
|
| 263 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 264 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 265 |
+
|
| 266 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 267 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 268 |
+
|
| 269 |
+
# Initialize model
|
| 270 |
+
model = MultiMLPToMunsell(
|
| 271 |
+
chroma_width_multiplier=chroma_width, dropout=dropout
|
| 272 |
+
).to(device)
|
| 273 |
+
|
| 274 |
+
# Count parameters
|
| 275 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 276 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 277 |
+
|
| 278 |
+
# Training setup
|
| 279 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 280 |
+
|
| 281 |
+
# MLflow setup
|
| 282 |
+
run_name = setup_mlflow_experiment(
|
| 283 |
+
"from_xyY", f"hparam_multi_mlp_trial_{trial.number}"
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Training loop with early stopping
|
| 287 |
+
num_epochs = 100 # Reduced for hyperparameter search
|
| 288 |
+
patience = 15
|
| 289 |
+
best_val_loss = float("inf")
|
| 290 |
+
patience_counter = 0
|
| 291 |
+
|
| 292 |
+
with mlflow.start_run(run_name=run_name):
|
| 293 |
+
mlflow.log_params(
|
| 294 |
+
{
|
| 295 |
+
"trial": trial.number,
|
| 296 |
+
"lr": lr,
|
| 297 |
+
"batch_size": batch_size,
|
| 298 |
+
"chroma_width": chroma_width,
|
| 299 |
+
"chroma_weight": chroma_weight,
|
| 300 |
+
"code_weight": code_weight,
|
| 301 |
+
"dropout": dropout,
|
| 302 |
+
"total_params": total_params,
|
| 303 |
+
}
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
for epoch in range(num_epochs):
|
| 307 |
+
train_loss = train_epoch(
|
| 308 |
+
model, train_loader, optimizer, device, chroma_weight, code_weight
|
| 309 |
+
)
|
| 310 |
+
val_loss = validate(model, val_loader, device, chroma_weight, code_weight)
|
| 311 |
+
|
| 312 |
+
# Log to MLflow
|
| 313 |
+
mlflow.log_metrics(
|
| 314 |
+
{
|
| 315 |
+
"train_loss": train_loss,
|
| 316 |
+
"val_loss": val_loss,
|
| 317 |
+
"learning_rate": lr,
|
| 318 |
+
},
|
| 319 |
+
step=epoch,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if (epoch + 1) % 10 == 0:
|
| 323 |
+
LOGGER.info(
|
| 324 |
+
" Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 325 |
+
epoch + 1,
|
| 326 |
+
num_epochs,
|
| 327 |
+
train_loss,
|
| 328 |
+
val_loss,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Early stopping
|
| 332 |
+
if val_loss < best_val_loss:
|
| 333 |
+
best_val_loss = val_loss
|
| 334 |
+
patience_counter = 0
|
| 335 |
+
else:
|
| 336 |
+
patience_counter += 1
|
| 337 |
+
if patience_counter >= patience:
|
| 338 |
+
LOGGER.info(" Early stopping at epoch %d", epoch + 1)
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
# Report intermediate value for pruning
|
| 342 |
+
trial.report(val_loss, epoch)
|
| 343 |
+
|
| 344 |
+
# Handle pruning
|
| 345 |
+
if trial.should_prune():
|
| 346 |
+
LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
|
| 347 |
+
mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
|
| 348 |
+
raise optuna.TrialPruned
|
| 349 |
+
|
| 350 |
+
# Log final results
|
| 351 |
+
mlflow.log_metrics(
|
| 352 |
+
{
|
| 353 |
+
"best_val_loss": best_val_loss,
|
| 354 |
+
"final_train_loss": train_loss,
|
| 355 |
+
"final_epoch": epoch + 1,
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
LOGGER.info(" Final validation loss: %.6f", best_val_loss)
|
| 360 |
+
|
| 361 |
+
return best_val_loss
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def main() -> None:
|
| 365 |
+
"""
|
| 366 |
+
Run hyperparameter search for Multi-MLP model.
|
| 367 |
+
|
| 368 |
+
Performs systematic hyperparameter optimization using Optuna with:
|
| 369 |
+
- MedianPruner for early stopping of unpromising trials
|
| 370 |
+
- 15 total trials
|
| 371 |
+
- MLflow logging for each trial
|
| 372 |
+
- Result visualization using matplotlib (optimization history,
|
| 373 |
+
parameter importances, parallel coordinate plot)
|
| 374 |
+
|
| 375 |
+
The search aims to find optimal hyperparameters for converting xyY
|
| 376 |
+
color coordinates to Munsell color specifications using a multi-MLP
|
| 377 |
+
architecture with independent branches for each component.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
LOGGER.info("=" * 80)
|
| 381 |
+
LOGGER.info("Multi-MLP Hyperparameter Search with Optuna")
|
| 382 |
+
LOGGER.info("=" * 80)
|
| 383 |
+
|
| 384 |
+
# Create study
|
| 385 |
+
study = optuna.create_study(
|
| 386 |
+
direction="minimize",
|
| 387 |
+
study_name="multi_mlp_hparam_search",
|
| 388 |
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Run optimization
|
| 392 |
+
n_trials = 15 # Number of trials to run
|
| 393 |
+
|
| 394 |
+
LOGGER.info("")
|
| 395 |
+
LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
|
| 396 |
+
LOGGER.info("")
|
| 397 |
+
|
| 398 |
+
study.optimize(objective, n_trials=n_trials, timeout=None)
|
| 399 |
+
|
| 400 |
+
# Print results
|
| 401 |
+
LOGGER.info("")
|
| 402 |
+
LOGGER.info("=" * 80)
|
| 403 |
+
LOGGER.info("Hyperparameter Search Results")
|
| 404 |
+
LOGGER.info("=" * 80)
|
| 405 |
+
LOGGER.info("")
|
| 406 |
+
LOGGER.info("Best trial:")
|
| 407 |
+
LOGGER.info(" Value (val_loss): %.6f", study.best_value)
|
| 408 |
+
LOGGER.info("")
|
| 409 |
+
LOGGER.info("Best hyperparameters:")
|
| 410 |
+
for key, value in study.best_params.items():
|
| 411 |
+
LOGGER.info(" %s: %s", key, value)
|
| 412 |
+
|
| 413 |
+
# Save results
|
| 414 |
+
results_dir = PROJECT_ROOT / "results" / "from_xyY"
|
| 415 |
+
results_dir.mkdir(exist_ok=True)
|
| 416 |
+
|
| 417 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 418 |
+
results_file = results_dir / f"hparam_search_{timestamp}.txt"
|
| 419 |
+
|
| 420 |
+
with open(results_file, "w") as f:
|
| 421 |
+
f.write("=" * 80 + "\n")
|
| 422 |
+
f.write("Multi-MLP Hyperparameter Search Results\n")
|
| 423 |
+
f.write("=" * 80 + "\n\n")
|
| 424 |
+
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 425 |
+
f.write(f"Number of trials: {len(study.trials)}\n")
|
| 426 |
+
f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
|
| 427 |
+
f.write("Best hyperparameters:\n")
|
| 428 |
+
for key, value in study.best_params.items():
|
| 429 |
+
f.write(f" {key}: {value}\n")
|
| 430 |
+
f.write("\n\nAll trials:\n")
|
| 431 |
+
f.write("-" * 80 + "\n")
|
| 432 |
+
|
| 433 |
+
for trial in study.trials:
|
| 434 |
+
f.write(f"\nTrial {trial.number}:\n")
|
| 435 |
+
f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n")
|
| 436 |
+
f.write(" Params:\n")
|
| 437 |
+
for key, value in trial.params.items():
|
| 438 |
+
f.write(f" {key}: {value}\n")
|
| 439 |
+
|
| 440 |
+
LOGGER.info("")
|
| 441 |
+
LOGGER.info("Results saved to: %s", results_file)
|
| 442 |
+
|
| 443 |
+
# Generate visualizations using matplotlib
|
| 444 |
+
from optuna.visualization.matplotlib import (
|
| 445 |
+
plot_optimization_history,
|
| 446 |
+
plot_param_importances,
|
| 447 |
+
plot_parallel_coordinate,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Optimization history
|
| 451 |
+
ax = plot_optimization_history(study)
|
| 452 |
+
ax.figure.savefig(results_dir / f"optimization_history_{timestamp}.png", dpi=150)
|
| 453 |
+
plt.close(ax.figure)
|
| 454 |
+
|
| 455 |
+
# Parameter importances
|
| 456 |
+
ax = plot_param_importances(study)
|
| 457 |
+
ax.figure.savefig(results_dir / f"param_importances_{timestamp}.png", dpi=150)
|
| 458 |
+
plt.close(ax.figure)
|
| 459 |
+
|
| 460 |
+
# Parallel coordinate plot
|
| 461 |
+
ax = plot_parallel_coordinate(study)
|
| 462 |
+
ax.figure.savefig(results_dir / f"parallel_coordinate_{timestamp}.png", dpi=150)
|
| 463 |
+
plt.close(ax.figure)
|
| 464 |
+
|
| 465 |
+
LOGGER.info("Visualizations saved to: %s", results_dir)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
if __name__ == "__main__":
|
| 469 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 470 |
+
|
| 471 |
+
main()
|
learning_munsell/training/from_xyY/refine_multi_head_real.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Refine Multi-Head model on REAL Munsell colors only.
|
| 3 |
+
|
| 4 |
+
This script fine-tunes the best Multi-Head model using only the 2734 real
|
| 5 |
+
(measured) Munsell colors, which should improve accuracy on the evaluation set.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import click
|
| 12 |
+
import mlflow
|
| 13 |
+
import mlflow.pytorch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
|
| 17 |
+
from colour.notation.munsell import (
|
| 18 |
+
munsell_colour_to_munsell_specification,
|
| 19 |
+
munsell_specification_to_xyY,
|
| 20 |
+
)
|
| 21 |
+
from numpy.typing import NDArray
|
| 22 |
+
from sklearn.model_selection import train_test_split
|
| 23 |
+
from torch import nn, optim
|
| 24 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 25 |
+
|
| 26 |
+
from learning_munsell import PROJECT_ROOT
|
| 27 |
+
from learning_munsell.models.networks import MultiHeadMLPToMunsell
|
| 28 |
+
from learning_munsell.utilities.common import (
|
| 29 |
+
log_training_epoch,
|
| 30 |
+
setup_mlflow_experiment,
|
| 31 |
+
)
|
| 32 |
+
from learning_munsell.utilities.data import (
|
| 33 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 34 |
+
XYY_NORMALIZATION_PARAMS,
|
| 35 |
+
normalize_munsell,
|
| 36 |
+
)
|
| 37 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 38 |
+
|
| 39 |
+
LOGGER = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def generate_real_samples(
|
| 43 |
+
n_samples_per_color: int = 100,
|
| 44 |
+
perturbation_pct: float = 0.05,
|
| 45 |
+
) -> tuple[NDArray, NDArray]:
|
| 46 |
+
"""
|
| 47 |
+
Generate training samples from REAL (measured) Munsell colors only.
|
| 48 |
+
|
| 49 |
+
Creates augmented samples by applying small perturbations to the 2734 real
|
| 50 |
+
Munsell color specifications to increase training data while staying close
|
| 51 |
+
to measured values.
|
| 52 |
+
|
| 53 |
+
Parameters
|
| 54 |
+
----------
|
| 55 |
+
n_samples_per_color : int, optional
|
| 56 |
+
Number of perturbed samples to generate per real color (default is 100).
|
| 57 |
+
perturbation_pct : float, optional
|
| 58 |
+
Percentage of range to use for perturbations (default is 0.05 = 5%).
|
| 59 |
+
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
xyY_samples : NDArray
|
| 63 |
+
Array of shape (n_samples, 3) containing xyY coordinates.
|
| 64 |
+
munsell_samples : NDArray
|
| 65 |
+
Array of shape (n_samples, 4) containing Munsell specifications
|
| 66 |
+
[hue, value, chroma, code].
|
| 67 |
+
|
| 68 |
+
Notes
|
| 69 |
+
-----
|
| 70 |
+
Perturbations are applied uniformly within ±perturbation_pct of the
|
| 71 |
+
component ranges:
|
| 72 |
+
- Hue range: 9.5 (0.5 to 10.0)
|
| 73 |
+
- Value range: 9.0 (1.0 to 10.0)
|
| 74 |
+
- Chroma range: 50.0 (0.0 to 50.0)
|
| 75 |
+
|
| 76 |
+
Invalid samples (that cannot be converted to xyY) are skipped.
|
| 77 |
+
"""
|
| 78 |
+
LOGGER.info(
|
| 79 |
+
"Generating samples from %d REAL Munsell colors...", len(MUNSELL_COLOURS_REAL)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
np.random.seed(42)
|
| 83 |
+
|
| 84 |
+
hue_range = 9.5
|
| 85 |
+
value_range = 9.0
|
| 86 |
+
chroma_range = 50.0
|
| 87 |
+
|
| 88 |
+
xyY_samples = []
|
| 89 |
+
munsell_samples = []
|
| 90 |
+
|
| 91 |
+
for munsell_spec_tuple, _ in MUNSELL_COLOURS_REAL:
|
| 92 |
+
hue_code_str, value, chroma = munsell_spec_tuple
|
| 93 |
+
munsell_str = f"{hue_code_str} {value}/{chroma}"
|
| 94 |
+
base_spec = munsell_colour_to_munsell_specification(munsell_str)
|
| 95 |
+
|
| 96 |
+
for _ in range(n_samples_per_color):
|
| 97 |
+
hue_delta = np.random.uniform(
|
| 98 |
+
-perturbation_pct * hue_range, perturbation_pct * hue_range
|
| 99 |
+
)
|
| 100 |
+
value_delta = np.random.uniform(
|
| 101 |
+
-perturbation_pct * value_range, perturbation_pct * value_range
|
| 102 |
+
)
|
| 103 |
+
chroma_delta = np.random.uniform(
|
| 104 |
+
-perturbation_pct * chroma_range, perturbation_pct * chroma_range
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
perturbed_spec = base_spec.copy()
|
| 108 |
+
perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0)
|
| 109 |
+
perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0)
|
| 110 |
+
perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0)
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
xyY = munsell_specification_to_xyY(perturbed_spec)
|
| 114 |
+
xyY_samples.append(xyY)
|
| 115 |
+
munsell_samples.append(perturbed_spec)
|
| 116 |
+
except Exception:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
LOGGER.info("Generated %d samples", len(xyY_samples))
|
| 120 |
+
return np.array(xyY_samples), np.array(munsell_samples)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@click.command()
|
| 124 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 125 |
+
@click.option("--batch-size", default=512, help="Batch size for training")
|
| 126 |
+
@click.option("--lr", default=1e-5, help="Learning rate")
|
| 127 |
+
@click.option("--patience", default=30, help="Early stopping patience")
|
| 128 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 129 |
+
"""
|
| 130 |
+
Refine Multi-Head model on REAL Munsell colors only.
|
| 131 |
+
|
| 132 |
+
Fine-tunes a pretrained Multi-Head MLP model using only the 2734 real
|
| 133 |
+
(measured) Munsell colors with small perturbations. This refinement step
|
| 134 |
+
aims to improve accuracy on actual measured colors by focusing the model
|
| 135 |
+
on the real color gamut.
|
| 136 |
+
|
| 137 |
+
Notes
|
| 138 |
+
-----
|
| 139 |
+
Training configuration:
|
| 140 |
+
- Dataset: 2734 real Munsell colors with 200 samples per color
|
| 141 |
+
- Perturbation: 3% of component ranges (smaller than initial training)
|
| 142 |
+
- Learning rate: 1e-5 (lower for fine-tuning)
|
| 143 |
+
- Batch size: 512
|
| 144 |
+
- Early stopping: patience of 30 epochs
|
| 145 |
+
- Optimizer: AdamW with weight decay 0.01
|
| 146 |
+
- Scheduler: ReduceLROnPlateau with factor 0.5, patience 15
|
| 147 |
+
|
| 148 |
+
Workflow:
|
| 149 |
+
1. Generate augmented samples from real Munsell colors
|
| 150 |
+
2. Load pretrained model (multi_head_large_best.pth)
|
| 151 |
+
3. Fine-tune with lower learning rate
|
| 152 |
+
4. Save best model based on validation loss
|
| 153 |
+
5. Export to ONNX format
|
| 154 |
+
6. Log metrics to MLflow
|
| 155 |
+
|
| 156 |
+
Files generated:
|
| 157 |
+
- multi_head_refined_real_best.pth: Best checkpoint
|
| 158 |
+
- multi_head_refined_real.onnx: ONNX model
|
| 159 |
+
- multi_head_refined_real_normalization_params.npz: Normalization params
|
| 160 |
+
"""
|
| 161 |
+
LOGGER.info("=" * 80)
|
| 162 |
+
LOGGER.info("Multi-Head Refinement on REAL Munsell Colors")
|
| 163 |
+
LOGGER.info("=" * 80)
|
| 164 |
+
|
| 165 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 166 |
+
if torch.backends.mps.is_available():
|
| 167 |
+
device = torch.device("mps")
|
| 168 |
+
LOGGER.info("Using device: %s", device)
|
| 169 |
+
|
| 170 |
+
# Generate REAL-only samples
|
| 171 |
+
LOGGER.info("")
|
| 172 |
+
xyY_all, munsell_all = generate_real_samples(
|
| 173 |
+
n_samples_per_color=200, # 200 samples per real color
|
| 174 |
+
perturbation_pct=0.03, # Smaller perturbations for refinement
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Split data
|
| 178 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 179 |
+
xyY_all, munsell_all, test_size=0.15, random_state=42
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 183 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 184 |
+
|
| 185 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 186 |
+
# Use hardcoded ranges covering the full Munsell space for generalization
|
| 187 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 188 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 189 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 190 |
+
|
| 191 |
+
# Convert to tensors
|
| 192 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 193 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 194 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 195 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 196 |
+
|
| 197 |
+
# Data loaders
|
| 198 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 199 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 200 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 201 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 202 |
+
|
| 203 |
+
# Load pretrained model
|
| 204 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 205 |
+
pretrained_path = model_directory / "multi_head_large_best.pth"
|
| 206 |
+
|
| 207 |
+
model = MultiHeadMLPToMunsell().to(device)
|
| 208 |
+
|
| 209 |
+
if pretrained_path.exists():
|
| 210 |
+
LOGGER.info("")
|
| 211 |
+
LOGGER.info("Loading pretrained model from %s...", pretrained_path)
|
| 212 |
+
checkpoint = torch.load(
|
| 213 |
+
pretrained_path, weights_only=False, map_location=device
|
| 214 |
+
)
|
| 215 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 216 |
+
LOGGER.info("Pretrained model loaded successfully")
|
| 217 |
+
else:
|
| 218 |
+
LOGGER.info("")
|
| 219 |
+
LOGGER.info("No pretrained model found, training from scratch")
|
| 220 |
+
|
| 221 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 222 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 223 |
+
|
| 224 |
+
# Fine-tuning with lower learning rate
|
| 225 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
|
| 226 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 227 |
+
optimizer, mode="min", factor=0.5, patience=15
|
| 228 |
+
)
|
| 229 |
+
criterion = nn.MSELoss()
|
| 230 |
+
|
| 231 |
+
# MLflow setup
|
| 232 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head_refined_real")
|
| 233 |
+
|
| 234 |
+
LOGGER.info("")
|
| 235 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 236 |
+
LOGGER.info("Learning rate: %e (fine-tuning)", lr)
|
| 237 |
+
|
| 238 |
+
# Training loop
|
| 239 |
+
best_val_loss = float("inf")
|
| 240 |
+
patience_counter = 0
|
| 241 |
+
|
| 242 |
+
LOGGER.info("")
|
| 243 |
+
LOGGER.info("Starting refinement training...")
|
| 244 |
+
|
| 245 |
+
with mlflow.start_run(run_name=run_name):
|
| 246 |
+
mlflow.log_params(
|
| 247 |
+
{
|
| 248 |
+
"model": "multi_head_refined_real",
|
| 249 |
+
"learning_rate": lr,
|
| 250 |
+
"batch_size": batch_size,
|
| 251 |
+
"num_epochs": epochs,
|
| 252 |
+
"patience": patience,
|
| 253 |
+
"total_params": total_params,
|
| 254 |
+
"train_samples": len(X_train),
|
| 255 |
+
"val_samples": len(X_val),
|
| 256 |
+
"dataset": "REAL_only",
|
| 257 |
+
"perturbation_pct": 0.03,
|
| 258 |
+
"samples_per_color": 200,
|
| 259 |
+
}
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
for epoch in range(epochs):
|
| 263 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 264 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 265 |
+
|
| 266 |
+
scheduler.step(val_loss)
|
| 267 |
+
|
| 268 |
+
log_training_epoch(
|
| 269 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
LOGGER.info(
|
| 273 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.2e",
|
| 274 |
+
epoch + 1,
|
| 275 |
+
epochs,
|
| 276 |
+
train_loss,
|
| 277 |
+
val_loss,
|
| 278 |
+
optimizer.param_groups[0]["lr"],
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if val_loss < best_val_loss:
|
| 282 |
+
best_val_loss = val_loss
|
| 283 |
+
patience_counter = 0
|
| 284 |
+
|
| 285 |
+
checkpoint_file = model_directory / "multi_head_refined_real_best.pth"
|
| 286 |
+
|
| 287 |
+
torch.save(
|
| 288 |
+
{
|
| 289 |
+
"model_state_dict": model.state_dict(),
|
| 290 |
+
"output_params": output_params,
|
| 291 |
+
"epoch": epoch,
|
| 292 |
+
"val_loss": val_loss,
|
| 293 |
+
},
|
| 294 |
+
checkpoint_file,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
LOGGER.info(" -> Saved best model (val_loss: %.6f)", val_loss)
|
| 298 |
+
else:
|
| 299 |
+
patience_counter += 1
|
| 300 |
+
if patience_counter >= patience:
|
| 301 |
+
LOGGER.info("")
|
| 302 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 303 |
+
break
|
| 304 |
+
|
| 305 |
+
mlflow.log_metrics(
|
| 306 |
+
{
|
| 307 |
+
"best_val_loss": best_val_loss,
|
| 308 |
+
"final_epoch": epoch + 1,
|
| 309 |
+
}
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Export to ONNX
|
| 313 |
+
LOGGER.info("")
|
| 314 |
+
LOGGER.info("Exporting refined model to ONNX...")
|
| 315 |
+
model.eval()
|
| 316 |
+
|
| 317 |
+
checkpoint = torch.load(checkpoint_file, weights_only=False)
|
| 318 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 319 |
+
|
| 320 |
+
model_cpu = model.cpu()
|
| 321 |
+
dummy_input = torch.randn(1, 3)
|
| 322 |
+
|
| 323 |
+
onnx_file = model_directory / "multi_head_refined_real.onnx"
|
| 324 |
+
torch.onnx.export(
|
| 325 |
+
model_cpu,
|
| 326 |
+
dummy_input,
|
| 327 |
+
onnx_file,
|
| 328 |
+
export_params=True,
|
| 329 |
+
opset_version=14,
|
| 330 |
+
input_names=["xyY"],
|
| 331 |
+
output_names=["munsell_spec"],
|
| 332 |
+
dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
params_file = (
|
| 336 |
+
model_directory / "multi_head_refined_real_normalization_params.npz"
|
| 337 |
+
)
|
| 338 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 339 |
+
np.savez(
|
| 340 |
+
params_file,
|
| 341 |
+
input_params=input_params,
|
| 342 |
+
output_params=output_params,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 346 |
+
mlflow.log_artifact(str(onnx_file))
|
| 347 |
+
mlflow.log_artifact(str(params_file))
|
| 348 |
+
|
| 349 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 350 |
+
LOGGER.info("Normalization params saved to: %s", params_file)
|
| 351 |
+
|
| 352 |
+
LOGGER.info("=" * 80)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if __name__ == "__main__":
|
| 356 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 357 |
+
|
| 358 |
+
main()
|
learning_munsell/training/from_xyY/train_deep_wide.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Deep + Wide model for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
Option 5: Hybrid Deep + Wide architecture
|
| 5 |
+
- Input: 3 features (xyY)
|
| 6 |
+
- Deep path: 3 → 512 → 1024 (ResBlocks) → 512
|
| 7 |
+
- Wide path: 3 → 128 (direct linear)
|
| 8 |
+
- Combine: [512, 128] → 256 → 4
|
| 9 |
+
- Output: 4 features (hue, value, chroma, code)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import click
|
| 16 |
+
import mlflow
|
| 17 |
+
import mlflow.pytorch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from numpy.typing import NDArray
|
| 21 |
+
from torch import nn, optim
|
| 22 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 23 |
+
|
| 24 |
+
from learning_munsell import PROJECT_ROOT
|
| 25 |
+
from learning_munsell.models.networks import ResidualBlock
|
| 26 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 27 |
+
from learning_munsell.utilities.data import (
|
| 28 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 29 |
+
XYY_NORMALIZATION_PARAMS,
|
| 30 |
+
normalize_munsell,
|
| 31 |
+
)
|
| 32 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 33 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 34 |
+
|
| 35 |
+
LOGGER = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DeepWideNet(nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Deep + Wide Network for xyY to Munsell conversion.
|
| 41 |
+
|
| 42 |
+
Architecture:
|
| 43 |
+
- Deep path: Complex non-linear transformation
|
| 44 |
+
- Wide path: Direct linear connections
|
| 45 |
+
- Combines both for final prediction
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
num_residual_blocks : int, optional
|
| 50 |
+
Number of residual blocks in deep path. Default is 4.
|
| 51 |
+
|
| 52 |
+
Attributes
|
| 53 |
+
----------
|
| 54 |
+
deep_encoder : nn.Sequential
|
| 55 |
+
Deep path encoder: 3 → 512 → 1024.
|
| 56 |
+
deep_residual_blocks : nn.ModuleList
|
| 57 |
+
Stack of residual blocks in deep path.
|
| 58 |
+
deep_decoder : nn.Sequential
|
| 59 |
+
Deep path decoder: 1024 → 512.
|
| 60 |
+
wide_path : nn.Sequential
|
| 61 |
+
Wide path: 3 → 128.
|
| 62 |
+
output_head : nn.Sequential
|
| 63 |
+
Combined output: [512, 128] → 256 → 4.
|
| 64 |
+
|
| 65 |
+
Notes
|
| 66 |
+
-----
|
| 67 |
+
Hybrid architecture inspired by Google's Wide & Deep Learning:
|
| 68 |
+
- Deep path: 3 → 512 → 1024 → (ResBlocks) → 512
|
| 69 |
+
- Wide path: 3 → 128 (direct linear transformation)
|
| 70 |
+
- Combined: Concatenate [512, 128] → 256 → 4
|
| 71 |
+
|
| 72 |
+
The deep path learns complex non-linear transformations while the
|
| 73 |
+
wide path provides direct linear connections to preserve simple
|
| 74 |
+
relationships. Both paths are concatenated before the final output.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, num_residual_blocks: int = 4) -> None:
|
| 78 |
+
"""Initialize the deep and wide network."""
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
# Deep path: Complex transformation
|
| 82 |
+
self.deep_encoder = nn.Sequential(
|
| 83 |
+
nn.Linear(3, 512),
|
| 84 |
+
nn.GELU(),
|
| 85 |
+
nn.BatchNorm1d(512),
|
| 86 |
+
nn.Linear(512, 1024),
|
| 87 |
+
nn.GELU(),
|
| 88 |
+
nn.BatchNorm1d(1024),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.deep_residual_blocks = nn.ModuleList(
|
| 92 |
+
[ResidualBlock(1024) for _ in range(num_residual_blocks)]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.deep_decoder = nn.Sequential(
|
| 96 |
+
nn.Linear(1024, 512),
|
| 97 |
+
nn.GELU(),
|
| 98 |
+
nn.BatchNorm1d(512),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Wide path: Direct linear transformation
|
| 102 |
+
self.wide_path = nn.Sequential(
|
| 103 |
+
nn.Linear(3, 128),
|
| 104 |
+
nn.GELU(),
|
| 105 |
+
nn.BatchNorm1d(128),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Combined output: Concatenate deep (512) + wide (128) = 640
|
| 109 |
+
self.output_head = nn.Sequential(
|
| 110 |
+
nn.Linear(640, 256),
|
| 111 |
+
nn.GELU(),
|
| 112 |
+
nn.BatchNorm1d(256),
|
| 113 |
+
nn.Linear(256, 4),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
"""
|
| 118 |
+
Forward pass through deep and wide paths.
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
x : Tensor
|
| 123 |
+
Input tensor of shape (batch_size, 3) containing normalized xyY values.
|
| 124 |
+
|
| 125 |
+
Returns
|
| 126 |
+
-------
|
| 127 |
+
Tensor
|
| 128 |
+
Output tensor of shape (batch_size, 4) containing normalized Munsell
|
| 129 |
+
specifications [hue, value, chroma, code].
|
| 130 |
+
|
| 131 |
+
Notes
|
| 132 |
+
-----
|
| 133 |
+
The forward pass processes input through two parallel paths:
|
| 134 |
+
1. Deep path: Complex transformation through encoder, residual blocks,
|
| 135 |
+
and decoder (3 → 512 → 1024 → 512)
|
| 136 |
+
2. Wide path: Direct linear transformation (3 → 128)
|
| 137 |
+
3. Concatenation: Combine deep (512) + wide (128) = 640 features
|
| 138 |
+
4. Output head: Final transformation to 4 components (640 → 256 → 4)
|
| 139 |
+
"""
|
| 140 |
+
# Deep path
|
| 141 |
+
deep = self.deep_encoder(x)
|
| 142 |
+
for block in self.deep_residual_blocks:
|
| 143 |
+
deep = block(deep)
|
| 144 |
+
deep = self.deep_decoder(deep)
|
| 145 |
+
|
| 146 |
+
# Wide path
|
| 147 |
+
wide = self.wide_path(x)
|
| 148 |
+
|
| 149 |
+
# Concatenate and output
|
| 150 |
+
combined = torch.cat([deep, wide], dim=1)
|
| 151 |
+
return self.output_head(combined)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@click.command()
|
| 155 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 156 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 157 |
+
@click.option("--lr", default=3e-4, help="Learning rate")
|
| 158 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 159 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 160 |
+
"""
|
| 161 |
+
Train the DeepWideNet model for xyY to Munsell conversion.
|
| 162 |
+
|
| 163 |
+
Notes
|
| 164 |
+
-----
|
| 165 |
+
The training pipeline:
|
| 166 |
+
1. Loads normalization parameters from existing config
|
| 167 |
+
2. Loads training data from cache
|
| 168 |
+
3. Normalizes inputs and outputs to [0, 1] range
|
| 169 |
+
4. Creates PyTorch DataLoaders
|
| 170 |
+
5. Initializes DeepWideNet with deep and wide paths
|
| 171 |
+
6. Trains with AdamW optimizer and precision-focused loss
|
| 172 |
+
7. Uses learning rate scheduler (ReduceLROnPlateau)
|
| 173 |
+
8. Implements early stopping based on validation loss
|
| 174 |
+
9. Exports best model to ONNX format
|
| 175 |
+
10. Logs all metrics and artifacts to MLflow
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
LOGGER.info("=" * 80)
|
| 180 |
+
LOGGER.info("Deep + Wide Network: xyY → Munsell")
|
| 181 |
+
LOGGER.info("=" * 80)
|
| 182 |
+
|
| 183 |
+
# Set device
|
| 184 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 185 |
+
LOGGER.info("Using device: %s", device)
|
| 186 |
+
|
| 187 |
+
# Paths
|
| 188 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 189 |
+
data_dir = PROJECT_ROOT / "data"
|
| 190 |
+
cache_file = data_dir / "training_data.npz"
|
| 191 |
+
|
| 192 |
+
# Load training data
|
| 193 |
+
LOGGER.info("")
|
| 194 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 195 |
+
data = np.load(cache_file)
|
| 196 |
+
X_train = data["X_train"]
|
| 197 |
+
y_train = data["y_train"]
|
| 198 |
+
X_val = data["X_val"]
|
| 199 |
+
y_val = data["y_val"]
|
| 200 |
+
|
| 201 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 202 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 203 |
+
|
| 204 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 205 |
+
# Use hardcoded ranges covering the full Munsell space for generalization
|
| 206 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 207 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 208 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 209 |
+
|
| 210 |
+
# Convert to PyTorch tensors
|
| 211 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 212 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 213 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 214 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 215 |
+
|
| 216 |
+
# Create data loaders
|
| 217 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 218 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 219 |
+
|
| 220 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 221 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 222 |
+
|
| 223 |
+
# Initialize model
|
| 224 |
+
model = DeepWideNet(num_residual_blocks=4).to(device)
|
| 225 |
+
LOGGER.info("")
|
| 226 |
+
LOGGER.info("Deep + Wide architecture:")
|
| 227 |
+
LOGGER.info("%s", model)
|
| 228 |
+
|
| 229 |
+
# Count parameters
|
| 230 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 231 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 232 |
+
|
| 233 |
+
# Training setup
|
| 234 |
+
learning_rate = lr
|
| 235 |
+
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
|
| 236 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 237 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 238 |
+
)
|
| 239 |
+
criterion = precision_focused_loss
|
| 240 |
+
|
| 241 |
+
# MLflow setup
|
| 242 |
+
run_name = setup_mlflow_experiment("from_xyY", "deep_wide")
|
| 243 |
+
|
| 244 |
+
LOGGER.info("")
|
| 245 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 246 |
+
|
| 247 |
+
# Training loop
|
| 248 |
+
best_val_loss = float("inf")
|
| 249 |
+
patience_counter = 0
|
| 250 |
+
|
| 251 |
+
LOGGER.info("")
|
| 252 |
+
LOGGER.info("Starting training...")
|
| 253 |
+
|
| 254 |
+
with mlflow.start_run(run_name=run_name):
|
| 255 |
+
# Log parameters
|
| 256 |
+
mlflow.log_params(
|
| 257 |
+
{
|
| 258 |
+
"model": "deep_wide",
|
| 259 |
+
"learning_rate": learning_rate,
|
| 260 |
+
"batch_size": batch_size,
|
| 261 |
+
"num_epochs": epochs,
|
| 262 |
+
"patience": patience,
|
| 263 |
+
"total_params": total_params,
|
| 264 |
+
}
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
for epoch in range(epochs):
|
| 268 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 269 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 270 |
+
|
| 271 |
+
scheduler.step(val_loss)
|
| 272 |
+
|
| 273 |
+
# Log to MLflow
|
| 274 |
+
log_training_epoch(
|
| 275 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
LOGGER.info(
|
| 279 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 280 |
+
epoch + 1,
|
| 281 |
+
epochs,
|
| 282 |
+
train_loss,
|
| 283 |
+
val_loss,
|
| 284 |
+
optimizer.param_groups[0]["lr"],
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Early stopping
|
| 288 |
+
if val_loss < best_val_loss:
|
| 289 |
+
best_val_loss = val_loss
|
| 290 |
+
patience_counter = 0
|
| 291 |
+
|
| 292 |
+
model_directory.mkdir(exist_ok=True)
|
| 293 |
+
checkpoint_file = model_directory / "deep_wide_best.pth"
|
| 294 |
+
|
| 295 |
+
torch.save(
|
| 296 |
+
{
|
| 297 |
+
"model_state_dict": model.state_dict(),
|
| 298 |
+
"epoch": epoch,
|
| 299 |
+
"val_loss": val_loss,
|
| 300 |
+
},
|
| 301 |
+
checkpoint_file,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 305 |
+
else:
|
| 306 |
+
patience_counter += 1
|
| 307 |
+
if patience_counter >= patience:
|
| 308 |
+
LOGGER.info("")
|
| 309 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 310 |
+
break
|
| 311 |
+
|
| 312 |
+
# Log final metrics
|
| 313 |
+
mlflow.log_metrics(
|
| 314 |
+
{
|
| 315 |
+
"best_val_loss": best_val_loss,
|
| 316 |
+
"final_epoch": epoch + 1,
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Export to ONNX
|
| 321 |
+
LOGGER.info("")
|
| 322 |
+
LOGGER.info("Exporting to ONNX...")
|
| 323 |
+
model.eval()
|
| 324 |
+
|
| 325 |
+
checkpoint = torch.load(checkpoint_file)
|
| 326 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 327 |
+
|
| 328 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 329 |
+
|
| 330 |
+
onnx_file = model_directory / "deep_wide.onnx"
|
| 331 |
+
torch.onnx.export(
|
| 332 |
+
model,
|
| 333 |
+
dummy_input,
|
| 334 |
+
onnx_file,
|
| 335 |
+
export_params=True,
|
| 336 |
+
opset_version=15,
|
| 337 |
+
input_names=["xyY"],
|
| 338 |
+
output_names=["munsell_spec"],
|
| 339 |
+
dynamic_axes={
|
| 340 |
+
"xyY": {0: "batch_size"},
|
| 341 |
+
"munsell_spec": {0: "batch_size"},
|
| 342 |
+
},
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Save normalization parameters alongside model
|
| 346 |
+
params_file = model_directory / "deep_wide_normalization_params.npz"
|
| 347 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 348 |
+
np.savez(
|
| 349 |
+
params_file,
|
| 350 |
+
input_params=input_params,
|
| 351 |
+
output_params=output_params,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Log artifacts to MLflow
|
| 355 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 356 |
+
mlflow.log_artifact(str(onnx_file))
|
| 357 |
+
mlflow.log_artifact(str(params_file))
|
| 358 |
+
mlflow.pytorch.log_model(model, "model")
|
| 359 |
+
|
| 360 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 361 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 362 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
LOGGER.info("=" * 80)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 370 |
+
|
| 371 |
+
main()
|
learning_munsell/training/from_xyY/train_ft_transformer.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train FT-Transformer model for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
Option 4: Feature Tokenizer + Transformer architecture
|
| 5 |
+
- Input: 3 features (xyY) → each becomes a 256-dim token
|
| 6 |
+
- Add [CLS] token for regression
|
| 7 |
+
- 4-6 transformer blocks with multi-head attention
|
| 8 |
+
- Output: Take [CLS] token → MLP → 4 features
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import click
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import mlflow
|
| 16 |
+
import mlflow.pytorch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from numpy.typing import NDArray
|
| 20 |
+
from torch import nn, optim
|
| 21 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 22 |
+
|
| 23 |
+
from learning_munsell import PROJECT_ROOT
|
| 24 |
+
from learning_munsell.models.networks import FeatureTokenizer, TransformerBlock
|
| 25 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 26 |
+
from learning_munsell.utilities.data import (
|
| 27 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 28 |
+
XYY_NORMALIZATION_PARAMS,
|
| 29 |
+
normalize_munsell,
|
| 30 |
+
)
|
| 31 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 32 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 33 |
+
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FTTransformer(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Feature Tokenizer + Transformer for xyY to Munsell conversion.
|
| 40 |
+
|
| 41 |
+
This model adapts transformer architecture for tabular data by tokenizing
|
| 42 |
+
each input feature separately and using self-attention to capture complex
|
| 43 |
+
feature interactions.
|
| 44 |
+
|
| 45 |
+
Architecture
|
| 46 |
+
------------
|
| 47 |
+
- Tokenize each feature (3 features → 3 tokens)
|
| 48 |
+
- Add CLS token (4 tokens total)
|
| 49 |
+
- 4 transformer blocks with multi-head attention
|
| 50 |
+
- Extract CLS token → MLP head → 4 outputs
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
num_features : int, optional
|
| 55 |
+
Number of input features (xyY), default is 3.
|
| 56 |
+
embedding_dim : int, optional
|
| 57 |
+
Dimension of token embeddings, default is 256.
|
| 58 |
+
num_blocks : int, optional
|
| 59 |
+
Number of transformer blocks, default is 4.
|
| 60 |
+
num_heads : int, optional
|
| 61 |
+
Number of attention heads, default is 4.
|
| 62 |
+
ff_dim : int, optional
|
| 63 |
+
Feedforward network hidden dimension, default is 512.
|
| 64 |
+
dropout : float, optional
|
| 65 |
+
Dropout probability, default is 0.1.
|
| 66 |
+
|
| 67 |
+
Attributes
|
| 68 |
+
----------
|
| 69 |
+
tokenizer : FeatureTokenizer
|
| 70 |
+
Converts input features to token embeddings.
|
| 71 |
+
transformer_blocks : nn.ModuleList
|
| 72 |
+
Stack of transformer blocks.
|
| 73 |
+
output_head : nn.Sequential
|
| 74 |
+
MLP that maps CLS token to output predictions.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
num_features: int = 3,
|
| 80 |
+
embedding_dim: int = 256,
|
| 81 |
+
num_blocks: int = 4,
|
| 82 |
+
num_heads: int = 4,
|
| 83 |
+
ff_dim: int = 512,
|
| 84 |
+
dropout: float = 0.1,
|
| 85 |
+
) -> None:
|
| 86 |
+
"""Initialize the FT-Transformer model."""
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
# Feature tokenizer
|
| 90 |
+
self.tokenizer = FeatureTokenizer(num_features, embedding_dim)
|
| 91 |
+
|
| 92 |
+
# Transformer blocks
|
| 93 |
+
self.transformer_blocks = nn.ModuleList(
|
| 94 |
+
[
|
| 95 |
+
TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
|
| 96 |
+
for _ in range(num_blocks)
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Output head (from CLS token)
|
| 101 |
+
self.output_head = nn.Sequential(
|
| 102 |
+
nn.Linear(embedding_dim, 128),
|
| 103 |
+
nn.GELU(),
|
| 104 |
+
nn.Dropout(dropout),
|
| 105 |
+
nn.Linear(128, 4),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Forward pass through FT-Transformer.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
x : Tensor
|
| 115 |
+
Input xyY values of shape (batch_size, 3).
|
| 116 |
+
|
| 117 |
+
Returns
|
| 118 |
+
-------
|
| 119 |
+
Tensor
|
| 120 |
+
Predicted Munsell specification [hue, value, chroma, code]
|
| 121 |
+
of shape (batch_size, 4).
|
| 122 |
+
"""
|
| 123 |
+
# Tokenize features
|
| 124 |
+
tokens = self.tokenizer(x) # (batch_size, 1+num_features, embedding_dim)
|
| 125 |
+
|
| 126 |
+
# Transformer blocks
|
| 127 |
+
for block in self.transformer_blocks:
|
| 128 |
+
tokens = block(tokens)
|
| 129 |
+
|
| 130 |
+
# Extract CLS token (first token)
|
| 131 |
+
cls_token = tokens[:, 0, :] # (batch_size, embedding_dim)
|
| 132 |
+
|
| 133 |
+
# Output head
|
| 134 |
+
return self.output_head(cls_token)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@click.command()
|
| 138 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 139 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 140 |
+
@click.option("--lr", default=3e-4, help="Learning rate")
|
| 141 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 142 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 143 |
+
"""
|
| 144 |
+
Train FT-Transformer model for xyY to Munsell conversion.
|
| 145 |
+
|
| 146 |
+
Notes
|
| 147 |
+
-----
|
| 148 |
+
The training pipeline:
|
| 149 |
+
1. Loads normalization parameters from existing config
|
| 150 |
+
2. Loads training data from cache
|
| 151 |
+
3. Normalizes inputs and outputs to [0, 1] range
|
| 152 |
+
4. Creates PyTorch DataLoaders
|
| 153 |
+
5. Initializes FT-Transformer with feature tokenization
|
| 154 |
+
6. Trains with AdamW optimizer and precision-focused loss
|
| 155 |
+
7. Uses learning rate scheduler (ReduceLROnPlateau)
|
| 156 |
+
8. Implements early stopping based on validation loss
|
| 157 |
+
9. Exports best model to ONNX format
|
| 158 |
+
10. Logs all metrics and artifacts to MLflow
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
LOGGER.info("=" * 80)
|
| 163 |
+
LOGGER.info("FT-Transformer: xyY → Munsell")
|
| 164 |
+
LOGGER.info("=" * 80)
|
| 165 |
+
|
| 166 |
+
# Set device
|
| 167 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 168 |
+
LOGGER.info("Using device: %s", device)
|
| 169 |
+
|
| 170 |
+
# Paths
|
| 171 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 172 |
+
data_dir = PROJECT_ROOT / "data"
|
| 173 |
+
cache_file = data_dir / "training_data.npz"
|
| 174 |
+
|
| 175 |
+
# Load training data
|
| 176 |
+
LOGGER.info("")
|
| 177 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 178 |
+
data = np.load(cache_file)
|
| 179 |
+
X_train = data["X_train"]
|
| 180 |
+
y_train = data["y_train"]
|
| 181 |
+
X_val = data["X_val"]
|
| 182 |
+
y_val = data["y_val"]
|
| 183 |
+
|
| 184 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 185 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 186 |
+
|
| 187 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 188 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 189 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 190 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 191 |
+
|
| 192 |
+
# Convert to PyTorch tensors
|
| 193 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 194 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 195 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 196 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 197 |
+
|
| 198 |
+
# Create data loaders
|
| 199 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 200 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 201 |
+
|
| 202 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 203 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 204 |
+
|
| 205 |
+
# Initialize model
|
| 206 |
+
model = FTTransformer(
|
| 207 |
+
num_features=3,
|
| 208 |
+
embedding_dim=256,
|
| 209 |
+
num_blocks=4,
|
| 210 |
+
num_heads=4,
|
| 211 |
+
ff_dim=512,
|
| 212 |
+
dropout=0.1,
|
| 213 |
+
).to(device)
|
| 214 |
+
|
| 215 |
+
LOGGER.info("")
|
| 216 |
+
LOGGER.info("FT-Transformer architecture:")
|
| 217 |
+
LOGGER.info("%s", model)
|
| 218 |
+
|
| 219 |
+
# Count parameters
|
| 220 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 221 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 222 |
+
|
| 223 |
+
# Training setup
|
| 224 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 225 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 226 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 227 |
+
)
|
| 228 |
+
criterion = precision_focused_loss
|
| 229 |
+
|
| 230 |
+
# MLflow setup
|
| 231 |
+
run_name = setup_mlflow_experiment("from_xyY", "ft_transformer")
|
| 232 |
+
|
| 233 |
+
LOGGER.info("")
|
| 234 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 235 |
+
|
| 236 |
+
# Training loop
|
| 237 |
+
best_val_loss = float("inf")
|
| 238 |
+
patience_counter = 0
|
| 239 |
+
|
| 240 |
+
LOGGER.info("")
|
| 241 |
+
LOGGER.info("Starting training...")
|
| 242 |
+
|
| 243 |
+
with mlflow.start_run(run_name=run_name):
|
| 244 |
+
mlflow.log_params(
|
| 245 |
+
{
|
| 246 |
+
"model": "ft_transformer",
|
| 247 |
+
"learning_rate": lr,
|
| 248 |
+
"batch_size": batch_size,
|
| 249 |
+
"num_epochs": epochs,
|
| 250 |
+
"patience": patience,
|
| 251 |
+
"total_params": total_params,
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
for epoch in range(epochs):
|
| 256 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 257 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 258 |
+
|
| 259 |
+
scheduler.step(val_loss)
|
| 260 |
+
|
| 261 |
+
log_training_epoch(
|
| 262 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
LOGGER.info(
|
| 266 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 267 |
+
epoch + 1,
|
| 268 |
+
epochs,
|
| 269 |
+
train_loss,
|
| 270 |
+
val_loss,
|
| 271 |
+
optimizer.param_groups[0]["lr"],
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Early stopping
|
| 275 |
+
if val_loss < best_val_loss:
|
| 276 |
+
best_val_loss = val_loss
|
| 277 |
+
patience_counter = 0
|
| 278 |
+
|
| 279 |
+
model_directory.mkdir(exist_ok=True)
|
| 280 |
+
checkpoint_file = model_directory / "ft_transformer_best.pth"
|
| 281 |
+
|
| 282 |
+
torch.save(
|
| 283 |
+
{
|
| 284 |
+
"model_state_dict": model.state_dict(),
|
| 285 |
+
"epoch": epoch,
|
| 286 |
+
"val_loss": val_loss,
|
| 287 |
+
},
|
| 288 |
+
checkpoint_file,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 292 |
+
else:
|
| 293 |
+
patience_counter += 1
|
| 294 |
+
if patience_counter >= patience:
|
| 295 |
+
LOGGER.info("")
|
| 296 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
+
mlflow.log_metrics(
|
| 300 |
+
{
|
| 301 |
+
"best_val_loss": best_val_loss,
|
| 302 |
+
"final_epoch": epoch + 1,
|
| 303 |
+
}
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Export to ONNX
|
| 307 |
+
LOGGER.info("")
|
| 308 |
+
LOGGER.info("Exporting to ONNX...")
|
| 309 |
+
model.eval()
|
| 310 |
+
|
| 311 |
+
checkpoint = torch.load(checkpoint_file)
|
| 312 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 313 |
+
|
| 314 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 315 |
+
|
| 316 |
+
onnx_file = model_directory / "ft_transformer.onnx"
|
| 317 |
+
torch.onnx.export(
|
| 318 |
+
model,
|
| 319 |
+
dummy_input,
|
| 320 |
+
onnx_file,
|
| 321 |
+
export_params=True,
|
| 322 |
+
opset_version=15,
|
| 323 |
+
input_names=["xyY"],
|
| 324 |
+
output_names=["munsell_spec"],
|
| 325 |
+
dynamic_axes={
|
| 326 |
+
"xyY": {0: "batch_size"},
|
| 327 |
+
"munsell_spec": {0: "batch_size"},
|
| 328 |
+
},
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Save normalization parameters alongside model
|
| 332 |
+
params_file = model_directory / "ft_transformer_normalization_params.npz"
|
| 333 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 334 |
+
np.savez(
|
| 335 |
+
params_file,
|
| 336 |
+
input_params=input_params,
|
| 337 |
+
output_params=output_params,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 341 |
+
mlflow.log_artifact(str(onnx_file))
|
| 342 |
+
mlflow.log_artifact(str(params_file))
|
| 343 |
+
mlflow.pytorch.log_model(model, "model")
|
| 344 |
+
|
| 345 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 346 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 347 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
LOGGER.info("=" * 80)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 355 |
+
|
| 356 |
+
main()
|
learning_munsell/training/from_xyY/train_mixture_of_experts.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Mixture of Experts model for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
Option 6: Mixture of Experts architecture
|
| 5 |
+
- Input: 3 features (xyY)
|
| 6 |
+
- Gating network: 3 → 128 → 64 → 4 (softmax weights)
|
| 7 |
+
- 4 Expert networks: Each 3 → 256 → 256 → 4 (MLP)
|
| 8 |
+
- Output: Weighted combination of expert outputs
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import click
|
| 13 |
+
|
| 14 |
+
import mlflow
|
| 15 |
+
import mlflow.pytorch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from numpy.typing import NDArray
|
| 19 |
+
from torch import nn, optim
|
| 20 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 21 |
+
|
| 22 |
+
from learning_munsell import PROJECT_ROOT
|
| 23 |
+
from learning_munsell.models.networks import ResidualBlock
|
| 24 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 25 |
+
from learning_munsell.utilities.data import (
|
| 26 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 27 |
+
XYY_NORMALIZATION_PARAMS,
|
| 28 |
+
normalize_munsell,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ExpertNetwork(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Single expert network with MLP architecture.
|
| 37 |
+
|
| 38 |
+
Each expert is a specialized neural network that learns to handle
|
| 39 |
+
specific regions of the input space. Uses residual connections for
|
| 40 |
+
improved gradient flow.
|
| 41 |
+
|
| 42 |
+
Architecture
|
| 43 |
+
------------
|
| 44 |
+
- Encoder: 3 → 256 with GELU and BatchNorm
|
| 45 |
+
- Residual blocks: Configurable number of ResidualBlock(256)
|
| 46 |
+
- Decoder: 256 → 4
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
num_residual_blocks : int, optional
|
| 51 |
+
Number of residual blocks, default is 2.
|
| 52 |
+
|
| 53 |
+
Attributes
|
| 54 |
+
----------
|
| 55 |
+
encoder : nn.Sequential
|
| 56 |
+
Input encoding layer.
|
| 57 |
+
residual_blocks : nn.ModuleList
|
| 58 |
+
Stack of residual blocks.
|
| 59 |
+
decoder : nn.Sequential
|
| 60 |
+
Output decoding layer.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, num_residual_blocks: int = 2) -> None:
|
| 64 |
+
"""Initialize the expert network."""
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.encoder = nn.Sequential(
|
| 68 |
+
nn.Linear(3, 256),
|
| 69 |
+
nn.GELU(),
|
| 70 |
+
nn.BatchNorm1d(256),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.residual_blocks = nn.ModuleList(
|
| 74 |
+
[ResidualBlock(256) for _ in range(num_residual_blocks)]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.decoder = nn.Sequential(
|
| 78 |
+
nn.Linear(256, 4),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
"""
|
| 83 |
+
Forward pass through expert network.
|
| 84 |
+
|
| 85 |
+
Parameters
|
| 86 |
+
----------
|
| 87 |
+
x : Tensor
|
| 88 |
+
Input xyY values of shape (batch_size, 3).
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
Tensor
|
| 93 |
+
Expert's prediction of shape (batch_size, 4).
|
| 94 |
+
"""
|
| 95 |
+
x = self.encoder(x)
|
| 96 |
+
for block in self.residual_blocks:
|
| 97 |
+
x = block(x)
|
| 98 |
+
return self.decoder(x)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class GatingNetwork(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
Gating network to compute expert weights.
|
| 104 |
+
|
| 105 |
+
Learns to route inputs to appropriate experts by outputting a probability
|
| 106 |
+
distribution over all experts. Different inputs activate different experts
|
| 107 |
+
based on learned input characteristics.
|
| 108 |
+
|
| 109 |
+
Architecture
|
| 110 |
+
------------
|
| 111 |
+
3 → 128 → 64 → num_experts → softmax
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
num_experts : int
|
| 116 |
+
Number of expert networks to gate.
|
| 117 |
+
|
| 118 |
+
Attributes
|
| 119 |
+
----------
|
| 120 |
+
gate : nn.Sequential
|
| 121 |
+
MLP that maps inputs to expert logits.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, num_experts: int) -> None:
|
| 125 |
+
"""Initialize the gating network."""
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.gate = nn.Sequential(
|
| 129 |
+
nn.Linear(3, 128),
|
| 130 |
+
nn.GELU(),
|
| 131 |
+
nn.BatchNorm1d(128),
|
| 132 |
+
nn.Linear(128, 64),
|
| 133 |
+
nn.GELU(),
|
| 134 |
+
nn.BatchNorm1d(64),
|
| 135 |
+
nn.Linear(64, num_experts),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
"""
|
| 140 |
+
Compute expert weights for input.
|
| 141 |
+
|
| 142 |
+
Parameters
|
| 143 |
+
----------
|
| 144 |
+
x : Tensor
|
| 145 |
+
Input xyY values of shape (batch_size, 3).
|
| 146 |
+
|
| 147 |
+
Returns
|
| 148 |
+
-------
|
| 149 |
+
Tensor
|
| 150 |
+
Softmax weights over experts of shape (batch_size, num_experts).
|
| 151 |
+
Weights sum to 1 along expert dimension.
|
| 152 |
+
"""
|
| 153 |
+
# Output softmax weights for each expert
|
| 154 |
+
return torch.softmax(self.gate(x), dim=-1)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class MixtureOfExperts(nn.Module):
|
| 158 |
+
"""
|
| 159 |
+
Mixture of Experts for xyY to Munsell conversion.
|
| 160 |
+
|
| 161 |
+
Implements a mixture of experts architecture where multiple specialized
|
| 162 |
+
neural networks (experts) are combined via learned gating weights. This
|
| 163 |
+
allows different experts to specialize in different regions of the input
|
| 164 |
+
space (e.g., different color ranges or hue families).
|
| 165 |
+
|
| 166 |
+
Architecture
|
| 167 |
+
------------
|
| 168 |
+
- Gating network: Learns which expert(s) to use for each input
|
| 169 |
+
- Multiple expert networks: Each specializes in different input regions
|
| 170 |
+
- Output: Weighted combination of expert predictions based on gate weights
|
| 171 |
+
- Load balancing: Auxiliary loss encourages balanced expert usage
|
| 172 |
+
|
| 173 |
+
Parameters
|
| 174 |
+
----------
|
| 175 |
+
num_experts : int, optional
|
| 176 |
+
Number of expert networks, default is 4.
|
| 177 |
+
num_residual_blocks : int, optional
|
| 178 |
+
Number of residual blocks per expert, default is 2.
|
| 179 |
+
|
| 180 |
+
Attributes
|
| 181 |
+
----------
|
| 182 |
+
num_experts : int
|
| 183 |
+
Number of expert networks.
|
| 184 |
+
gating_network : GatingNetwork
|
| 185 |
+
Network that computes expert weights.
|
| 186 |
+
experts : nn.ModuleList
|
| 187 |
+
List of expert networks.
|
| 188 |
+
load_balance_weight : float
|
| 189 |
+
Weight for load balancing auxiliary loss.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, num_experts: int = 4, num_residual_blocks: int = 2) -> None:
|
| 193 |
+
"""Initialize the mixture of experts model."""
|
| 194 |
+
super().__init__()
|
| 195 |
+
|
| 196 |
+
self.num_experts = num_experts
|
| 197 |
+
|
| 198 |
+
# Gating network
|
| 199 |
+
self.gating_network = GatingNetwork(num_experts)
|
| 200 |
+
|
| 201 |
+
# Expert networks
|
| 202 |
+
self.experts = nn.ModuleList(
|
| 203 |
+
[ExpertNetwork(num_residual_blocks) for _ in range(num_experts)]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Load balancing loss weight
|
| 207 |
+
self.load_balance_weight = 0.01
|
| 208 |
+
|
| 209 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 210 |
+
"""
|
| 211 |
+
Forward pass through mixture of experts.
|
| 212 |
+
|
| 213 |
+
Parameters
|
| 214 |
+
----------
|
| 215 |
+
x : Tensor
|
| 216 |
+
Input xyY values of shape (batch_size, 3).
|
| 217 |
+
|
| 218 |
+
Returns
|
| 219 |
+
-------
|
| 220 |
+
tuple
|
| 221 |
+
(output, gate_weights) where:
|
| 222 |
+
- output: Weighted expert predictions of shape (batch_size, 4)
|
| 223 |
+
- gate_weights: Expert weights of shape (batch_size, num_experts)
|
| 224 |
+
"""
|
| 225 |
+
# Get gating weights
|
| 226 |
+
gate_weights = self.gating_network(x) # (batch_size, num_experts)
|
| 227 |
+
|
| 228 |
+
# Get expert outputs
|
| 229 |
+
expert_outputs = torch.stack(
|
| 230 |
+
[expert(x) for expert in self.experts], dim=1
|
| 231 |
+
) # (batch_size, num_experts, 4)
|
| 232 |
+
|
| 233 |
+
# Weighted combination
|
| 234 |
+
gate_weights_expanded = gate_weights.unsqueeze(
|
| 235 |
+
-1
|
| 236 |
+
) # (batch_size, num_experts, 1)
|
| 237 |
+
output = torch.sum(
|
| 238 |
+
expert_outputs * gate_weights_expanded, dim=1
|
| 239 |
+
) # (batch_size, 4)
|
| 240 |
+
|
| 241 |
+
return output, gate_weights
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def precision_focused_loss(
|
| 245 |
+
pred: torch.Tensor,
|
| 246 |
+
target: torch.Tensor,
|
| 247 |
+
gate_weights: torch.Tensor,
|
| 248 |
+
load_balance_weight: float = 0.01,
|
| 249 |
+
) -> torch.Tensor:
|
| 250 |
+
"""
|
| 251 |
+
Precision-focused loss function with load balancing for mixture of experts.
|
| 252 |
+
|
| 253 |
+
Combines standard regression losses (MSE, MAE, log penalty, Huber) with
|
| 254 |
+
a load balancing auxiliary loss that encourages uniform expert usage across
|
| 255 |
+
the dataset to prevent expert collapse.
|
| 256 |
+
|
| 257 |
+
Parameters
|
| 258 |
+
----------
|
| 259 |
+
pred : torch.Tensor
|
| 260 |
+
Predicted values.
|
| 261 |
+
target : torch.Tensor
|
| 262 |
+
Target ground truth values.
|
| 263 |
+
gate_weights : torch.Tensor
|
| 264 |
+
Expert gating weights of shape (batch_size, num_experts).
|
| 265 |
+
load_balance_weight : float, optional
|
| 266 |
+
Weight for load balancing auxiliary loss, default is 0.01.
|
| 267 |
+
|
| 268 |
+
Returns
|
| 269 |
+
-------
|
| 270 |
+
torch.Tensor
|
| 271 |
+
Combined loss value including load balancing term.
|
| 272 |
+
|
| 273 |
+
Notes
|
| 274 |
+
-----
|
| 275 |
+
The load balancing loss encourages each expert to handle roughly
|
| 276 |
+
1/num_experts of the data, preventing scenarios where only a few
|
| 277 |
+
experts are used while others remain idle.
|
| 278 |
+
"""
|
| 279 |
+
# Standard precision loss
|
| 280 |
+
mse = torch.mean((pred - target) ** 2)
|
| 281 |
+
mae = torch.mean(torch.abs(pred - target))
|
| 282 |
+
log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
|
| 283 |
+
|
| 284 |
+
delta = 0.01
|
| 285 |
+
abs_error = torch.abs(pred - target)
|
| 286 |
+
huber = torch.where(
|
| 287 |
+
abs_error <= delta, 0.5 * abs_error**2, delta * (abs_error - 0.5 * delta)
|
| 288 |
+
)
|
| 289 |
+
huber_loss = torch.mean(huber)
|
| 290 |
+
|
| 291 |
+
# Load balancing loss: Encourage balanced expert usage
|
| 292 |
+
# Compute importance (sum of gate weights per expert)
|
| 293 |
+
importance = gate_weights.sum(dim=0) # (num_experts,)
|
| 294 |
+
# Normalize to probabilities
|
| 295 |
+
importance = importance / importance.sum()
|
| 296 |
+
# Encourage uniform distribution (1/num_experts for each)
|
| 297 |
+
num_experts = gate_weights.size(1)
|
| 298 |
+
target_importance = torch.ones_like(importance) / num_experts
|
| 299 |
+
load_balance_loss = torch.mean((importance - target_importance) ** 2)
|
| 300 |
+
|
| 301 |
+
return (
|
| 302 |
+
1.0 * mse
|
| 303 |
+
+ 0.5 * mae
|
| 304 |
+
+ 0.3 * log_penalty
|
| 305 |
+
+ 0.5 * huber_loss
|
| 306 |
+
+ load_balance_weight * load_balance_loss
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def train_epoch(
|
| 311 |
+
model: nn.Module,
|
| 312 |
+
dataloader: DataLoader,
|
| 313 |
+
optimizer: optim.Optimizer,
|
| 314 |
+
device: torch.device,
|
| 315 |
+
) -> float:
|
| 316 |
+
"""
|
| 317 |
+
Train the mixture of experts model for one epoch.
|
| 318 |
+
|
| 319 |
+
Parameters
|
| 320 |
+
----------
|
| 321 |
+
model : nn.Module
|
| 322 |
+
The neural network model to train.
|
| 323 |
+
dataloader : DataLoader
|
| 324 |
+
DataLoader providing training batches (X, y).
|
| 325 |
+
optimizer : optim.Optimizer
|
| 326 |
+
Optimizer for updating model parameters.
|
| 327 |
+
device : torch.device
|
| 328 |
+
Device to run training on.
|
| 329 |
+
|
| 330 |
+
Returns
|
| 331 |
+
-------
|
| 332 |
+
float
|
| 333 |
+
Average loss for the epoch.
|
| 334 |
+
|
| 335 |
+
Notes
|
| 336 |
+
-----
|
| 337 |
+
Loss includes both prediction error and load balancing term.
|
| 338 |
+
The loss function is computed by precision_focused_loss which is
|
| 339 |
+
passed gate_weights for load balancing.
|
| 340 |
+
"""
|
| 341 |
+
model.train()
|
| 342 |
+
total_loss = 0.0
|
| 343 |
+
|
| 344 |
+
for X_batch, y_batch in dataloader:
|
| 345 |
+
X_batch = X_batch.to(device)
|
| 346 |
+
y_batch = y_batch.to(device)
|
| 347 |
+
outputs, gate_weights = model(X_batch)
|
| 348 |
+
loss = precision_focused_loss(
|
| 349 |
+
outputs, y_batch, gate_weights, model.load_balance_weight
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
optimizer.zero_grad()
|
| 353 |
+
loss.backward()
|
| 354 |
+
optimizer.step()
|
| 355 |
+
|
| 356 |
+
total_loss += loss.item()
|
| 357 |
+
|
| 358 |
+
return total_loss / len(dataloader)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def validate(model: nn.Module, dataloader: DataLoader, device: torch.device) -> float:
|
| 362 |
+
"""
|
| 363 |
+
Validate the mixture of experts model on validation set.
|
| 364 |
+
|
| 365 |
+
Parameters
|
| 366 |
+
----------
|
| 367 |
+
model : nn.Module
|
| 368 |
+
The neural network model to validate.
|
| 369 |
+
dataloader : DataLoader
|
| 370 |
+
DataLoader providing validation batches (X, y).
|
| 371 |
+
device : torch.device
|
| 372 |
+
Device to run validation on.
|
| 373 |
+
|
| 374 |
+
Returns
|
| 375 |
+
-------
|
| 376 |
+
float
|
| 377 |
+
Average loss for the validation set.
|
| 378 |
+
"""
|
| 379 |
+
model.eval()
|
| 380 |
+
total_loss = 0.0
|
| 381 |
+
|
| 382 |
+
with torch.no_grad():
|
| 383 |
+
for X_batch, y_batch in dataloader:
|
| 384 |
+
X_batch = X_batch.to(device)
|
| 385 |
+
y_batch = y_batch.to(device)
|
| 386 |
+
outputs, gate_weights = model(X_batch)
|
| 387 |
+
loss = precision_focused_loss(
|
| 388 |
+
outputs, y_batch, gate_weights, model.load_balance_weight
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
total_loss += loss.item()
|
| 392 |
+
|
| 393 |
+
return total_loss / len(dataloader)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@click.command()
|
| 397 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 398 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 399 |
+
@click.option("--lr", default=3e-4, help="Learning rate")
|
| 400 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 401 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 402 |
+
"""
|
| 403 |
+
Train mixture of experts model for xyY to Munsell conversion.
|
| 404 |
+
|
| 405 |
+
Notes
|
| 406 |
+
-----
|
| 407 |
+
The training pipeline:
|
| 408 |
+
1. Loads normalization parameters from existing config
|
| 409 |
+
2. Loads training data from cache
|
| 410 |
+
3. Normalizes inputs and outputs to [0, 1] range
|
| 411 |
+
4. Creates PyTorch DataLoaders
|
| 412 |
+
5. Initializes MixtureOfExperts with 4 expert networks
|
| 413 |
+
6. Trains with AdamW optimizer and precision-focused loss
|
| 414 |
+
7. Uses learning rate scheduler (ReduceLROnPlateau)
|
| 415 |
+
8. Implements early stopping based on validation loss
|
| 416 |
+
9. Exports best model to ONNX format
|
| 417 |
+
10. Logs all metrics and artifacts to MLflow
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
LOGGER.info("=" * 80)
|
| 422 |
+
LOGGER.info("Mixture of Experts: xyY → Munsell")
|
| 423 |
+
LOGGER.info("=" * 80)
|
| 424 |
+
|
| 425 |
+
# Set device
|
| 426 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 427 |
+
LOGGER.info("Using device: %s", device)
|
| 428 |
+
|
| 429 |
+
# Paths
|
| 430 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 431 |
+
data_dir = PROJECT_ROOT / "data"
|
| 432 |
+
cache_file = data_dir / "training_data.npz"
|
| 433 |
+
|
| 434 |
+
# Load training data
|
| 435 |
+
LOGGER.info("")
|
| 436 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 437 |
+
data = np.load(cache_file)
|
| 438 |
+
X_train = data["X_train"]
|
| 439 |
+
y_train = data["y_train"]
|
| 440 |
+
X_val = data["X_val"]
|
| 441 |
+
y_val = data["y_val"]
|
| 442 |
+
|
| 443 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 444 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 445 |
+
|
| 446 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 447 |
+
# Use hardcoded ranges covering the full Munsell space for generalization
|
| 448 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 449 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 450 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 451 |
+
|
| 452 |
+
# Convert to PyTorch tensors
|
| 453 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 454 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 455 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 456 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 457 |
+
|
| 458 |
+
# Create data loaders
|
| 459 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 460 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 461 |
+
|
| 462 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 463 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 464 |
+
|
| 465 |
+
# Initialize model
|
| 466 |
+
model = MixtureOfExperts(num_experts=4, num_residual_blocks=2).to(device)
|
| 467 |
+
LOGGER.info("")
|
| 468 |
+
LOGGER.info("Mixture of Experts architecture:")
|
| 469 |
+
LOGGER.info("%s", model)
|
| 470 |
+
|
| 471 |
+
# Count parameters
|
| 472 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 473 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 474 |
+
|
| 475 |
+
# Training setup
|
| 476 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 477 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 478 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# MLflow setup
|
| 482 |
+
run_name = setup_mlflow_experiment("from_xyY", "mixture_of_experts")
|
| 483 |
+
|
| 484 |
+
LOGGER.info("")
|
| 485 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 486 |
+
|
| 487 |
+
# Training loop
|
| 488 |
+
best_val_loss = float("inf")
|
| 489 |
+
patience_counter = 0
|
| 490 |
+
|
| 491 |
+
LOGGER.info("")
|
| 492 |
+
LOGGER.info("Starting training...")
|
| 493 |
+
|
| 494 |
+
with mlflow.start_run(run_name=run_name):
|
| 495 |
+
mlflow.log_params(
|
| 496 |
+
{
|
| 497 |
+
"model": "mixture_of_experts",
|
| 498 |
+
"learning_rate": lr,
|
| 499 |
+
"batch_size": batch_size,
|
| 500 |
+
"num_epochs": epochs,
|
| 501 |
+
"patience": patience,
|
| 502 |
+
"total_params": total_params,
|
| 503 |
+
}
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
for epoch in range(epochs):
|
| 507 |
+
train_loss = train_epoch(model, train_loader, optimizer, device)
|
| 508 |
+
val_loss = validate(model, val_loader, device)
|
| 509 |
+
|
| 510 |
+
scheduler.step(val_loss)
|
| 511 |
+
|
| 512 |
+
log_training_epoch(
|
| 513 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
LOGGER.info(
|
| 517 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 518 |
+
epoch + 1,
|
| 519 |
+
epochs,
|
| 520 |
+
train_loss,
|
| 521 |
+
val_loss,
|
| 522 |
+
optimizer.param_groups[0]["lr"],
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Early stopping
|
| 526 |
+
if val_loss < best_val_loss:
|
| 527 |
+
best_val_loss = val_loss
|
| 528 |
+
patience_counter = 0
|
| 529 |
+
|
| 530 |
+
model_directory.mkdir(exist_ok=True)
|
| 531 |
+
checkpoint_file = model_directory / "mixture_of_experts_best.pth"
|
| 532 |
+
|
| 533 |
+
torch.save(
|
| 534 |
+
{
|
| 535 |
+
"model_state_dict": model.state_dict(),
|
| 536 |
+
"epoch": epoch,
|
| 537 |
+
"val_loss": val_loss,
|
| 538 |
+
},
|
| 539 |
+
checkpoint_file,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 543 |
+
else:
|
| 544 |
+
patience_counter += 1
|
| 545 |
+
if patience_counter >= patience:
|
| 546 |
+
LOGGER.info("")
|
| 547 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 548 |
+
break
|
| 549 |
+
|
| 550 |
+
mlflow.log_metrics(
|
| 551 |
+
{
|
| 552 |
+
"best_val_loss": best_val_loss,
|
| 553 |
+
"final_epoch": epoch + 1,
|
| 554 |
+
}
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Export to ONNX (simplified - outputs only prediction, not gate weights)
|
| 558 |
+
LOGGER.info("")
|
| 559 |
+
LOGGER.info("Exporting to ONNX...")
|
| 560 |
+
model.eval()
|
| 561 |
+
|
| 562 |
+
checkpoint = torch.load(checkpoint_file)
|
| 563 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 564 |
+
|
| 565 |
+
# Create wrapper for ONNX export (only return prediction)
|
| 566 |
+
class MoEWrapper(nn.Module):
|
| 567 |
+
def __init__(self, moe_model: nn.Module) -> None:
|
| 568 |
+
super().__init__()
|
| 569 |
+
self.moe_model = moe_model
|
| 570 |
+
|
| 571 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 572 |
+
output, _ = self.moe_model(x)
|
| 573 |
+
return output
|
| 574 |
+
|
| 575 |
+
wrapped_model = MoEWrapper(model).to(device)
|
| 576 |
+
wrapped_model.eval()
|
| 577 |
+
|
| 578 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 579 |
+
|
| 580 |
+
onnx_file = model_directory / "mixture_of_experts.onnx"
|
| 581 |
+
torch.onnx.export(
|
| 582 |
+
wrapped_model,
|
| 583 |
+
dummy_input,
|
| 584 |
+
onnx_file,
|
| 585 |
+
export_params=True,
|
| 586 |
+
opset_version=15,
|
| 587 |
+
input_names=["xyY"],
|
| 588 |
+
output_names=["munsell_spec"],
|
| 589 |
+
dynamic_axes={
|
| 590 |
+
"xyY": {0: "batch_size"},
|
| 591 |
+
"munsell_spec": {0: "batch_size"},
|
| 592 |
+
},
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Save normalization parameters alongside model
|
| 596 |
+
params_file = model_directory / "mixture_of_experts_normalization_params.npz"
|
| 597 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 598 |
+
np.savez(
|
| 599 |
+
params_file,
|
| 600 |
+
input_params=input_params,
|
| 601 |
+
output_params=output_params,
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 605 |
+
mlflow.log_artifact(str(onnx_file))
|
| 606 |
+
mlflow.log_artifact(str(params_file))
|
| 607 |
+
mlflow.pytorch.log_model(model, "model")
|
| 608 |
+
|
| 609 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 610 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 611 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
LOGGER.info("=" * 80)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
if __name__ == "__main__":
|
| 618 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 619 |
+
|
| 620 |
+
main()
|
learning_munsell/training/from_xyY/train_mlp.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train ML model for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This script trains a compact MLP/DNN model with architecture:
|
| 5 |
+
3 inputs → [64, 128, 128, 64] hidden layers → 4 outputs
|
| 6 |
+
|
| 7 |
+
Target: < 1e-7 accuracy compared to iterative algorithm
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
import click
|
| 13 |
+
import mlflow
|
| 14 |
+
import mlflow.pytorch
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from torch import optim
|
| 18 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 19 |
+
|
| 20 |
+
from learning_munsell import PROJECT_ROOT
|
| 21 |
+
from learning_munsell.models.networks import MLPToMunsell
|
| 22 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 23 |
+
from learning_munsell.utilities.data import (
|
| 24 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 25 |
+
XYY_NORMALIZATION_PARAMS,
|
| 26 |
+
normalize_munsell,
|
| 27 |
+
)
|
| 28 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 29 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 30 |
+
|
| 31 |
+
LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@click.command()
|
| 35 |
+
@click.option("--epochs", default=200, help="Maximum training epochs.")
|
| 36 |
+
@click.option("--batch-size", default=1024, help="Training batch size.")
|
| 37 |
+
@click.option("--lr", default=5e-4, help="Learning rate.")
|
| 38 |
+
@click.option("--patience", default=20, help="Early stopping patience.")
|
| 39 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Train the MLPToMunsell model for xyY to Munsell conversion.
|
| 42 |
+
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
epochs : int
|
| 46 |
+
Maximum number of training epochs.
|
| 47 |
+
batch_size : int
|
| 48 |
+
Training batch size.
|
| 49 |
+
lr : float
|
| 50 |
+
Learning rate for AdamW optimizer.
|
| 51 |
+
patience : int
|
| 52 |
+
Early stopping patience (epochs without improvement).
|
| 53 |
+
|
| 54 |
+
Notes
|
| 55 |
+
-----
|
| 56 |
+
The training pipeline:
|
| 57 |
+
1. Loads training data from cache
|
| 58 |
+
2. Normalizes Munsell outputs to [0, 1] range
|
| 59 |
+
3. Trains compact MLP model (3 → [64, 128, 128, 64] → 4)
|
| 60 |
+
4. Uses weighted MSE loss function
|
| 61 |
+
5. Learning rate scheduling with ReduceLROnPlateau
|
| 62 |
+
6. Early stopping based on validation loss
|
| 63 |
+
7. Exports model to ONNX format
|
| 64 |
+
8. Logs metrics and artifacts to MLflow
|
| 65 |
+
"""
|
| 66 |
+
LOGGER.info("=" * 80)
|
| 67 |
+
LOGGER.info("ML-Based xyY to Munsell Conversion: Model Training")
|
| 68 |
+
LOGGER.info("=" * 80)
|
| 69 |
+
|
| 70 |
+
# Set device
|
| 71 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
+
LOGGER.info("Using device: %s", device)
|
| 73 |
+
|
| 74 |
+
# Load training data
|
| 75 |
+
data_dir = PROJECT_ROOT / "data"
|
| 76 |
+
cache_file = data_dir / "training_data.npz"
|
| 77 |
+
|
| 78 |
+
if not cache_file.exists():
|
| 79 |
+
LOGGER.error("Error: Training data not found at %s", cache_file)
|
| 80 |
+
LOGGER.error("Please run 01_generate_training_data.py first")
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 84 |
+
data = np.load(cache_file)
|
| 85 |
+
|
| 86 |
+
X_train = data["X_train"]
|
| 87 |
+
y_train = data["y_train"]
|
| 88 |
+
X_val = data["X_val"]
|
| 89 |
+
y_val = data["y_val"]
|
| 90 |
+
|
| 91 |
+
# Note: Invalid samples (outside Munsell gamut) are also stored in the cache
|
| 92 |
+
# Available as: data['xyY_all'], data['munsell_all'], data['valid_mask']
|
| 93 |
+
# These can be used for future enhancements like:
|
| 94 |
+
# - Adversarial training to avoid extrapolation
|
| 95 |
+
# - Gamut-aware loss functions
|
| 96 |
+
# - Uncertainty estimation at boundaries
|
| 97 |
+
|
| 98 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 99 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 100 |
+
|
| 101 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 102 |
+
# Use hardcoded ranges covering the full Munsell space for generalization
|
| 103 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 104 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 105 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 106 |
+
|
| 107 |
+
# Convert to PyTorch tensors
|
| 108 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 109 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 110 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 111 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 112 |
+
|
| 113 |
+
# Create data loaders
|
| 114 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 115 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 116 |
+
|
| 117 |
+
# Larger batch size for larger dataset (500K samples)
|
| 118 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 119 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 120 |
+
|
| 121 |
+
# Initialize model
|
| 122 |
+
model = MLPToMunsell().to(device)
|
| 123 |
+
LOGGER.info("")
|
| 124 |
+
LOGGER.info("Model architecture:")
|
| 125 |
+
LOGGER.info("%s", model)
|
| 126 |
+
|
| 127 |
+
# Count parameters
|
| 128 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 129 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 130 |
+
|
| 131 |
+
# Training setup - lower learning rate for larger model
|
| 132 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 133 |
+
# Use weighted MSE with default weights
|
| 134 |
+
weights = torch.tensor([1.0, 1.0, 2.0, 0.5])
|
| 135 |
+
criterion = lambda pred, target: weighted_mse_loss(pred, target, weights)
|
| 136 |
+
|
| 137 |
+
# MLflow setup
|
| 138 |
+
run_name = setup_mlflow_experiment("from_xyY", "mlp")
|
| 139 |
+
|
| 140 |
+
LOGGER.info("")
|
| 141 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 142 |
+
|
| 143 |
+
# Training loop
|
| 144 |
+
best_val_loss = float("inf")
|
| 145 |
+
patience_counter = 0
|
| 146 |
+
|
| 147 |
+
LOGGER.info("")
|
| 148 |
+
LOGGER.info("Starting training...")
|
| 149 |
+
|
| 150 |
+
with mlflow.start_run(run_name=run_name):
|
| 151 |
+
# Log hyperparameters
|
| 152 |
+
mlflow.log_params(
|
| 153 |
+
{
|
| 154 |
+
"epochs": epochs,
|
| 155 |
+
"batch_size": batch_size,
|
| 156 |
+
"learning_rate": lr,
|
| 157 |
+
"optimizer": "Adam",
|
| 158 |
+
"criterion": "weighted_mse_loss",
|
| 159 |
+
"patience": patience,
|
| 160 |
+
"total_params": total_params,
|
| 161 |
+
}
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
for epoch in range(epochs):
|
| 165 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 166 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 167 |
+
|
| 168 |
+
# Log to MLflow
|
| 169 |
+
log_training_epoch(
|
| 170 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
LOGGER.info(
|
| 174 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 175 |
+
epoch + 1,
|
| 176 |
+
epochs,
|
| 177 |
+
train_loss,
|
| 178 |
+
val_loss,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Early stopping
|
| 182 |
+
if val_loss < best_val_loss:
|
| 183 |
+
best_val_loss = val_loss
|
| 184 |
+
patience_counter = 0
|
| 185 |
+
|
| 186 |
+
# Save best model
|
| 187 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 188 |
+
model_directory.mkdir(exist_ok=True)
|
| 189 |
+
checkpoint_file = model_directory / "mlp_best.pth"
|
| 190 |
+
|
| 191 |
+
torch.save(
|
| 192 |
+
{
|
| 193 |
+
"model_state_dict": model.state_dict(),
|
| 194 |
+
"output_params": output_params,
|
| 195 |
+
"epoch": epoch,
|
| 196 |
+
"val_loss": val_loss,
|
| 197 |
+
},
|
| 198 |
+
checkpoint_file,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 202 |
+
else:
|
| 203 |
+
patience_counter += 1
|
| 204 |
+
if patience_counter >= patience:
|
| 205 |
+
LOGGER.info("")
|
| 206 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
# Log final metrics
|
| 210 |
+
mlflow.log_metrics(
|
| 211 |
+
{
|
| 212 |
+
"best_val_loss": best_val_loss,
|
| 213 |
+
"final_epoch": epoch + 1,
|
| 214 |
+
}
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Export to ONNX
|
| 218 |
+
LOGGER.info("")
|
| 219 |
+
LOGGER.info("Exporting model to ONNX...")
|
| 220 |
+
model.eval()
|
| 221 |
+
|
| 222 |
+
# Load best model
|
| 223 |
+
checkpoint = torch.load(checkpoint_file)
|
| 224 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 225 |
+
|
| 226 |
+
# Create dummy input
|
| 227 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 228 |
+
|
| 229 |
+
# Export
|
| 230 |
+
onnx_file = model_directory / "mlp.onnx"
|
| 231 |
+
torch.onnx.export(
|
| 232 |
+
model,
|
| 233 |
+
dummy_input,
|
| 234 |
+
onnx_file,
|
| 235 |
+
export_params=True,
|
| 236 |
+
opset_version=15,
|
| 237 |
+
input_names=["xyY"],
|
| 238 |
+
output_names=["munsell_spec"],
|
| 239 |
+
dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Save normalization parameters alongside model
|
| 243 |
+
params_file = model_directory / "mlp_normalization_params.npz"
|
| 244 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 245 |
+
np.savez(
|
| 246 |
+
params_file,
|
| 247 |
+
input_params=input_params,
|
| 248 |
+
output_params=output_params,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 252 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 253 |
+
|
| 254 |
+
# Log artifacts
|
| 255 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 256 |
+
mlflow.log_artifact(str(onnx_file))
|
| 257 |
+
mlflow.log_artifact(str(params_file))
|
| 258 |
+
|
| 259 |
+
# Log model
|
| 260 |
+
mlflow.pytorch.log_model(model, "model")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
LOGGER.info("=" * 80)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 268 |
+
|
| 269 |
+
main()
|
learning_munsell/training/from_xyY/train_mlp_attention.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train MLP + Self-Attention model for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
Option 1: MLP backbone with multi-head self-attention layers
|
| 5 |
+
- Input: 3 features (xyY)
|
| 6 |
+
- Architecture: 3 -> 512 -> 1024 + [Attention + ResBlock] x 4 -> 512 -> 4
|
| 7 |
+
- Output: 4 features (hue, value, chroma, code)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import click
|
| 12 |
+
import mlflow
|
| 13 |
+
import mlflow.pytorch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from numpy.typing import NDArray
|
| 17 |
+
from torch import nn, optim
|
| 18 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 19 |
+
|
| 20 |
+
from learning_munsell import PROJECT_ROOT
|
| 21 |
+
from learning_munsell.models.networks import ResidualBlock
|
| 22 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 23 |
+
from learning_munsell.utilities.data import (
|
| 24 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 25 |
+
XYY_NORMALIZATION_PARAMS,
|
| 26 |
+
normalize_munsell,
|
| 27 |
+
)
|
| 28 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 29 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 30 |
+
|
| 31 |
+
LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Multi-head self-attention layer for feature interaction.
|
| 37 |
+
|
| 38 |
+
Implements scaled dot-product attention with multiple heads to capture
|
| 39 |
+
different aspects of feature relationships.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
dim
|
| 44 |
+
Input and output feature dimension.
|
| 45 |
+
num_heads
|
| 46 |
+
Number of attention heads. Must divide ``dim`` evenly.
|
| 47 |
+
|
| 48 |
+
Attributes
|
| 49 |
+
----------
|
| 50 |
+
query
|
| 51 |
+
Linear projection for query vectors.
|
| 52 |
+
key
|
| 53 |
+
Linear projection for key vectors.
|
| 54 |
+
value
|
| 55 |
+
Linear projection for value vectors.
|
| 56 |
+
out
|
| 57 |
+
Output projection after attention.
|
| 58 |
+
scale
|
| 59 |
+
Scaling factor (1/sqrt(head_dim)) for dot-product attention.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, dim: int, num_heads: int = 4) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
self.dim = dim
|
| 66 |
+
self.head_dim = dim // num_heads
|
| 67 |
+
|
| 68 |
+
assert dim % num_heads == 0, "dim must be divisible by num_heads" # noqa: S101
|
| 69 |
+
|
| 70 |
+
self.query = nn.Linear(dim, dim)
|
| 71 |
+
self.key = nn.Linear(dim, dim)
|
| 72 |
+
self.value = nn.Linear(dim, dim)
|
| 73 |
+
self.out = nn.Linear(dim, dim)
|
| 74 |
+
|
| 75 |
+
self.scale = self.head_dim**-0.5
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Apply multi-head self-attention.
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
x
|
| 84 |
+
Input tensor of shape ``(batch_size, dim)``.
|
| 85 |
+
|
| 86 |
+
Returns
|
| 87 |
+
-------
|
| 88 |
+
torch.Tensor
|
| 89 |
+
Output tensor of shape ``(batch_size, dim)`` with attention applied.
|
| 90 |
+
"""
|
| 91 |
+
batch_size = x.size(0)
|
| 92 |
+
|
| 93 |
+
# Linear projections
|
| 94 |
+
Q = self.query(x).view(batch_size, self.num_heads, self.head_dim)
|
| 95 |
+
K = self.key(x).view(batch_size, self.num_heads, self.head_dim)
|
| 96 |
+
V = self.value(x).view(batch_size, self.num_heads, self.head_dim)
|
| 97 |
+
|
| 98 |
+
# Scaled dot-product attention
|
| 99 |
+
attn_weights = torch.softmax(
|
| 100 |
+
torch.matmul(Q, K.transpose(-2, -1)) * self.scale, dim=-1
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Apply attention to values
|
| 104 |
+
attn_output = torch.matmul(attn_weights, V)
|
| 105 |
+
|
| 106 |
+
# Concatenate heads and project
|
| 107 |
+
attn_output = attn_output.view(batch_size, self.dim)
|
| 108 |
+
return self.out(attn_output)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class AttentionResBlock(nn.Module):
|
| 112 |
+
"""
|
| 113 |
+
Combined attention and residual block.
|
| 114 |
+
|
| 115 |
+
Applies self-attention followed by a residual MLP block, each with
|
| 116 |
+
batch normalization and skip connections.
|
| 117 |
+
|
| 118 |
+
Parameters
|
| 119 |
+
----------
|
| 120 |
+
dim
|
| 121 |
+
Input and output feature dimension.
|
| 122 |
+
num_heads
|
| 123 |
+
Number of attention heads for the self-attention layer.
|
| 124 |
+
|
| 125 |
+
Attributes
|
| 126 |
+
----------
|
| 127 |
+
attention
|
| 128 |
+
Multi-head self-attention layer.
|
| 129 |
+
norm1
|
| 130 |
+
Batch normalization after attention.
|
| 131 |
+
residual
|
| 132 |
+
Residual MLP block.
|
| 133 |
+
norm2
|
| 134 |
+
Batch normalization after residual block.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, dim: int, num_heads: int = 4) -> None:
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.attention = MultiHeadSelfAttention(dim, num_heads)
|
| 140 |
+
self.norm1 = nn.BatchNorm1d(dim)
|
| 141 |
+
self.residual = ResidualBlock(dim)
|
| 142 |
+
self.norm2 = nn.BatchNorm1d(dim)
|
| 143 |
+
|
| 144 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Apply attention and residual transformations.
|
| 147 |
+
|
| 148 |
+
Parameters
|
| 149 |
+
----------
|
| 150 |
+
x
|
| 151 |
+
Input tensor of shape ``(batch_size, dim)``.
|
| 152 |
+
|
| 153 |
+
Returns
|
| 154 |
+
-------
|
| 155 |
+
torch.Tensor
|
| 156 |
+
Output tensor of shape ``(batch_size, dim)``.
|
| 157 |
+
"""
|
| 158 |
+
# Attention with residual
|
| 159 |
+
attn_out = self.norm1(x + self.attention(x))
|
| 160 |
+
# ResBlock with residual
|
| 161 |
+
return self.norm2(self.residual(attn_out))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class MLPAttention(nn.Module):
|
| 165 |
+
"""
|
| 166 |
+
MLP with self-attention for xyY to Munsell conversion.
|
| 167 |
+
|
| 168 |
+
Architecture:
|
| 169 |
+
- Input: 3 features (xyY normalized to [0, 1])
|
| 170 |
+
- Encoder: 3 -> 512 -> 1024
|
| 171 |
+
- Attention-ResBlocks at 1024-dim (configurable count)
|
| 172 |
+
- Decoder: 1024 -> 512 -> 4
|
| 173 |
+
- Output: 4 features (hue, value, chroma, code normalized)
|
| 174 |
+
|
| 175 |
+
Parameters
|
| 176 |
+
----------
|
| 177 |
+
num_blocks
|
| 178 |
+
Number of attention-residual blocks in the middle.
|
| 179 |
+
num_heads
|
| 180 |
+
Number of attention heads in each attention layer.
|
| 181 |
+
|
| 182 |
+
Attributes
|
| 183 |
+
----------
|
| 184 |
+
encoder
|
| 185 |
+
MLP that projects 3D xyY input to 1024D feature space.
|
| 186 |
+
blocks
|
| 187 |
+
List of AttentionResBlock modules.
|
| 188 |
+
decoder
|
| 189 |
+
MLP that projects 1024D features to 4D Munsell output.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, num_blocks: int = 4, num_heads: int = 4) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
|
| 195 |
+
# Encoder
|
| 196 |
+
self.encoder = nn.Sequential(
|
| 197 |
+
nn.Linear(3, 512),
|
| 198 |
+
nn.GELU(),
|
| 199 |
+
nn.BatchNorm1d(512),
|
| 200 |
+
nn.Linear(512, 1024),
|
| 201 |
+
nn.GELU(),
|
| 202 |
+
nn.BatchNorm1d(1024),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Attention-ResBlocks
|
| 206 |
+
self.blocks = nn.ModuleList(
|
| 207 |
+
[AttentionResBlock(1024, num_heads) for _ in range(num_blocks)]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Decoder
|
| 211 |
+
self.decoder = nn.Sequential(
|
| 212 |
+
nn.Linear(1024, 512),
|
| 213 |
+
nn.GELU(),
|
| 214 |
+
nn.BatchNorm1d(512),
|
| 215 |
+
nn.Linear(512, 4),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
"""
|
| 220 |
+
Predict Munsell specification from xyY input.
|
| 221 |
+
|
| 222 |
+
Parameters
|
| 223 |
+
----------
|
| 224 |
+
x
|
| 225 |
+
Input tensor of shape ``(batch_size, 3)`` containing normalized
|
| 226 |
+
xyY values.
|
| 227 |
+
|
| 228 |
+
Returns
|
| 229 |
+
-------
|
| 230 |
+
torch.Tensor
|
| 231 |
+
Output tensor of shape ``(batch_size, 4)`` containing normalized
|
| 232 |
+
Munsell specification [hue, value, chroma, code].
|
| 233 |
+
"""
|
| 234 |
+
# Encode
|
| 235 |
+
x = self.encoder(x)
|
| 236 |
+
|
| 237 |
+
# Attention-ResBlocks
|
| 238 |
+
for block in self.blocks:
|
| 239 |
+
x = block(x)
|
| 240 |
+
|
| 241 |
+
# Decode
|
| 242 |
+
return self.decoder(x)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@click.command()
|
| 246 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 247 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 248 |
+
@click.option("--lr", default=3e-4, help="Learning rate")
|
| 249 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 250 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 251 |
+
"""
|
| 252 |
+
Train MLP + Self-Attention model for xyY to Munsell conversion.
|
| 253 |
+
|
| 254 |
+
Notes
|
| 255 |
+
-----
|
| 256 |
+
The training pipeline:
|
| 257 |
+
1. Loads normalization parameters and training data from disk
|
| 258 |
+
2. Normalizes inputs (xyY) and outputs (Munsell specification) to [0, 1]
|
| 259 |
+
3. Creates MLPAttention model (4 blocks, 4 attention heads)
|
| 260 |
+
4. Trains with precision-focused loss (MSE + MAE + log + Huber)
|
| 261 |
+
5. Uses AdamW optimizer with ReduceLROnPlateau scheduler
|
| 262 |
+
6. Applies early stopping based on validation loss (patience=20)
|
| 263 |
+
7. Exports best model to ONNX format
|
| 264 |
+
8. Logs metrics and artifacts to MLflow
|
| 265 |
+
"""
|
| 266 |
+
LOGGER.info("=" * 80)
|
| 267 |
+
LOGGER.info("MLP + Self-Attention: xyY → Munsell")
|
| 268 |
+
LOGGER.info("=" * 80)
|
| 269 |
+
|
| 270 |
+
# Set device
|
| 271 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 272 |
+
LOGGER.info("Using device: %s", device)
|
| 273 |
+
|
| 274 |
+
# Paths
|
| 275 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 276 |
+
data_dir = PROJECT_ROOT / "data"
|
| 277 |
+
cache_file = data_dir / "training_data.npz"
|
| 278 |
+
|
| 279 |
+
# Load training data
|
| 280 |
+
LOGGER.info("")
|
| 281 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 282 |
+
data = np.load(cache_file)
|
| 283 |
+
X_train = data["X_train"]
|
| 284 |
+
y_train = data["y_train"]
|
| 285 |
+
X_val = data["X_val"]
|
| 286 |
+
y_val = data["y_val"]
|
| 287 |
+
|
| 288 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 289 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 290 |
+
|
| 291 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 292 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 293 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 294 |
+
|
| 295 |
+
# Convert to PyTorch tensors
|
| 296 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 297 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 298 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 299 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 300 |
+
|
| 301 |
+
# Create data loaders
|
| 302 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 303 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 304 |
+
|
| 305 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 306 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 307 |
+
|
| 308 |
+
# Initialize model
|
| 309 |
+
model = MLPAttention(num_blocks=4, num_heads=4).to(device)
|
| 310 |
+
LOGGER.info("")
|
| 311 |
+
LOGGER.info("MLP + Attention architecture:")
|
| 312 |
+
LOGGER.info("%s", model)
|
| 313 |
+
|
| 314 |
+
# Count parameters
|
| 315 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 316 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 317 |
+
|
| 318 |
+
# Training setup
|
| 319 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 320 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 321 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 322 |
+
)
|
| 323 |
+
criterion = precision_focused_loss
|
| 324 |
+
|
| 325 |
+
# MLflow setup
|
| 326 |
+
run_name = setup_mlflow_experiment("from_xyY", "mlp_attention")
|
| 327 |
+
|
| 328 |
+
LOGGER.info("")
|
| 329 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 330 |
+
|
| 331 |
+
# Training loop
|
| 332 |
+
best_val_loss = float("inf")
|
| 333 |
+
patience_counter = 0
|
| 334 |
+
|
| 335 |
+
LOGGER.info("")
|
| 336 |
+
LOGGER.info("Starting training...")
|
| 337 |
+
|
| 338 |
+
with mlflow.start_run(run_name=run_name):
|
| 339 |
+
# Log hyperparameters
|
| 340 |
+
mlflow.log_params(
|
| 341 |
+
{
|
| 342 |
+
"num_epochs": epochs,
|
| 343 |
+
"batch_size": batch_size,
|
| 344 |
+
"learning_rate": lr,
|
| 345 |
+
"weight_decay": 1e-5,
|
| 346 |
+
"optimizer": "AdamW",
|
| 347 |
+
"scheduler": "ReduceLROnPlateau",
|
| 348 |
+
"criterion": "precision_focused_loss",
|
| 349 |
+
"patience": patience,
|
| 350 |
+
"total_params": total_params,
|
| 351 |
+
"num_blocks": 4,
|
| 352 |
+
"num_heads": 4,
|
| 353 |
+
}
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
for epoch in range(epochs):
|
| 357 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 358 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 359 |
+
|
| 360 |
+
scheduler.step(val_loss)
|
| 361 |
+
|
| 362 |
+
log_training_epoch(
|
| 363 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
LOGGER.info(
|
| 367 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 368 |
+
epoch + 1,
|
| 369 |
+
epochs,
|
| 370 |
+
train_loss,
|
| 371 |
+
val_loss,
|
| 372 |
+
optimizer.param_groups[0]["lr"],
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Early stopping
|
| 376 |
+
if val_loss < best_val_loss:
|
| 377 |
+
best_val_loss = val_loss
|
| 378 |
+
patience_counter = 0
|
| 379 |
+
|
| 380 |
+
model_directory.mkdir(exist_ok=True)
|
| 381 |
+
checkpoint_file = model_directory / "mlp_attention_best.pth"
|
| 382 |
+
|
| 383 |
+
torch.save(
|
| 384 |
+
{
|
| 385 |
+
"model_state_dict": model.state_dict(),
|
| 386 |
+
"epoch": epoch,
|
| 387 |
+
"val_loss": val_loss,
|
| 388 |
+
},
|
| 389 |
+
checkpoint_file,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 393 |
+
else:
|
| 394 |
+
patience_counter += 1
|
| 395 |
+
if patience_counter >= patience:
|
| 396 |
+
LOGGER.info("")
|
| 397 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 398 |
+
break
|
| 399 |
+
|
| 400 |
+
# Log final metrics
|
| 401 |
+
mlflow.log_metrics(
|
| 402 |
+
{
|
| 403 |
+
"best_val_loss": best_val_loss,
|
| 404 |
+
"final_epoch": epoch + 1,
|
| 405 |
+
}
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Export to ONNX
|
| 409 |
+
LOGGER.info("")
|
| 410 |
+
LOGGER.info("Exporting to ONNX...")
|
| 411 |
+
model.eval()
|
| 412 |
+
|
| 413 |
+
checkpoint = torch.load(checkpoint_file)
|
| 414 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 415 |
+
|
| 416 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 417 |
+
|
| 418 |
+
onnx_file = model_directory / "mlp_attention.onnx"
|
| 419 |
+
torch.onnx.export(
|
| 420 |
+
model,
|
| 421 |
+
dummy_input,
|
| 422 |
+
onnx_file,
|
| 423 |
+
export_params=True,
|
| 424 |
+
opset_version=15,
|
| 425 |
+
input_names=["xyY"],
|
| 426 |
+
output_names=["munsell_spec"],
|
| 427 |
+
dynamic_axes={
|
| 428 |
+
"xyY": {0: "batch_size"},
|
| 429 |
+
"munsell_spec": {0: "batch_size"},
|
| 430 |
+
},
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Save normalization parameters alongside model
|
| 434 |
+
params_file = model_directory / "mlp_attention_normalization_params.npz"
|
| 435 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 436 |
+
np.savez(
|
| 437 |
+
params_file,
|
| 438 |
+
input_params=input_params,
|
| 439 |
+
output_params=output_params,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 443 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 444 |
+
|
| 445 |
+
# Log artifacts
|
| 446 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 447 |
+
mlflow.log_artifact(str(onnx_file))
|
| 448 |
+
mlflow.log_artifact(str(params_file))
|
| 449 |
+
|
| 450 |
+
# Log model
|
| 451 |
+
mlflow.pytorch.log_model(model, "model")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
LOGGER.info("=" * 80)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 459 |
+
|
| 460 |
+
main()
|
learning_munsell/training/from_xyY/train_mlp_error_predictor.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train error predictor with advanced MLP architecture.
|
| 3 |
+
|
| 4 |
+
Architecture features:
|
| 5 |
+
- Larger capacity: 7 → 256 → 512 → 512 → 256 → 4
|
| 6 |
+
- Residual connections (MLP-style) for better gradient flow
|
| 7 |
+
- Modern activation functions (GELU instead of ReLU)
|
| 8 |
+
- Precision-focused loss function
|
| 9 |
+
|
| 10 |
+
Generic error predictor that can work with any base model.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
import click
|
| 18 |
+
import mlflow
|
| 19 |
+
import mlflow.pytorch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import onnxruntime as ort
|
| 22 |
+
import torch
|
| 23 |
+
from numpy.typing import NDArray
|
| 24 |
+
from torch import nn, optim
|
| 25 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 26 |
+
|
| 27 |
+
from learning_munsell import PROJECT_ROOT
|
| 28 |
+
from learning_munsell.models.networks import ResidualBlock
|
| 29 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 30 |
+
from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
|
| 31 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 32 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 33 |
+
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Note: This script has a custom ErrorPredictorMLP architecture
|
| 37 |
+
# so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor from shared modules.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ErrorPredictorMLP(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Advanced error predictor with residual connections.
|
| 43 |
+
|
| 44 |
+
This model implements a two-stage architecture for Munsell color prediction:
|
| 45 |
+
1. Base model makes initial predictions from xyY coordinates
|
| 46 |
+
2. Error predictor learns residual corrections to improve base predictions
|
| 47 |
+
|
| 48 |
+
The error predictor uses MLP-style residual blocks for better gradient
|
| 49 |
+
flow and deeper representations. It takes both the input xyY coordinates
|
| 50 |
+
and the base model's predictions to predict the error that should be added
|
| 51 |
+
to the base predictions.
|
| 52 |
+
|
| 53 |
+
Architecture:
|
| 54 |
+
- Input: 7 features (xyY_norm + base_pred_norm)
|
| 55 |
+
- Encoder: 7 → 256 → 512
|
| 56 |
+
- Residual blocks at 512-dim
|
| 57 |
+
- Decoder: 512 → 256 → 128 → 4
|
| 58 |
+
- Uses GELU activations and residual connections
|
| 59 |
+
|
| 60 |
+
Parameters
|
| 61 |
+
----------
|
| 62 |
+
num_residual_blocks : int, optional
|
| 63 |
+
Number of residual blocks to use in the middle of the network.
|
| 64 |
+
Default is 3.
|
| 65 |
+
|
| 66 |
+
Attributes
|
| 67 |
+
----------
|
| 68 |
+
encoder : nn.Sequential
|
| 69 |
+
Encoder network that maps 7D input to 512D representation.
|
| 70 |
+
residual_blocks : nn.ModuleList
|
| 71 |
+
List of residual blocks for deep feature extraction.
|
| 72 |
+
decoder : nn.Sequential
|
| 73 |
+
Decoder network that maps 512D representation to 4D error prediction.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, num_residual_blocks: int = 3) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
# Encoder
|
| 80 |
+
self.encoder = nn.Sequential(
|
| 81 |
+
nn.Linear(7, 256),
|
| 82 |
+
nn.GELU(),
|
| 83 |
+
nn.BatchNorm1d(256),
|
| 84 |
+
nn.Linear(256, 512),
|
| 85 |
+
nn.GELU(),
|
| 86 |
+
nn.BatchNorm1d(512),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Residual blocks
|
| 90 |
+
self.residual_blocks = nn.ModuleList(
|
| 91 |
+
[ResidualBlock(512) for _ in range(num_residual_blocks)]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Decoder
|
| 95 |
+
self.decoder = nn.Sequential(
|
| 96 |
+
nn.Linear(512, 256),
|
| 97 |
+
nn.GELU(),
|
| 98 |
+
nn.BatchNorm1d(256),
|
| 99 |
+
nn.Linear(256, 128),
|
| 100 |
+
nn.GELU(),
|
| 101 |
+
nn.BatchNorm1d(128),
|
| 102 |
+
nn.Linear(128, 4),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
Forward pass through the error predictor.
|
| 108 |
+
|
| 109 |
+
Parameters
|
| 110 |
+
----------
|
| 111 |
+
x : Tensor
|
| 112 |
+
Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
|
| 113 |
+
|
| 114 |
+
Returns
|
| 115 |
+
-------
|
| 116 |
+
Tensor
|
| 117 |
+
Predicted error correction of shape (batch_size, 4).
|
| 118 |
+
"""
|
| 119 |
+
# Encode
|
| 120 |
+
x = self.encoder(x)
|
| 121 |
+
|
| 122 |
+
# Residual blocks
|
| 123 |
+
for block in self.residual_blocks:
|
| 124 |
+
x = block(x)
|
| 125 |
+
|
| 126 |
+
# Decode
|
| 127 |
+
return self.decoder(x)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_base_model(
|
| 131 |
+
model_path: Path, params_path: Path
|
| 132 |
+
) -> tuple[ort.InferenceSession, dict, dict]:
|
| 133 |
+
"""
|
| 134 |
+
Load the base ONNX model and its normalization parameters.
|
| 135 |
+
|
| 136 |
+
The base model is the first stage of the two-stage architecture that makes
|
| 137 |
+
initial predictions from xyY coordinates to Munsell specifications.
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
model_path : Path
|
| 142 |
+
Path to the ONNX model file.
|
| 143 |
+
params_path : Path
|
| 144 |
+
Path to the .npz file containing input and output normalization parameters.
|
| 145 |
+
|
| 146 |
+
Returns
|
| 147 |
+
-------
|
| 148 |
+
session : ort.InferenceSession
|
| 149 |
+
ONNX Runtime inference session for the base model.
|
| 150 |
+
input_params : dict
|
| 151 |
+
Dictionary containing input normalization ranges (x_range, y_range, Y_range).
|
| 152 |
+
output_params : dict
|
| 153 |
+
Dictionary containing output normalization ranges (hue_range, value_range,
|
| 154 |
+
chroma_range, code_range).
|
| 155 |
+
"""
|
| 156 |
+
session = ort.InferenceSession(str(model_path))
|
| 157 |
+
params = np.load(params_path, allow_pickle=True)
|
| 158 |
+
return session, params["input_params"].item(), params["output_params"].item()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@click.command()
|
| 162 |
+
@click.option(
|
| 163 |
+
"--base-model",
|
| 164 |
+
type=click.Path(exists=True, path_type=Path),
|
| 165 |
+
help="Path to base model ONNX file",
|
| 166 |
+
)
|
| 167 |
+
@click.option(
|
| 168 |
+
"--params",
|
| 169 |
+
type=click.Path(exists=True, path_type=Path),
|
| 170 |
+
help="Path to normalization params file",
|
| 171 |
+
)
|
| 172 |
+
@click.option(
|
| 173 |
+
"--epochs",
|
| 174 |
+
type=int,
|
| 175 |
+
default=200,
|
| 176 |
+
help="Number of training epochs",
|
| 177 |
+
)
|
| 178 |
+
@click.option(
|
| 179 |
+
"--batch-size",
|
| 180 |
+
type=int,
|
| 181 |
+
default=1024,
|
| 182 |
+
help="Batch size for training",
|
| 183 |
+
)
|
| 184 |
+
@click.option(
|
| 185 |
+
"--lr",
|
| 186 |
+
type=float,
|
| 187 |
+
default=3e-4,
|
| 188 |
+
help="Learning rate",
|
| 189 |
+
)
|
| 190 |
+
@click.option(
|
| 191 |
+
"--patience",
|
| 192 |
+
type=int,
|
| 193 |
+
default=20,
|
| 194 |
+
help="Patience for early stopping",
|
| 195 |
+
)
|
| 196 |
+
def main(
|
| 197 |
+
base_model: Path | None,
|
| 198 |
+
params: Path | None,
|
| 199 |
+
epochs: int,
|
| 200 |
+
batch_size: int,
|
| 201 |
+
lr: float,
|
| 202 |
+
patience: int,
|
| 203 |
+
) -> None:
|
| 204 |
+
"""
|
| 205 |
+
Train error predictor with advanced MLP architecture.
|
| 206 |
+
|
| 207 |
+
Parameters
|
| 208 |
+
----------
|
| 209 |
+
base_model : Path or None
|
| 210 |
+
Path to the base model ONNX file. If None, uses default path.
|
| 211 |
+
params : Path or None
|
| 212 |
+
Path to normalization parameters .npz file. If None, uses default path.
|
| 213 |
+
|
| 214 |
+
Notes
|
| 215 |
+
-----
|
| 216 |
+
The training pipeline:
|
| 217 |
+
1. Loads pre-trained base model
|
| 218 |
+
2. Generates base model predictions for training data
|
| 219 |
+
3. Computes residual errors between predictions and targets
|
| 220 |
+
4. Trains error predictor on these residuals
|
| 221 |
+
5. Uses precision-focused loss function
|
| 222 |
+
6. Learning rate scheduling with ReduceLROnPlateau
|
| 223 |
+
7. Early stopping based on validation loss
|
| 224 |
+
8. Exports model to ONNX format
|
| 225 |
+
9. Logs metrics and artifacts to MLflow
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
LOGGER.info("=" * 80)
|
| 230 |
+
LOGGER.info("Error Predictor: MLP + GELU + Precision Loss")
|
| 231 |
+
LOGGER.info("=" * 80)
|
| 232 |
+
|
| 233 |
+
# Set device
|
| 234 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 235 |
+
LOGGER.info("Using device: %s", device)
|
| 236 |
+
|
| 237 |
+
# Paths
|
| 238 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 239 |
+
data_dir = PROJECT_ROOT / "data"
|
| 240 |
+
|
| 241 |
+
base_model_path = base_model
|
| 242 |
+
params_path = params
|
| 243 |
+
cache_file = data_dir / "training_data.npz"
|
| 244 |
+
|
| 245 |
+
# Extract base model name for error predictor naming
|
| 246 |
+
base_model_name = (
|
| 247 |
+
base_model_path.stem if base_model_path else "xyY_to_munsell_specification"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Load base model
|
| 251 |
+
LOGGER.info("")
|
| 252 |
+
LOGGER.info("Loading base model from %s...", base_model_path)
|
| 253 |
+
base_session, input_params, output_params = load_base_model(
|
| 254 |
+
base_model_path, params_path
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Load training data
|
| 258 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 259 |
+
data = np.load(cache_file)
|
| 260 |
+
X_train = data["X_train"]
|
| 261 |
+
y_train = data["y_train"]
|
| 262 |
+
X_val = data["X_val"]
|
| 263 |
+
y_val = data["y_val"]
|
| 264 |
+
|
| 265 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 266 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 267 |
+
|
| 268 |
+
# Generate base model predictions
|
| 269 |
+
LOGGER.info("")
|
| 270 |
+
LOGGER.info("Generating base model predictions...")
|
| 271 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 272 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 273 |
+
|
| 274 |
+
# Base predictions (normalized)
|
| 275 |
+
base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
|
| 276 |
+
|
| 277 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 278 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 279 |
+
base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
|
| 280 |
+
|
| 281 |
+
# Compute errors (in normalized space)
|
| 282 |
+
error_train = y_train_norm - base_pred_train_norm
|
| 283 |
+
error_val = y_val_norm - base_pred_val_norm
|
| 284 |
+
|
| 285 |
+
# Statistics
|
| 286 |
+
LOGGER.info("")
|
| 287 |
+
LOGGER.info("Base model error statistics (normalized space):")
|
| 288 |
+
LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
|
| 289 |
+
LOGGER.info(" Std of error: %.6f", np.std(error_train))
|
| 290 |
+
LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
|
| 291 |
+
|
| 292 |
+
# Create combined input: [xyY_norm, base_prediction_norm]
|
| 293 |
+
X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
|
| 294 |
+
X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
|
| 295 |
+
|
| 296 |
+
# Convert to PyTorch tensors
|
| 297 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 298 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 299 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 300 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 301 |
+
|
| 302 |
+
# Create data loaders
|
| 303 |
+
train_dataset = TensorDataset(X_train_t, error_train_t)
|
| 304 |
+
val_dataset = TensorDataset(X_val_t, error_val_t)
|
| 305 |
+
|
| 306 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 307 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 308 |
+
|
| 309 |
+
# Initialize error predictor model with MLP architecture
|
| 310 |
+
model = ErrorPredictorMLP(num_residual_blocks=3).to(device)
|
| 311 |
+
LOGGER.info("")
|
| 312 |
+
LOGGER.info("Error predictor architecture:")
|
| 313 |
+
LOGGER.info("%s", model)
|
| 314 |
+
|
| 315 |
+
# Count parameters
|
| 316 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 317 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 318 |
+
|
| 319 |
+
# Training setup with precision-focused loss
|
| 320 |
+
LOGGER.info("")
|
| 321 |
+
LOGGER.info("Using precision-focused loss function:")
|
| 322 |
+
LOGGER.info(" - MSE (weight: 1.0)")
|
| 323 |
+
LOGGER.info(" - MAE (weight: 0.5)")
|
| 324 |
+
LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
|
| 325 |
+
LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
|
| 326 |
+
|
| 327 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 328 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 329 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 330 |
+
)
|
| 331 |
+
criterion = precision_focused_loss
|
| 332 |
+
|
| 333 |
+
# MLflow setup
|
| 334 |
+
model_name = f"{base_model_name}_error_predictor"
|
| 335 |
+
run_name = setup_mlflow_experiment("from_xyY", model_name)
|
| 336 |
+
|
| 337 |
+
LOGGER.info("")
|
| 338 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 339 |
+
|
| 340 |
+
# Training loop
|
| 341 |
+
best_val_loss = float("inf")
|
| 342 |
+
patience_counter = 0
|
| 343 |
+
|
| 344 |
+
LOGGER.info("")
|
| 345 |
+
LOGGER.info("Starting training...")
|
| 346 |
+
|
| 347 |
+
with mlflow.start_run(run_name=run_name):
|
| 348 |
+
mlflow.log_params(
|
| 349 |
+
{
|
| 350 |
+
"model": model_name,
|
| 351 |
+
"base_model": base_model_name,
|
| 352 |
+
"learning_rate": lr,
|
| 353 |
+
"batch_size": batch_size,
|
| 354 |
+
"num_epochs": epochs,
|
| 355 |
+
"patience": patience,
|
| 356 |
+
"total_params": total_params,
|
| 357 |
+
}
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
for epoch in range(epochs):
|
| 361 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 362 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 363 |
+
|
| 364 |
+
# Update learning rate
|
| 365 |
+
scheduler.step(val_loss)
|
| 366 |
+
|
| 367 |
+
log_training_epoch(
|
| 368 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
LOGGER.info(
|
| 372 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 373 |
+
epoch + 1,
|
| 374 |
+
epochs,
|
| 375 |
+
train_loss,
|
| 376 |
+
val_loss,
|
| 377 |
+
optimizer.param_groups[0]["lr"],
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Early stopping
|
| 381 |
+
if val_loss < best_val_loss:
|
| 382 |
+
best_val_loss = val_loss
|
| 383 |
+
patience_counter = 0
|
| 384 |
+
|
| 385 |
+
# Save best model
|
| 386 |
+
model_directory.mkdir(exist_ok=True)
|
| 387 |
+
checkpoint_file = (
|
| 388 |
+
model_directory / f"{base_model_name}_error_predictor_best.pth"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
torch.save(
|
| 392 |
+
{
|
| 393 |
+
"model_state_dict": model.state_dict(),
|
| 394 |
+
"epoch": epoch,
|
| 395 |
+
"val_loss": val_loss,
|
| 396 |
+
},
|
| 397 |
+
checkpoint_file,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 401 |
+
else:
|
| 402 |
+
patience_counter += 1
|
| 403 |
+
if patience_counter >= patience:
|
| 404 |
+
LOGGER.info("")
|
| 405 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
mlflow.log_metrics(
|
| 409 |
+
{
|
| 410 |
+
"best_val_loss": best_val_loss,
|
| 411 |
+
"final_epoch": epoch + 1,
|
| 412 |
+
}
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Export to ONNX
|
| 416 |
+
LOGGER.info("")
|
| 417 |
+
LOGGER.info("Exporting error predictor to ONNX...")
|
| 418 |
+
model.eval()
|
| 419 |
+
|
| 420 |
+
# Load best model
|
| 421 |
+
checkpoint = torch.load(checkpoint_file)
|
| 422 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 423 |
+
|
| 424 |
+
# Create dummy input (xyY_norm + base_pred_norm = 7 inputs)
|
| 425 |
+
dummy_input = torch.randn(1, 7).to(device)
|
| 426 |
+
|
| 427 |
+
# Export
|
| 428 |
+
onnx_file = model_directory / f"{base_model_name}_error_predictor.onnx"
|
| 429 |
+
torch.onnx.export(
|
| 430 |
+
model,
|
| 431 |
+
dummy_input,
|
| 432 |
+
onnx_file,
|
| 433 |
+
export_params=True,
|
| 434 |
+
opset_version=15,
|
| 435 |
+
input_names=["combined_input"],
|
| 436 |
+
output_names=["error_correction"],
|
| 437 |
+
dynamic_axes={
|
| 438 |
+
"combined_input": {0: "batch_size"},
|
| 439 |
+
"error_correction": {0: "batch_size"},
|
| 440 |
+
},
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 444 |
+
mlflow.log_artifact(str(onnx_file))
|
| 445 |
+
mlflow.pytorch.log_model(model, "model")
|
| 446 |
+
|
| 447 |
+
LOGGER.info("Error predictor ONNX model saved to: %s", onnx_file)
|
| 448 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
LOGGER.info("=" * 80)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if __name__ == "__main__":
|
| 455 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 456 |
+
|
| 457 |
+
main()
|
learning_munsell/training/from_xyY/train_mlp_gamma.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train ML model for xyY to Munsell conversion with gamma-corrected Y.
|
| 3 |
+
|
| 4 |
+
Experiment: Apply gamma 2.33 to Y before normalization to better align
|
| 5 |
+
with perceptual lightness (Munsell Value scale is perceptually uniform).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import click
|
| 12 |
+
import mlflow
|
| 13 |
+
import mlflow.pytorch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from numpy.typing import NDArray
|
| 17 |
+
from torch import optim
|
| 18 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 19 |
+
|
| 20 |
+
from learning_munsell import PROJECT_ROOT
|
| 21 |
+
from learning_munsell.models.networks import MLPToMunsell
|
| 22 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 23 |
+
from learning_munsell.utilities.data import (
|
| 24 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 25 |
+
normalize_munsell,
|
| 26 |
+
)
|
| 27 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 28 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 29 |
+
|
| 30 |
+
LOGGER = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Gamma value for Y transformation
|
| 33 |
+
GAMMA = 2.33
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize_inputs(
|
| 37 |
+
X: NDArray, gamma: float = GAMMA
|
| 38 |
+
) -> tuple[NDArray, dict[str, Any]]:
|
| 39 |
+
"""
|
| 40 |
+
Normalize xyY inputs to [0, 1] range with gamma correction on Y.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
X : ndarray
|
| 45 |
+
xyY values of shape (n, 3) where columns are [x, y, Y].
|
| 46 |
+
gamma : float
|
| 47 |
+
Gamma value to apply to Y component.
|
| 48 |
+
|
| 49 |
+
Returns
|
| 50 |
+
-------
|
| 51 |
+
ndarray
|
| 52 |
+
Normalized values with gamma-corrected Y, dtype float32.
|
| 53 |
+
dict
|
| 54 |
+
Normalization parameters including gamma value.
|
| 55 |
+
"""
|
| 56 |
+
# Typical ranges for xyY
|
| 57 |
+
x_range = (0.0, 1.0)
|
| 58 |
+
y_range = (0.0, 1.0)
|
| 59 |
+
Y_range = (0.0, 1.0)
|
| 60 |
+
|
| 61 |
+
X_norm = X.copy()
|
| 62 |
+
X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
|
| 63 |
+
X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
|
| 64 |
+
|
| 65 |
+
# Normalize Y first, then apply gamma
|
| 66 |
+
Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
|
| 67 |
+
# Clip to avoid numerical issues with negative values
|
| 68 |
+
Y_normalized = np.clip(Y_normalized, 0, 1)
|
| 69 |
+
# Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light
|
| 70 |
+
X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
|
| 71 |
+
|
| 72 |
+
params = {
|
| 73 |
+
"x_range": x_range,
|
| 74 |
+
"y_range": y_range,
|
| 75 |
+
"Y_range": Y_range,
|
| 76 |
+
"gamma": gamma,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
return X_norm, params
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@click.command()
|
| 83 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 84 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 85 |
+
@click.option("--lr", default=5e-4, help="Learning rate")
|
| 86 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 87 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Train MLP model with gamma-corrected Y input.
|
| 90 |
+
|
| 91 |
+
Notes
|
| 92 |
+
-----
|
| 93 |
+
The training pipeline:
|
| 94 |
+
1. Loads training and validation data from cache
|
| 95 |
+
2. Normalizes inputs with gamma correction (gamma=2.33) on Y
|
| 96 |
+
3. Normalizes Munsell outputs to [0, 1] range
|
| 97 |
+
4. Trains MLP with weighted MSE loss
|
| 98 |
+
5. Uses early stopping based on validation loss
|
| 99 |
+
6. Exports best model to ONNX format
|
| 100 |
+
7. Logs metrics and artifacts to MLflow
|
| 101 |
+
|
| 102 |
+
The gamma correction on Y aligns with perceptual lightness. The gamma
|
| 103 |
+
transformation spreads dark values and compresses light values, matching
|
| 104 |
+
human lightness perception and the perceptually uniform Munsell Value scale.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
LOGGER.info("=" * 80)
|
| 108 |
+
LOGGER.info("ML-Based xyY to Munsell Conversion: Gamma Experiment")
|
| 109 |
+
LOGGER.info("Gamma = %.2f applied to Y component", GAMMA)
|
| 110 |
+
LOGGER.info("=" * 80)
|
| 111 |
+
|
| 112 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 113 |
+
LOGGER.info("Using device: %s", device)
|
| 114 |
+
|
| 115 |
+
# Load training data
|
| 116 |
+
data_dir = PROJECT_ROOT / "data"
|
| 117 |
+
cache_file = data_dir / "training_data.npz"
|
| 118 |
+
|
| 119 |
+
if not cache_file.exists():
|
| 120 |
+
LOGGER.error("Error: Training data not found at %s", cache_file)
|
| 121 |
+
LOGGER.error("Please run 01_generate_training_data.py first")
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 125 |
+
data = np.load(cache_file)
|
| 126 |
+
|
| 127 |
+
X_train = data["X_train"]
|
| 128 |
+
y_train = data["y_train"]
|
| 129 |
+
X_val = data["X_val"]
|
| 130 |
+
y_val = data["y_val"]
|
| 131 |
+
|
| 132 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 133 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 134 |
+
|
| 135 |
+
# Normalize data with gamma correction
|
| 136 |
+
X_train_norm, input_params = normalize_inputs(X_train, gamma=GAMMA)
|
| 137 |
+
X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA)
|
| 138 |
+
|
| 139 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 140 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 141 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 142 |
+
|
| 143 |
+
LOGGER.info("")
|
| 144 |
+
LOGGER.info("Input normalization with gamma=%.2f:", GAMMA)
|
| 145 |
+
LOGGER.info(" Y range after gamma: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max())
|
| 146 |
+
|
| 147 |
+
# Convert to PyTorch tensors
|
| 148 |
+
X_train_t = torch.FloatTensor(X_train_norm)
|
| 149 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 150 |
+
X_val_t = torch.FloatTensor(X_val_norm)
|
| 151 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 152 |
+
|
| 153 |
+
# Create data loaders
|
| 154 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 155 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 156 |
+
|
| 157 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 158 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 159 |
+
|
| 160 |
+
# Initialize model
|
| 161 |
+
model = MLPToMunsell().to(device)
|
| 162 |
+
LOGGER.info("")
|
| 163 |
+
LOGGER.info("Model architecture:")
|
| 164 |
+
LOGGER.info("%s", model)
|
| 165 |
+
|
| 166 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 167 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 168 |
+
|
| 169 |
+
# Training setup
|
| 170 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 171 |
+
# Component weights: emphasize chroma (2.0), de-emphasize code (0.5)
|
| 172 |
+
weights = torch.tensor([1.0, 1.0, 2.0, 0.5])
|
| 173 |
+
criterion = lambda pred, target: weighted_mse_loss(pred, target, weights)
|
| 174 |
+
|
| 175 |
+
# MLflow setup
|
| 176 |
+
run_name = setup_mlflow_experiment("from_xyY", f"mlp_gamma_{GAMMA}")
|
| 177 |
+
|
| 178 |
+
LOGGER.info("")
|
| 179 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 180 |
+
|
| 181 |
+
# Training loop
|
| 182 |
+
best_val_loss = float("inf")
|
| 183 |
+
patience_counter = 0
|
| 184 |
+
|
| 185 |
+
LOGGER.info("")
|
| 186 |
+
LOGGER.info("Starting training...")
|
| 187 |
+
|
| 188 |
+
with mlflow.start_run(run_name=run_name):
|
| 189 |
+
mlflow.log_params(
|
| 190 |
+
{
|
| 191 |
+
"num_epochs": epochs,
|
| 192 |
+
"batch_size": batch_size,
|
| 193 |
+
"learning_rate": lr,
|
| 194 |
+
"optimizer": "Adam",
|
| 195 |
+
"criterion": "weighted_mse_loss",
|
| 196 |
+
"patience": patience,
|
| 197 |
+
"total_params": total_params,
|
| 198 |
+
"gamma": GAMMA,
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
for epoch in range(epochs):
|
| 203 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 204 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 205 |
+
|
| 206 |
+
log_training_epoch(
|
| 207 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
LOGGER.info(
|
| 211 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 212 |
+
epoch + 1,
|
| 213 |
+
epochs,
|
| 214 |
+
train_loss,
|
| 215 |
+
val_loss,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if val_loss < best_val_loss:
|
| 219 |
+
best_val_loss = val_loss
|
| 220 |
+
patience_counter = 0
|
| 221 |
+
|
| 222 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 223 |
+
model_directory.mkdir(exist_ok=True)
|
| 224 |
+
checkpoint_file = model_directory / "mlp_gamma_best.pth"
|
| 225 |
+
|
| 226 |
+
torch.save(
|
| 227 |
+
{
|
| 228 |
+
"model_state_dict": model.state_dict(),
|
| 229 |
+
"input_params": input_params,
|
| 230 |
+
"output_params": output_params,
|
| 231 |
+
"epoch": epoch,
|
| 232 |
+
"val_loss": val_loss,
|
| 233 |
+
},
|
| 234 |
+
checkpoint_file,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 238 |
+
else:
|
| 239 |
+
patience_counter += 1
|
| 240 |
+
if patience_counter >= patience:
|
| 241 |
+
LOGGER.info("")
|
| 242 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
mlflow.log_metrics(
|
| 246 |
+
{
|
| 247 |
+
"best_val_loss": best_val_loss,
|
| 248 |
+
"final_epoch": epoch + 1,
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Export to ONNX
|
| 253 |
+
LOGGER.info("")
|
| 254 |
+
LOGGER.info("Exporting model to ONNX...")
|
| 255 |
+
model.eval()
|
| 256 |
+
|
| 257 |
+
checkpoint = torch.load(checkpoint_file)
|
| 258 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 259 |
+
|
| 260 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 261 |
+
|
| 262 |
+
onnx_file = model_directory / "mlp_gamma.onnx"
|
| 263 |
+
torch.onnx.export(
|
| 264 |
+
model,
|
| 265 |
+
dummy_input,
|
| 266 |
+
onnx_file,
|
| 267 |
+
export_params=True,
|
| 268 |
+
opset_version=15,
|
| 269 |
+
input_names=["xyY_gamma"],
|
| 270 |
+
output_names=["munsell_spec"],
|
| 271 |
+
dynamic_axes={"xyY_gamma": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Save normalization parameters (including gamma)
|
| 275 |
+
params_file = model_directory / "mlp_gamma_normalization_params.npz"
|
| 276 |
+
np.savez(
|
| 277 |
+
params_file,
|
| 278 |
+
input_params=input_params,
|
| 279 |
+
output_params=output_params,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 283 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 284 |
+
LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA)
|
| 285 |
+
|
| 286 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 287 |
+
mlflow.log_artifact(str(onnx_file))
|
| 288 |
+
mlflow.log_artifact(str(params_file))
|
| 289 |
+
mlflow.pytorch.log_model(model, "model")
|
| 290 |
+
|
| 291 |
+
LOGGER.info("=" * 80)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 296 |
+
|
| 297 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train second-stage error predictor for 3-stage model.
|
| 3 |
+
|
| 4 |
+
Architecture: Multi-Head + Multi-Error Predictor + Multi-Error Predictor
|
| 5 |
+
- Stage 1: Multi-Head base model (existing)
|
| 6 |
+
- Stage 2: First error predictor (existing)
|
| 7 |
+
- Stage 3: Second error predictor (this script) - learns residuals from stage 2
|
| 8 |
+
|
| 9 |
+
The second error predictor has the same architecture as the first but learns
|
| 10 |
+
the remaining errors after the first error correction is applied.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
import click
|
| 18 |
+
import mlflow
|
| 19 |
+
import mlflow.pytorch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import onnxruntime as ort
|
| 22 |
+
import torch
|
| 23 |
+
from numpy.typing import NDArray
|
| 24 |
+
from torch import nn, optim
|
| 25 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 26 |
+
|
| 27 |
+
from learning_munsell import PROJECT_ROOT
|
| 28 |
+
from learning_munsell.models.networks import (
|
| 29 |
+
ComponentErrorPredictor,
|
| 30 |
+
MultiHeadErrorPredictorToMunsell,
|
| 31 |
+
)
|
| 32 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 33 |
+
from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
|
| 34 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 35 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 36 |
+
|
| 37 |
+
LOGGER = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@click.command()
|
| 41 |
+
@click.option(
|
| 42 |
+
"--base-model",
|
| 43 |
+
type=click.Path(exists=True, path_type=Path),
|
| 44 |
+
default=None,
|
| 45 |
+
help="Path to Multi-Head base model ONNX file",
|
| 46 |
+
)
|
| 47 |
+
@click.option(
|
| 48 |
+
"--first-error-predictor",
|
| 49 |
+
type=click.Path(exists=True, path_type=Path),
|
| 50 |
+
default=None,
|
| 51 |
+
help="Path to first error predictor ONNX file",
|
| 52 |
+
)
|
| 53 |
+
@click.option(
|
| 54 |
+
"--params",
|
| 55 |
+
type=click.Path(exists=True, path_type=Path),
|
| 56 |
+
default=None,
|
| 57 |
+
help="Path to normalization params file",
|
| 58 |
+
)
|
| 59 |
+
@click.option(
|
| 60 |
+
"--epochs",
|
| 61 |
+
type=int,
|
| 62 |
+
default=300,
|
| 63 |
+
help="Number of training epochs (default: 300)",
|
| 64 |
+
)
|
| 65 |
+
@click.option(
|
| 66 |
+
"--batch-size",
|
| 67 |
+
type=int,
|
| 68 |
+
default=2048,
|
| 69 |
+
help="Batch size for training (default: 2048)",
|
| 70 |
+
)
|
| 71 |
+
@click.option(
|
| 72 |
+
"--lr",
|
| 73 |
+
type=float,
|
| 74 |
+
default=3e-4,
|
| 75 |
+
help="Learning rate (default: 3e-4)",
|
| 76 |
+
)
|
| 77 |
+
@click.option(
|
| 78 |
+
"--patience",
|
| 79 |
+
type=int,
|
| 80 |
+
default=30,
|
| 81 |
+
help="Early stopping patience (default: 30)",
|
| 82 |
+
)
|
| 83 |
+
def main(
|
| 84 |
+
base_model: Path | None,
|
| 85 |
+
first_error_predictor: Path | None,
|
| 86 |
+
params: Path | None,
|
| 87 |
+
epochs: int,
|
| 88 |
+
batch_size: int,
|
| 89 |
+
lr: float,
|
| 90 |
+
patience: int,
|
| 91 |
+
) -> None:
|
| 92 |
+
"""
|
| 93 |
+
Train the second-stage error predictor for the 3-stage model.
|
| 94 |
+
|
| 95 |
+
This script trains the third stage of a 3-stage model:
|
| 96 |
+
- Stage 1: Multi-Head base model (pre-trained)
|
| 97 |
+
- Stage 2: First error predictor (pre-trained)
|
| 98 |
+
- Stage 3: Second error predictor (trained by this script)
|
| 99 |
+
|
| 100 |
+
The second error predictor learns the residual errors remaining after
|
| 101 |
+
the first error correction is applied, further refining the predictions.
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
base_model : Path, optional
|
| 106 |
+
Path to the Multi-Head base model ONNX file.
|
| 107 |
+
Default: models/from_xyY/multi_head_large.onnx
|
| 108 |
+
first_error_predictor : Path, optional
|
| 109 |
+
Path to the first error predictor ONNX file.
|
| 110 |
+
Default: models/from_xyY/multi_head_multi_error_predictor_large.onnx
|
| 111 |
+
params : Path, optional
|
| 112 |
+
Path to the normalization parameters file.
|
| 113 |
+
Default: models/from_xyY/multi_head_large_normalization_params.npz
|
| 114 |
+
|
| 115 |
+
Notes
|
| 116 |
+
-----
|
| 117 |
+
The training pipeline:
|
| 118 |
+
1. Loads pre-trained Stage 1 and Stage 2 models
|
| 119 |
+
2. Generates Stage 2 predictions (base + first error correction)
|
| 120 |
+
3. Computes remaining residual errors
|
| 121 |
+
4. Trains Stage 3 error predictor on these residuals
|
| 122 |
+
5. Exports the model to ONNX format
|
| 123 |
+
6. Logs metrics and artifacts to MLflow
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
LOGGER.info("=" * 80)
|
| 127 |
+
LOGGER.info("Second Error Predictor: 3-Stage Model Training")
|
| 128 |
+
LOGGER.info("Multi-Head + Multi-Error Predictor + Multi-Error Predictor")
|
| 129 |
+
LOGGER.info("=" * 80)
|
| 130 |
+
|
| 131 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 132 |
+
if torch.backends.mps.is_available():
|
| 133 |
+
device = torch.device("mps")
|
| 134 |
+
LOGGER.info("Using device: %s", device)
|
| 135 |
+
|
| 136 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 137 |
+
data_dir = PROJECT_ROOT / "data"
|
| 138 |
+
|
| 139 |
+
if base_model is None:
|
| 140 |
+
base_model = model_directory / "multi_head_large.onnx"
|
| 141 |
+
if first_error_predictor is None:
|
| 142 |
+
first_error_predictor = model_directory / "multi_head_multi_error_predictor_large.onnx"
|
| 143 |
+
if params is None:
|
| 144 |
+
params = model_directory / "multi_head_large_normalization_params.npz"
|
| 145 |
+
|
| 146 |
+
cache_file = data_dir / "training_data_large.npz"
|
| 147 |
+
|
| 148 |
+
if not cache_file.exists():
|
| 149 |
+
LOGGER.error("Error: Large training data not found at %s", cache_file)
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
if not base_model.exists():
|
| 153 |
+
LOGGER.error("Error: Base model not found at %s", base_model)
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
if not first_error_predictor.exists():
|
| 157 |
+
LOGGER.error("Error: First error predictor not found at %s", first_error_predictor)
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
# Load models
|
| 161 |
+
LOGGER.info("")
|
| 162 |
+
LOGGER.info("Loading Stage 1: Multi-Head base model from %s...", base_model)
|
| 163 |
+
base_session = ort.InferenceSession(str(base_model))
|
| 164 |
+
|
| 165 |
+
LOGGER.info("Loading Stage 2: First error predictor from %s...", first_error_predictor)
|
| 166 |
+
error_predictor_session = ort.InferenceSession(str(first_error_predictor))
|
| 167 |
+
|
| 168 |
+
# Load normalization params
|
| 169 |
+
params_data = np.load(params, allow_pickle=True)
|
| 170 |
+
input_params = params_data["input_params"].item()
|
| 171 |
+
output_params = params_data["output_params"].item()
|
| 172 |
+
|
| 173 |
+
# Load training data
|
| 174 |
+
LOGGER.info("Loading large training data from %s...", cache_file)
|
| 175 |
+
data = np.load(cache_file)
|
| 176 |
+
X_train = data["X_train"]
|
| 177 |
+
y_train = data["y_train"]
|
| 178 |
+
X_val = data["X_val"]
|
| 179 |
+
y_val = data["y_val"]
|
| 180 |
+
|
| 181 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 182 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 183 |
+
|
| 184 |
+
# Generate stage 2 predictions (base + first error correction)
|
| 185 |
+
LOGGER.info("")
|
| 186 |
+
LOGGER.info("Computing Stage 2 predictions (base + first error correction)...")
|
| 187 |
+
|
| 188 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 189 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 190 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 191 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 192 |
+
|
| 193 |
+
inference_batch_size = 50000
|
| 194 |
+
|
| 195 |
+
# Stage 1: Base model predictions
|
| 196 |
+
LOGGER.info(" Stage 1: Base model predictions (training set)...")
|
| 197 |
+
base_pred_train = []
|
| 198 |
+
for i in range(0, len(X_train_norm), inference_batch_size):
|
| 199 |
+
batch = X_train_norm[i : i + inference_batch_size]
|
| 200 |
+
pred = base_session.run(None, {"xyY": batch})[0]
|
| 201 |
+
base_pred_train.append(pred)
|
| 202 |
+
base_pred_train = np.concatenate(base_pred_train, axis=0)
|
| 203 |
+
|
| 204 |
+
LOGGER.info(" Stage 1: Base model predictions (validation set)...")
|
| 205 |
+
base_pred_val = []
|
| 206 |
+
for i in range(0, len(X_val_norm), inference_batch_size):
|
| 207 |
+
batch = X_val_norm[i : i + inference_batch_size]
|
| 208 |
+
pred = base_session.run(None, {"xyY": batch})[0]
|
| 209 |
+
base_pred_val.append(pred)
|
| 210 |
+
base_pred_val = np.concatenate(base_pred_val, axis=0)
|
| 211 |
+
|
| 212 |
+
# Stage 2: First error predictor corrections
|
| 213 |
+
LOGGER.info(" Stage 2: First error predictor corrections (training set)...")
|
| 214 |
+
combined_train = np.concatenate([X_train_norm, base_pred_train], axis=1).astype(np.float32)
|
| 215 |
+
error_correction_train = []
|
| 216 |
+
for i in range(0, len(combined_train), inference_batch_size):
|
| 217 |
+
batch = combined_train[i : i + inference_batch_size]
|
| 218 |
+
correction = error_predictor_session.run(None, {"combined_input": batch})[0]
|
| 219 |
+
error_correction_train.append(correction)
|
| 220 |
+
error_correction_train = np.concatenate(error_correction_train, axis=0)
|
| 221 |
+
|
| 222 |
+
LOGGER.info(" Stage 2: First error predictor corrections (validation set)...")
|
| 223 |
+
combined_val = np.concatenate([X_val_norm, base_pred_val], axis=1).astype(np.float32)
|
| 224 |
+
error_correction_val = []
|
| 225 |
+
for i in range(0, len(combined_val), inference_batch_size):
|
| 226 |
+
batch = combined_val[i : i + inference_batch_size]
|
| 227 |
+
correction = error_predictor_session.run(None, {"combined_input": batch})[0]
|
| 228 |
+
error_correction_val.append(correction)
|
| 229 |
+
error_correction_val = np.concatenate(error_correction_val, axis=0)
|
| 230 |
+
|
| 231 |
+
# Stage 2 predictions (base + first error correction)
|
| 232 |
+
stage2_pred_train = base_pred_train + error_correction_train
|
| 233 |
+
stage2_pred_val = base_pred_val + error_correction_val
|
| 234 |
+
|
| 235 |
+
# Compute remaining errors for stage 3
|
| 236 |
+
error_train = y_train_norm - stage2_pred_train
|
| 237 |
+
error_val = y_val_norm - stage2_pred_val
|
| 238 |
+
|
| 239 |
+
# Statistics
|
| 240 |
+
LOGGER.info("")
|
| 241 |
+
LOGGER.info("Stage 2 prediction error statistics (normalized space):")
|
| 242 |
+
LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
|
| 243 |
+
LOGGER.info(" Std of error: %.6f", np.std(error_train))
|
| 244 |
+
LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
|
| 245 |
+
|
| 246 |
+
# Compare with stage 1 errors
|
| 247 |
+
stage1_error_train = y_train_norm - base_pred_train
|
| 248 |
+
LOGGER.info("")
|
| 249 |
+
LOGGER.info("Stage 1 (base only) error statistics for comparison:")
|
| 250 |
+
LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(stage1_error_train)))
|
| 251 |
+
LOGGER.info(" Std of error: %.6f", np.std(stage1_error_train))
|
| 252 |
+
|
| 253 |
+
error_reduction = (
|
| 254 |
+
(np.mean(np.abs(stage1_error_train)) - np.mean(np.abs(error_train)))
|
| 255 |
+
/ np.mean(np.abs(stage1_error_train))
|
| 256 |
+
* 100
|
| 257 |
+
)
|
| 258 |
+
LOGGER.info("")
|
| 259 |
+
LOGGER.info("Stage 2 error reduction vs Stage 1: %.1f%%", error_reduction)
|
| 260 |
+
|
| 261 |
+
# Create combined input for stage 3: [xyY_norm, stage2_pred_norm]
|
| 262 |
+
X_train_combined = np.concatenate([X_train_norm, stage2_pred_train], axis=1)
|
| 263 |
+
X_val_combined = np.concatenate([X_val_norm, stage2_pred_val], axis=1)
|
| 264 |
+
|
| 265 |
+
# Convert to PyTorch tensors
|
| 266 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 267 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 268 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 269 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 270 |
+
|
| 271 |
+
train_dataset = TensorDataset(X_train_t, error_train_t)
|
| 272 |
+
val_dataset = TensorDataset(X_val_t, error_val_t)
|
| 273 |
+
|
| 274 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 275 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 276 |
+
|
| 277 |
+
# Initialize second error predictor (same architecture as first)
|
| 278 |
+
model = MultiHeadErrorPredictorToMunsell().to(device)
|
| 279 |
+
LOGGER.info("")
|
| 280 |
+
LOGGER.info("Stage 3: Second error predictor architecture:")
|
| 281 |
+
LOGGER.info("%s", model)
|
| 282 |
+
|
| 283 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 284 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 285 |
+
|
| 286 |
+
# Training setup
|
| 287 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 288 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 289 |
+
optimizer, mode="min", factor=0.5, patience=10
|
| 290 |
+
)
|
| 291 |
+
criterion = precision_focused_loss
|
| 292 |
+
|
| 293 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head_3stage_error_predictor")
|
| 294 |
+
|
| 295 |
+
LOGGER.info("")
|
| 296 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 297 |
+
|
| 298 |
+
# Training loop
|
| 299 |
+
best_val_loss = float("inf")
|
| 300 |
+
patience_counter = 0
|
| 301 |
+
|
| 302 |
+
LOGGER.info("")
|
| 303 |
+
LOGGER.info("Starting Stage 3 training...")
|
| 304 |
+
|
| 305 |
+
with mlflow.start_run(run_name=run_name):
|
| 306 |
+
mlflow.log_params(
|
| 307 |
+
{
|
| 308 |
+
"model": "multi_head_3stage_error_predictor",
|
| 309 |
+
"num_epochs": epochs,
|
| 310 |
+
"batch_size": batch_size,
|
| 311 |
+
"learning_rate": lr,
|
| 312 |
+
"weight_decay": 1e-5,
|
| 313 |
+
"optimizer": "AdamW",
|
| 314 |
+
"scheduler": "ReduceLROnPlateau",
|
| 315 |
+
"criterion": "precision_focused_loss",
|
| 316 |
+
"patience": patience,
|
| 317 |
+
"total_params": total_params,
|
| 318 |
+
"train_samples": len(X_train),
|
| 319 |
+
"val_samples": len(X_val),
|
| 320 |
+
"stage2_error_reduction_pct": error_reduction,
|
| 321 |
+
}
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
for epoch in range(epochs):
|
| 325 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 326 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 327 |
+
|
| 328 |
+
scheduler.step(val_loss)
|
| 329 |
+
|
| 330 |
+
log_training_epoch(
|
| 331 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
LOGGER.info(
|
| 335 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 336 |
+
epoch + 1,
|
| 337 |
+
epochs,
|
| 338 |
+
train_loss,
|
| 339 |
+
val_loss,
|
| 340 |
+
optimizer.param_groups[0]["lr"],
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if val_loss < best_val_loss:
|
| 344 |
+
best_val_loss = val_loss
|
| 345 |
+
patience_counter = 0
|
| 346 |
+
|
| 347 |
+
model_directory.mkdir(exist_ok=True)
|
| 348 |
+
checkpoint_file = model_directory / "multi_head_3stage_error_predictor_best.pth"
|
| 349 |
+
|
| 350 |
+
torch.save(
|
| 351 |
+
{
|
| 352 |
+
"model_state_dict": model.state_dict(),
|
| 353 |
+
"epoch": epoch,
|
| 354 |
+
"val_loss": val_loss,
|
| 355 |
+
},
|
| 356 |
+
checkpoint_file,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 360 |
+
else:
|
| 361 |
+
patience_counter += 1
|
| 362 |
+
if patience_counter >= patience:
|
| 363 |
+
LOGGER.info("")
|
| 364 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
mlflow.log_metrics(
|
| 368 |
+
{
|
| 369 |
+
"best_val_loss": best_val_loss,
|
| 370 |
+
"final_epoch": epoch + 1,
|
| 371 |
+
}
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Export to ONNX
|
| 375 |
+
LOGGER.info("")
|
| 376 |
+
LOGGER.info("Exporting Stage 3 error predictor to ONNX...")
|
| 377 |
+
model.eval()
|
| 378 |
+
|
| 379 |
+
checkpoint = torch.load(checkpoint_file, weights_only=False)
|
| 380 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 381 |
+
|
| 382 |
+
dummy_input = torch.randn(1, 7).to(device)
|
| 383 |
+
|
| 384 |
+
onnx_file = model_directory / "multi_head_3stage_error_predictor.onnx"
|
| 385 |
+
torch.onnx.export(
|
| 386 |
+
model,
|
| 387 |
+
dummy_input,
|
| 388 |
+
onnx_file,
|
| 389 |
+
export_params=True,
|
| 390 |
+
opset_version=15,
|
| 391 |
+
input_names=["combined_input"],
|
| 392 |
+
output_names=["error_correction"],
|
| 393 |
+
dynamic_axes={
|
| 394 |
+
"combined_input": {0: "batch_size"},
|
| 395 |
+
"error_correction": {0: "batch_size"},
|
| 396 |
+
},
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
LOGGER.info("Stage 3 error predictor ONNX model saved to: %s", onnx_file)
|
| 400 |
+
|
| 401 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 402 |
+
mlflow.log_artifact(str(onnx_file))
|
| 403 |
+
mlflow.pytorch.log_model(model, "model")
|
| 404 |
+
|
| 405 |
+
LOGGER.info("=" * 80)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 410 |
+
|
| 411 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_circular.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Multi-Head model with circular hue loss for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This version uses circular loss for the hue component (which wraps from 0-10)
|
| 5 |
+
to avoid penalizing predictions near the boundary.
|
| 6 |
+
|
| 7 |
+
Key Difference from Standard Training:
|
| 8 |
+
- Uses munsell_component_loss() which applies circular MSE for hue
|
| 9 |
+
- and regular MSE for value/chroma/code components
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
import click
|
| 18 |
+
import mlflow
|
| 19 |
+
import mlflow.pytorch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn, optim
|
| 23 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 24 |
+
|
| 25 |
+
from learning_munsell import PROJECT_ROOT
|
| 26 |
+
from learning_munsell.utilities.common import setup_mlflow_experiment
|
| 27 |
+
from learning_munsell.utilities.data import (
|
| 28 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 29 |
+
normalize_munsell,
|
| 30 |
+
)
|
| 31 |
+
from learning_munsell.training.from_xyY.hyperparameter_search_multi_head import (
|
| 32 |
+
MultiHeadParametric,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
LOGGER = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def circular_mse_loss(
|
| 39 |
+
pred_hue: torch.Tensor, target_hue: torch.Tensor, hue_range: float = 1.0
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
"""
|
| 42 |
+
Circular MSE loss for hue component (normalized 0-1).
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
pred_hue : Tensor
|
| 47 |
+
Predicted hue values (normalized 0-1)
|
| 48 |
+
target_hue : Tensor
|
| 49 |
+
Target hue values (normalized 0-1)
|
| 50 |
+
hue_range : float
|
| 51 |
+
Range of hue values (1.0 for normalized)
|
| 52 |
+
|
| 53 |
+
Returns
|
| 54 |
+
-------
|
| 55 |
+
Tensor
|
| 56 |
+
Circular MSE loss
|
| 57 |
+
"""
|
| 58 |
+
diff = torch.abs(pred_hue - target_hue)
|
| 59 |
+
circular_diff = torch.min(diff, hue_range - diff)
|
| 60 |
+
return torch.mean(circular_diff**2)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def munsell_component_loss(
|
| 64 |
+
pred: torch.Tensor, target: torch.Tensor, hue_range: float = 1.0
|
| 65 |
+
) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Component-wise loss for Munsell predictions.
|
| 68 |
+
|
| 69 |
+
Uses circular MSE for hue (component 0) and regular MSE
|
| 70 |
+
for value, chroma, code (components 1-3).
|
| 71 |
+
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
pred : Tensor
|
| 75 |
+
Predictions [hue, value, chroma, code] (shape: [batch, 4])
|
| 76 |
+
target : Tensor
|
| 77 |
+
Ground truth [hue, value, chroma, code] (shape: [batch, 4])
|
| 78 |
+
hue_range : float
|
| 79 |
+
Range of normalized hue values (default 1.0)
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
Tensor
|
| 84 |
+
Combined loss
|
| 85 |
+
"""
|
| 86 |
+
hue_loss = circular_mse_loss(pred[:, 0], target[:, 0], hue_range)
|
| 87 |
+
other_loss = nn.functional.mse_loss(pred[:, 1:], target[:, 1:])
|
| 88 |
+
return hue_loss + other_loss
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@click.command()
|
| 92 |
+
@click.option("--epochs", default=300, help="Number of training epochs")
|
| 93 |
+
@click.option("--batch-size", default=512, help="Batch size for training")
|
| 94 |
+
@click.option("--lr", default=0.000837, help="Learning rate")
|
| 95 |
+
@click.option("--patience", default=30, help="Early stopping patience")
|
| 96 |
+
def main(
|
| 97 |
+
epochs: int,
|
| 98 |
+
batch_size: int,
|
| 99 |
+
lr: float,
|
| 100 |
+
patience: int,
|
| 101 |
+
encoder_width: float = 0.75,
|
| 102 |
+
head_width: float = 1.5,
|
| 103 |
+
chroma_head_width: float = 1.5,
|
| 104 |
+
dropout: float = 0.0,
|
| 105 |
+
weight_decay: float = 0.000013,
|
| 106 |
+
) -> tuple[MultiHeadParametric, float]:
|
| 107 |
+
"""
|
| 108 |
+
Train Multi-Head model with circular hue loss.
|
| 109 |
+
|
| 110 |
+
This script uses circular loss for the hue component (which wraps from
|
| 111 |
+
0-10) to avoid penalizing predictions near the boundary.
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
epochs : int, optional
|
| 116 |
+
Maximum number of training epochs.
|
| 117 |
+
batch_size : int, optional
|
| 118 |
+
Training batch size.
|
| 119 |
+
lr : float, optional
|
| 120 |
+
Learning rate for AdamW optimizer.
|
| 121 |
+
encoder_width : float, optional
|
| 122 |
+
Width multiplier for the shared encoder.
|
| 123 |
+
head_width : float, optional
|
| 124 |
+
Width multiplier for hue, value, and code heads.
|
| 125 |
+
chroma_head_width : float, optional
|
| 126 |
+
Width multiplier for chroma head (typically larger).
|
| 127 |
+
dropout : float, optional
|
| 128 |
+
Dropout rate for regularization.
|
| 129 |
+
weight_decay : float, optional
|
| 130 |
+
Weight decay for AdamW optimizer.
|
| 131 |
+
|
| 132 |
+
Returns
|
| 133 |
+
-------
|
| 134 |
+
model : MultiHeadParametric
|
| 135 |
+
Trained model with best validation loss weights.
|
| 136 |
+
best_val_loss : float
|
| 137 |
+
Best validation loss achieved during training.
|
| 138 |
+
|
| 139 |
+
Notes
|
| 140 |
+
-----
|
| 141 |
+
The training pipeline:
|
| 142 |
+
1. Loads training data from cache
|
| 143 |
+
2. Normalizes outputs to [0, 1] range
|
| 144 |
+
3. Trains with circular MSE for hue and regular MSE for other components
|
| 145 |
+
4. Uses CosineAnnealingLR scheduler
|
| 146 |
+
5. Early stopping based on validation loss
|
| 147 |
+
6. Exports model to ONNX format
|
| 148 |
+
7. Logs metrics and artifacts to MLflow
|
| 149 |
+
|
| 150 |
+
The circular loss experiment showed that while mathematically correct,
|
| 151 |
+
the circular distance creates gradient discontinuities that harm
|
| 152 |
+
optimization. This model is included for comparison purposes.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
LOGGER.info("=" * 80)
|
| 156 |
+
LOGGER.info("Training Multi-Head (Circular Hue Loss) for xyY to Munsell conversion")
|
| 157 |
+
LOGGER.info("=" * 80)
|
| 158 |
+
LOGGER.info("")
|
| 159 |
+
LOGGER.info("Using Circular Loss for Hue Component")
|
| 160 |
+
LOGGER.info("=" * 80)
|
| 161 |
+
LOGGER.info("")
|
| 162 |
+
LOGGER.info("Hyperparameters:")
|
| 163 |
+
LOGGER.info(" lr: %.6f", lr)
|
| 164 |
+
LOGGER.info(" batch_size: %d", batch_size)
|
| 165 |
+
LOGGER.info(" encoder_width: %.2f", encoder_width)
|
| 166 |
+
LOGGER.info(" head_width: %.2f", head_width)
|
| 167 |
+
LOGGER.info(" chroma_head_width: %.2f", chroma_head_width)
|
| 168 |
+
LOGGER.info(" dropout: %.2f", dropout)
|
| 169 |
+
LOGGER.info(" weight_decay: %.6f", weight_decay)
|
| 170 |
+
LOGGER.info("")
|
| 171 |
+
|
| 172 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 173 |
+
LOGGER.info("Using device: %s", device)
|
| 174 |
+
|
| 175 |
+
# Load data from cache
|
| 176 |
+
data_dir = PROJECT_ROOT / "data"
|
| 177 |
+
cache_file = data_dir / "training_data.npz"
|
| 178 |
+
data = np.load(cache_file)
|
| 179 |
+
|
| 180 |
+
X_train = data["X_train"]
|
| 181 |
+
y_train = data["y_train"]
|
| 182 |
+
X_val = data["X_val"]
|
| 183 |
+
y_val = data["y_val"]
|
| 184 |
+
|
| 185 |
+
LOGGER.info("Training samples: %d", len(X_train))
|
| 186 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 187 |
+
|
| 188 |
+
# Normalize outputs (xyY inputs already in [0, 1] range)
|
| 189 |
+
# Use shared normalization parameters covering the full Munsell space for generalization
|
| 190 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 191 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 192 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 193 |
+
|
| 194 |
+
# Convert to tensors
|
| 195 |
+
X_train_t = torch.from_numpy(X_train).float()
|
| 196 |
+
y_train_t = torch.from_numpy(y_train_norm).float()
|
| 197 |
+
X_val_t = torch.from_numpy(X_val).float()
|
| 198 |
+
y_val_t = torch.from_numpy(y_val_norm).float()
|
| 199 |
+
|
| 200 |
+
train_loader = DataLoader(
|
| 201 |
+
TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True
|
| 202 |
+
)
|
| 203 |
+
val_loader = DataLoader(
|
| 204 |
+
TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Create model
|
| 208 |
+
model = MultiHeadParametric(
|
| 209 |
+
encoder_width=encoder_width,
|
| 210 |
+
head_width=head_width,
|
| 211 |
+
chroma_head_width=chroma_head_width,
|
| 212 |
+
dropout=dropout,
|
| 213 |
+
).to(device)
|
| 214 |
+
|
| 215 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 216 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
| 217 |
+
|
| 218 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 219 |
+
LOGGER.info("")
|
| 220 |
+
LOGGER.info("Model parameters: %s", f"{total_params:,}")
|
| 221 |
+
|
| 222 |
+
encoder_params = sum(p.numel() for p in model.encoder.parameters())
|
| 223 |
+
hue_params = sum(p.numel() for p in model.hue_head.parameters())
|
| 224 |
+
value_params = sum(p.numel() for p in model.value_head.parameters())
|
| 225 |
+
chroma_params = sum(p.numel() for p in model.chroma_head.parameters())
|
| 226 |
+
code_params = sum(p.numel() for p in model.code_head.parameters())
|
| 227 |
+
|
| 228 |
+
LOGGER.info(" - Shared encoder (%.2fx): %s", encoder_width, f"{encoder_params:,}")
|
| 229 |
+
LOGGER.info(" - Hue head (%.2fx): %s", head_width, f"{hue_params:,}")
|
| 230 |
+
LOGGER.info(" - Value head (%.2fx): %s", head_width, f"{value_params:,}")
|
| 231 |
+
LOGGER.info(" - Chroma head (%.2fx): %s", chroma_head_width, f"{chroma_params:,}")
|
| 232 |
+
LOGGER.info(" - Code head (%.2fx): %s", head_width, f"{code_params:,}")
|
| 233 |
+
|
| 234 |
+
# MLflow setup
|
| 235 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head_circular")
|
| 236 |
+
LOGGER.info("")
|
| 237 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 238 |
+
|
| 239 |
+
best_val_loss = float("inf")
|
| 240 |
+
best_state = None
|
| 241 |
+
patience_counter = 0
|
| 242 |
+
|
| 243 |
+
LOGGER.info("")
|
| 244 |
+
LOGGER.info("Starting training with circular hue loss...")
|
| 245 |
+
|
| 246 |
+
with mlflow.start_run(run_name=run_name):
|
| 247 |
+
mlflow.log_params(
|
| 248 |
+
{
|
| 249 |
+
"model": "multi_head_circular",
|
| 250 |
+
"encoder_width": encoder_width,
|
| 251 |
+
"head_width": head_width,
|
| 252 |
+
"chroma_head_width": chroma_head_width,
|
| 253 |
+
"dropout": dropout,
|
| 254 |
+
"learning_rate": lr,
|
| 255 |
+
"batch_size": batch_size,
|
| 256 |
+
"weight_decay": weight_decay,
|
| 257 |
+
"epochs": epochs,
|
| 258 |
+
"patience": patience,
|
| 259 |
+
"total_params": total_params,
|
| 260 |
+
"loss_type": "circular_hue",
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
for epoch in range(epochs):
|
| 265 |
+
# Training
|
| 266 |
+
model.train()
|
| 267 |
+
train_loss = 0.0
|
| 268 |
+
for X_batch, y_batch in train_loader:
|
| 269 |
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
| 270 |
+
|
| 271 |
+
optimizer.zero_grad()
|
| 272 |
+
pred = model(X_batch)
|
| 273 |
+
|
| 274 |
+
# Use circular loss for hue component
|
| 275 |
+
loss = munsell_component_loss(pred, y_batch, hue_range=1.0)
|
| 276 |
+
|
| 277 |
+
loss.backward()
|
| 278 |
+
optimizer.step()
|
| 279 |
+
train_loss += loss.item() * len(X_batch)
|
| 280 |
+
|
| 281 |
+
train_loss /= len(X_train_t)
|
| 282 |
+
scheduler.step()
|
| 283 |
+
|
| 284 |
+
# Validation
|
| 285 |
+
model.eval()
|
| 286 |
+
val_loss = 0.0
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
for X_batch, y_batch in val_loader:
|
| 289 |
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
| 290 |
+
pred = model(X_batch)
|
| 291 |
+
val_loss += munsell_component_loss(
|
| 292 |
+
pred, y_batch, hue_range=1.0
|
| 293 |
+
).item() * len(X_batch)
|
| 294 |
+
val_loss /= len(X_val_t)
|
| 295 |
+
|
| 296 |
+
# Per-component MAE (denormalized for interpretability)
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
pred_val = model(X_val_t.to(device)).cpu()
|
| 299 |
+
# Denormalize predictions and ground truth
|
| 300 |
+
pred_denorm = pred_val.numpy()
|
| 301 |
+
hue_min, hue_max = output_params["hue_range"]
|
| 302 |
+
value_min, value_max = output_params["value_range"]
|
| 303 |
+
chroma_min, chroma_max = output_params["chroma_range"]
|
| 304 |
+
code_min, code_max = output_params["code_range"]
|
| 305 |
+
|
| 306 |
+
pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min # hue
|
| 307 |
+
pred_denorm[:, 1] = pred_val[:, 1].numpy() * (value_max - value_min) + value_min # value
|
| 308 |
+
pred_denorm[:, 2] = pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min # chroma
|
| 309 |
+
pred_denorm[:, 3] = pred_val[:, 3].numpy() * (code_max - code_min) + code_min # code
|
| 310 |
+
|
| 311 |
+
y_denorm = y_val_norm.copy()
|
| 312 |
+
y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min
|
| 313 |
+
y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min
|
| 314 |
+
y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min
|
| 315 |
+
y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min
|
| 316 |
+
|
| 317 |
+
mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0)
|
| 318 |
+
|
| 319 |
+
mlflow.log_metrics(
|
| 320 |
+
{
|
| 321 |
+
"train_loss": train_loss,
|
| 322 |
+
"val_loss": val_loss,
|
| 323 |
+
"mae_hue": mae[0],
|
| 324 |
+
"mae_value": mae[1],
|
| 325 |
+
"mae_chroma": mae[2],
|
| 326 |
+
"mae_code": mae[3],
|
| 327 |
+
},
|
| 328 |
+
step=epoch,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if val_loss < best_val_loss:
|
| 332 |
+
best_val_loss = val_loss
|
| 333 |
+
best_state = copy.deepcopy(model.state_dict())
|
| 334 |
+
patience_counter = 0
|
| 335 |
+
LOGGER.info(
|
| 336 |
+
"Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - "
|
| 337 |
+
"MAE: hue=%.4f, value=%.4f, chroma=%.4f, code=%.4f",
|
| 338 |
+
epoch + 1,
|
| 339 |
+
epochs,
|
| 340 |
+
train_loss,
|
| 341 |
+
val_loss,
|
| 342 |
+
mae[0],
|
| 343 |
+
mae[1],
|
| 344 |
+
mae[2],
|
| 345 |
+
mae[3],
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
patience_counter += 1
|
| 349 |
+
if (epoch + 1) % 50 == 0:
|
| 350 |
+
LOGGER.info(
|
| 351 |
+
"Epoch %03d/%d - Train: %.6f, Val: %.6f",
|
| 352 |
+
epoch + 1,
|
| 353 |
+
epochs,
|
| 354 |
+
train_loss,
|
| 355 |
+
val_loss,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
if patience_counter >= patience:
|
| 359 |
+
LOGGER.info("Early stopping at epoch %d", epoch + 1)
|
| 360 |
+
break
|
| 361 |
+
|
| 362 |
+
# Load best model
|
| 363 |
+
model.load_state_dict(best_state)
|
| 364 |
+
|
| 365 |
+
# Final evaluation
|
| 366 |
+
model.eval()
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
pred_val = model(X_val_t.to(device)).cpu()
|
| 369 |
+
pred_denorm = pred_val.numpy()
|
| 370 |
+
hue_min, hue_max = output_params["hue_range"]
|
| 371 |
+
value_min, value_max = output_params["value_range"]
|
| 372 |
+
chroma_min, chroma_max = output_params["chroma_range"]
|
| 373 |
+
code_min, code_max = output_params["code_range"]
|
| 374 |
+
|
| 375 |
+
pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min
|
| 376 |
+
pred_denorm[:, 1] = pred_val[:, 1].numpy() * (value_max - value_min) + value_min
|
| 377 |
+
pred_denorm[:, 2] = pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min
|
| 378 |
+
pred_denorm[:, 3] = pred_val[:, 3].numpy() * (code_max - code_min) + code_min
|
| 379 |
+
|
| 380 |
+
y_denorm = y_val_norm.copy()
|
| 381 |
+
y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min
|
| 382 |
+
y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min
|
| 383 |
+
y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min
|
| 384 |
+
y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min
|
| 385 |
+
|
| 386 |
+
mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0)
|
| 387 |
+
|
| 388 |
+
# Log final metrics
|
| 389 |
+
mlflow.log_metrics(
|
| 390 |
+
{
|
| 391 |
+
"best_val_loss": best_val_loss,
|
| 392 |
+
"final_mae_hue": mae[0],
|
| 393 |
+
"final_mae_value": mae[1],
|
| 394 |
+
"final_mae_chroma": mae[2],
|
| 395 |
+
"final_mae_code": mae[3],
|
| 396 |
+
"final_epoch": epoch + 1,
|
| 397 |
+
}
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
LOGGER.info("")
|
| 401 |
+
LOGGER.info("Final Results:")
|
| 402 |
+
LOGGER.info(" Best Val Loss: %.6f", best_val_loss)
|
| 403 |
+
LOGGER.info(" MAE hue: %.6f", mae[0])
|
| 404 |
+
LOGGER.info(" MAE value: %.6f", mae[1])
|
| 405 |
+
LOGGER.info(" MAE chroma: %.6f", mae[2])
|
| 406 |
+
LOGGER.info(" MAE code: %.6f", mae[3])
|
| 407 |
+
|
| 408 |
+
# Save model
|
| 409 |
+
models_dir = PROJECT_ROOT / "models" / "from_xyY"
|
| 410 |
+
models_dir.mkdir(exist_ok=True)
|
| 411 |
+
|
| 412 |
+
checkpoint_path = models_dir / "multi_head_circular.pth"
|
| 413 |
+
torch.save(
|
| 414 |
+
{
|
| 415 |
+
"model_state_dict": model.state_dict(),
|
| 416 |
+
"output_params": output_params,
|
| 417 |
+
"val_loss": best_val_loss,
|
| 418 |
+
"mae": {
|
| 419 |
+
"hue": float(mae[0]),
|
| 420 |
+
"value": float(mae[1]),
|
| 421 |
+
"chroma": float(mae[2]),
|
| 422 |
+
"code": float(mae[3]),
|
| 423 |
+
},
|
| 424 |
+
"hyperparameters": {
|
| 425 |
+
"encoder_width": encoder_width,
|
| 426 |
+
"head_width": head_width,
|
| 427 |
+
"chroma_head_width": chroma_head_width,
|
| 428 |
+
"dropout": dropout,
|
| 429 |
+
"lr": lr,
|
| 430 |
+
"batch_size": batch_size,
|
| 431 |
+
"weight_decay": weight_decay,
|
| 432 |
+
},
|
| 433 |
+
"loss_type": "circular_hue",
|
| 434 |
+
},
|
| 435 |
+
checkpoint_path,
|
| 436 |
+
)
|
| 437 |
+
LOGGER.info("")
|
| 438 |
+
LOGGER.info("Saved checkpoint: %s", checkpoint_path)
|
| 439 |
+
|
| 440 |
+
# Export to ONNX
|
| 441 |
+
model.cpu().eval()
|
| 442 |
+
dummy_input = torch.randn(1, 3)
|
| 443 |
+
onnx_path = models_dir / "multi_head_circular.onnx"
|
| 444 |
+
|
| 445 |
+
torch.onnx.export(
|
| 446 |
+
model,
|
| 447 |
+
dummy_input,
|
| 448 |
+
onnx_path,
|
| 449 |
+
input_names=["xyY"], # Match other models for comparison compatibility
|
| 450 |
+
output_names=["munsell_spec"],
|
| 451 |
+
dynamic_axes={"xyY": {0: "batch"}, "munsell_spec": {0: "batch"}},
|
| 452 |
+
opset_version=17,
|
| 453 |
+
)
|
| 454 |
+
LOGGER.info("Saved ONNX: %s", onnx_path)
|
| 455 |
+
|
| 456 |
+
# Save normalization parameters
|
| 457 |
+
params_path = models_dir / "multi_head_circular_normalization_params.npz"
|
| 458 |
+
np.savez(
|
| 459 |
+
params_path,
|
| 460 |
+
output_params=output_params,
|
| 461 |
+
)
|
| 462 |
+
LOGGER.info("Saved normalization parameters: %s", params_path)
|
| 463 |
+
|
| 464 |
+
# Log artifacts to MLflow
|
| 465 |
+
mlflow.log_artifact(str(checkpoint_path))
|
| 466 |
+
mlflow.log_artifact(str(onnx_path))
|
| 467 |
+
mlflow.log_artifact(str(params_path))
|
| 468 |
+
mlflow.pytorch.log_model(model, "model")
|
| 469 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 470 |
+
|
| 471 |
+
LOGGER.info("=" * 80)
|
| 472 |
+
|
| 473 |
+
return model, best_val_loss
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
if __name__ == "__main__":
|
| 477 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 478 |
+
|
| 479 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Multi-Head + Cross-Attention Error Predictor for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This version uses cross-attention between component branches to learn
|
| 5 |
+
correlations between errors in different Munsell components.
|
| 6 |
+
|
| 7 |
+
Key Features:
|
| 8 |
+
- Shared context encoder
|
| 9 |
+
- Multi-head cross-attention between components
|
| 10 |
+
- Component-specific prediction heads
|
| 11 |
+
- Residual connections
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
import mlflow
|
| 20 |
+
import mlflow.pytorch
|
| 21 |
+
import numpy as np
|
| 22 |
+
import onnxruntime as ort
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn, optim
|
| 25 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 26 |
+
|
| 27 |
+
from learning_munsell import PROJECT_ROOT
|
| 28 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 29 |
+
from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
|
| 30 |
+
|
| 31 |
+
LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Note: This script has a custom CrossAttentionErrorPredictor architecture
|
| 34 |
+
# so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor from shared modules.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CustomMultiheadAttention(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Custom multi-head attention that exports cleanly to ONNX.
|
| 40 |
+
|
| 41 |
+
Uses basic operations instead of nn.MultiheadAttention to avoid
|
| 42 |
+
reshape issues with dynamic batch sizes during ONNX export.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
embed_dim : int
|
| 47 |
+
Total dimension of the model (must be divisible by num_heads).
|
| 48 |
+
num_heads : int
|
| 49 |
+
Number of parallel attention heads.
|
| 50 |
+
dropout : float, optional
|
| 51 |
+
Dropout probability on attention weights.
|
| 52 |
+
|
| 53 |
+
Attributes
|
| 54 |
+
----------
|
| 55 |
+
embed_dim : int
|
| 56 |
+
Total embedding dimension.
|
| 57 |
+
num_heads : int
|
| 58 |
+
Number of attention heads.
|
| 59 |
+
head_dim : int
|
| 60 |
+
Dimension of each attention head (embed_dim // num_heads).
|
| 61 |
+
scale : float
|
| 62 |
+
Scaling factor for attention scores (head_dim ** -0.5).
|
| 63 |
+
q_proj : nn.Linear
|
| 64 |
+
Query projection layer.
|
| 65 |
+
k_proj : nn.Linear
|
| 66 |
+
Key projection layer.
|
| 67 |
+
v_proj : nn.Linear
|
| 68 |
+
Value projection layer.
|
| 69 |
+
out_proj : nn.Linear
|
| 70 |
+
Output projection layer.
|
| 71 |
+
dropout : nn.Dropout
|
| 72 |
+
Dropout layer for attention weights.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
embed_dim: int,
|
| 78 |
+
num_heads: int,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
) -> None:
|
| 81 |
+
"""Initialize the custom multi-head attention module."""
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 85 |
+
|
| 86 |
+
self.embed_dim = embed_dim
|
| 87 |
+
self.num_heads = num_heads
|
| 88 |
+
self.head_dim = embed_dim // num_heads
|
| 89 |
+
self.scale = self.head_dim**-0.5
|
| 90 |
+
|
| 91 |
+
# Linear projections for Q, K, V
|
| 92 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 93 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 94 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 95 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 96 |
+
|
| 97 |
+
self.dropout = nn.Dropout(dropout)
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
"""
|
| 101 |
+
Forward pass for self-attention.
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
x : Tensor
|
| 106 |
+
Input tensor [batch, seq_len, embed_dim]
|
| 107 |
+
|
| 108 |
+
Returns
|
| 109 |
+
-------
|
| 110 |
+
Tensor
|
| 111 |
+
Output tensor [batch, seq_len, embed_dim]
|
| 112 |
+
"""
|
| 113 |
+
batch_size, seq_len, embed_dim = x.shape
|
| 114 |
+
|
| 115 |
+
# Project to Q, K, V
|
| 116 |
+
q = self.q_proj(x) # [batch, seq_len, embed_dim]
|
| 117 |
+
k = self.k_proj(x) # [batch, seq_len, embed_dim]
|
| 118 |
+
v = self.v_proj(x) # [batch, seq_len, embed_dim]
|
| 119 |
+
|
| 120 |
+
# Reshape for multi-head attention: [batch, seq_len, num_heads, head_dim]
|
| 121 |
+
# Then transpose to: [batch, num_heads, seq_len, head_dim]
|
| 122 |
+
# Use -1 for batch dimension to enable dynamic batch size in ONNX
|
| 123 |
+
q = q.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 124 |
+
k = k.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 125 |
+
v = v.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 126 |
+
|
| 127 |
+
# Scaled dot-product attention
|
| 128 |
+
# Q @ K^T: [batch, heads, seq, dim] @ [batch, heads, dim, seq]
|
| 129 |
+
# -> [batch, heads, seq, seq]
|
| 130 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
| 131 |
+
attn_weights = torch.softmax(attn_scores, dim=-1)
|
| 132 |
+
attn_weights = self.dropout(attn_weights)
|
| 133 |
+
|
| 134 |
+
# Apply attention to values
|
| 135 |
+
# [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim]
|
| 136 |
+
# -> [batch, num_heads, seq_len, head_dim]
|
| 137 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 138 |
+
|
| 139 |
+
# Transpose back and reshape: [batch, num_heads, seq_len, head_dim]
|
| 140 |
+
# -> [batch, seq_len, num_heads, head_dim]
|
| 141 |
+
# -> [batch, seq_len, embed_dim]
|
| 142 |
+
# Use -1 for batch dimension to enable dynamic batch size in ONNX
|
| 143 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 144 |
+
attn_output = attn_output.reshape(-1, seq_len, self.embed_dim)
|
| 145 |
+
|
| 146 |
+
# Final projection
|
| 147 |
+
output = self.out_proj(attn_output)
|
| 148 |
+
|
| 149 |
+
return output
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class CrossAttentionErrorPredictor(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
Error predictor with cross-attention between Munsell components.
|
| 155 |
+
|
| 156 |
+
Uses cross-attention to learn correlations between errors in different
|
| 157 |
+
Munsell components (hue, value, chroma, code).
|
| 158 |
+
|
| 159 |
+
Parameters
|
| 160 |
+
----------
|
| 161 |
+
input_dim : int, optional
|
| 162 |
+
Input dimension (7 = xyY_norm + base_pred_norm).
|
| 163 |
+
context_dim : int, optional
|
| 164 |
+
Dimension of shared context features.
|
| 165 |
+
component_dim : int, optional
|
| 166 |
+
Dimension of component-specific features.
|
| 167 |
+
n_components : int, optional
|
| 168 |
+
Number of Munsell components (4).
|
| 169 |
+
n_attention_heads : int, optional
|
| 170 |
+
Number of attention heads for cross-attention.
|
| 171 |
+
dropout : float, optional
|
| 172 |
+
Dropout probability.
|
| 173 |
+
|
| 174 |
+
Attributes
|
| 175 |
+
----------
|
| 176 |
+
context_encoder : nn.Sequential
|
| 177 |
+
Shared encoder: input_dim → 256 → context_dim.
|
| 178 |
+
component_encoders : nn.ModuleList
|
| 179 |
+
Component-specific encoders: context_dim → component_dim (x4).
|
| 180 |
+
cross_attention : CustomMultiheadAttention
|
| 181 |
+
Cross-attention module between component features.
|
| 182 |
+
attention_norm : nn.LayerNorm
|
| 183 |
+
Layer normalization after attention.
|
| 184 |
+
component_decoders : nn.ModuleList
|
| 185 |
+
Component-specific decoders: component_dim → 128 → 1 (x4).
|
| 186 |
+
|
| 187 |
+
Notes
|
| 188 |
+
-----
|
| 189 |
+
Architecture:
|
| 190 |
+
1. Shared context encoder: 7 → 256 → 512
|
| 191 |
+
2. Component-specific encoders: 512 → 256 (x4)
|
| 192 |
+
3. Multi-head cross-attention between components
|
| 193 |
+
4. Residual connection + layer norm
|
| 194 |
+
5. Component-specific decoders: 256 → 128 → 1
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
input_dim: int = 7,
|
| 200 |
+
context_dim: int = 512,
|
| 201 |
+
component_dim: int = 256,
|
| 202 |
+
n_components: int = 4,
|
| 203 |
+
n_attention_heads: int = 4,
|
| 204 |
+
dropout: float = 0.1,
|
| 205 |
+
) -> None:
|
| 206 |
+
"""Initialize the cross-attention error predictor."""
|
| 207 |
+
super().__init__()
|
| 208 |
+
|
| 209 |
+
self.n_components = n_components
|
| 210 |
+
self.component_dim = component_dim
|
| 211 |
+
|
| 212 |
+
# Shared context encoder
|
| 213 |
+
self.context_encoder = nn.Sequential(
|
| 214 |
+
nn.Linear(input_dim, 256),
|
| 215 |
+
nn.GELU(),
|
| 216 |
+
nn.LayerNorm(256),
|
| 217 |
+
nn.Dropout(dropout),
|
| 218 |
+
nn.Linear(256, context_dim),
|
| 219 |
+
nn.GELU(),
|
| 220 |
+
nn.LayerNorm(context_dim),
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Component-specific encoders
|
| 224 |
+
self.component_encoders = nn.ModuleList(
|
| 225 |
+
[
|
| 226 |
+
nn.Sequential(
|
| 227 |
+
nn.Linear(context_dim, component_dim),
|
| 228 |
+
nn.GELU(),
|
| 229 |
+
nn.LayerNorm(component_dim),
|
| 230 |
+
)
|
| 231 |
+
for _ in range(n_components)
|
| 232 |
+
]
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Multi-head cross-attention (using custom implementation)
|
| 236 |
+
self.cross_attention = CustomMultiheadAttention(
|
| 237 |
+
embed_dim=component_dim,
|
| 238 |
+
num_heads=n_attention_heads,
|
| 239 |
+
dropout=dropout,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Layer norm after attention
|
| 243 |
+
self.attention_norm = nn.LayerNorm(component_dim)
|
| 244 |
+
|
| 245 |
+
# Component-specific decoders
|
| 246 |
+
self.component_decoders = nn.ModuleList(
|
| 247 |
+
[
|
| 248 |
+
nn.Sequential(
|
| 249 |
+
nn.Linear(component_dim, 128),
|
| 250 |
+
nn.GELU(),
|
| 251 |
+
nn.LayerNorm(128),
|
| 252 |
+
nn.Dropout(dropout),
|
| 253 |
+
nn.Linear(128, 1),
|
| 254 |
+
)
|
| 255 |
+
for _ in range(n_components)
|
| 256 |
+
]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Forward pass with cross-attention.
|
| 262 |
+
|
| 263 |
+
Parameters
|
| 264 |
+
----------
|
| 265 |
+
x : Tensor
|
| 266 |
+
Input [xyY_norm (3) + base_pred_norm (4)] = 7 features
|
| 267 |
+
|
| 268 |
+
Returns
|
| 269 |
+
-------
|
| 270 |
+
Tensor
|
| 271 |
+
Predicted errors [hue_err, value_err, chroma_err, code_err]
|
| 272 |
+
"""
|
| 273 |
+
# Shared context encoding
|
| 274 |
+
context = self.context_encoder(x) # [batch, 512]
|
| 275 |
+
|
| 276 |
+
# Component-specific encoding
|
| 277 |
+
component_features = []
|
| 278 |
+
for encoder in self.component_encoders:
|
| 279 |
+
feat = encoder(context) # [batch, 256]
|
| 280 |
+
component_features.append(feat)
|
| 281 |
+
|
| 282 |
+
# Stack for cross-attention: [batch, 4, 256]
|
| 283 |
+
component_stack = torch.stack(component_features, dim=1)
|
| 284 |
+
|
| 285 |
+
# Cross-attention between components
|
| 286 |
+
attended = self.cross_attention(component_stack) # [batch, 4, 256]
|
| 287 |
+
|
| 288 |
+
# Residual connection + layer norm
|
| 289 |
+
component_stack = self.attention_norm(component_stack + attended)
|
| 290 |
+
|
| 291 |
+
# Component-specific decoding (unrolled for ONNX compatibility)
|
| 292 |
+
# Use unbind to split the tensor instead of indexing to preserve batch dimension
|
| 293 |
+
components = torch.unbind(
|
| 294 |
+
component_stack, dim=1
|
| 295 |
+
) # Split into 4 tensors of shape [batch, 256]
|
| 296 |
+
|
| 297 |
+
# Decode each component explicitly
|
| 298 |
+
pred_0 = self.component_decoders[0](components[0]) # [batch, 1]
|
| 299 |
+
pred_1 = self.component_decoders[1](components[1]) # [batch, 1]
|
| 300 |
+
pred_2 = self.component_decoders[2](components[2]) # [batch, 1]
|
| 301 |
+
pred_3 = self.component_decoders[3](components[3]) # [batch, 1]
|
| 302 |
+
|
| 303 |
+
# Concatenate along dimension 1 and squeeze
|
| 304 |
+
predictions = torch.cat([pred_0, pred_1, pred_2, pred_3], dim=1) # [batch, 4]
|
| 305 |
+
|
| 306 |
+
return predictions
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def train_cross_attention_error_predictor(
|
| 310 |
+
epochs: int = 300,
|
| 311 |
+
batch_size: int = 1024,
|
| 312 |
+
lr: float = 0.0005,
|
| 313 |
+
dropout: float = 0.1,
|
| 314 |
+
context_dim: int = 512,
|
| 315 |
+
component_dim: int = 256,
|
| 316 |
+
n_attention_heads: int = 4,
|
| 317 |
+
) -> tuple[CrossAttentionErrorPredictor, float]:
|
| 318 |
+
"""
|
| 319 |
+
Train cross-attention error predictor.
|
| 320 |
+
|
| 321 |
+
This model uses cross-attention between component branches to learn
|
| 322 |
+
correlations between errors in different Munsell components.
|
| 323 |
+
|
| 324 |
+
Parameters
|
| 325 |
+
----------
|
| 326 |
+
epochs : int, optional
|
| 327 |
+
Maximum number of training epochs.
|
| 328 |
+
batch_size : int, optional
|
| 329 |
+
Training batch size.
|
| 330 |
+
lr : float, optional
|
| 331 |
+
Learning rate for AdamW optimizer.
|
| 332 |
+
dropout : float, optional
|
| 333 |
+
Dropout rate for regularization.
|
| 334 |
+
context_dim : int, optional
|
| 335 |
+
Dimension of shared context features.
|
| 336 |
+
component_dim : int, optional
|
| 337 |
+
Dimension of component-specific features.
|
| 338 |
+
n_attention_heads : int, optional
|
| 339 |
+
Number of attention heads for cross-attention.
|
| 340 |
+
|
| 341 |
+
Returns
|
| 342 |
+
-------
|
| 343 |
+
model : CrossAttentionErrorPredictor
|
| 344 |
+
Trained model with best validation loss weights.
|
| 345 |
+
best_val_loss : float
|
| 346 |
+
Best validation loss achieved during training.
|
| 347 |
+
|
| 348 |
+
Notes
|
| 349 |
+
-----
|
| 350 |
+
The training pipeline:
|
| 351 |
+
1. Loads pre-trained Multi-Head base model
|
| 352 |
+
2. Generates base model predictions for training data
|
| 353 |
+
3. Computes residual errors between predictions and targets
|
| 354 |
+
4. Trains cross-attention error predictor on these residuals
|
| 355 |
+
5. Uses CosineAnnealingLR scheduler
|
| 356 |
+
6. Early stopping based on validation loss
|
| 357 |
+
7. Exports model to ONNX format
|
| 358 |
+
8. Logs metrics and artifacts to MLflow
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
LOGGER.info("=" * 80)
|
| 362 |
+
LOGGER.info("Training Multi-Head + Cross-Attention Error Predictor")
|
| 363 |
+
LOGGER.info("=" * 80)
|
| 364 |
+
LOGGER.info("")
|
| 365 |
+
LOGGER.info("Architecture:")
|
| 366 |
+
LOGGER.info(" - Shared context encoder: 7 → 256 → %d", context_dim)
|
| 367 |
+
LOGGER.info(" - Component encoders: %d → %d (x4)", context_dim, component_dim)
|
| 368 |
+
LOGGER.info(" - Cross-attention: %d heads", n_attention_heads)
|
| 369 |
+
LOGGER.info(" - Component decoders: %d → 128 → 1 (x4)", component_dim)
|
| 370 |
+
LOGGER.info("")
|
| 371 |
+
LOGGER.info("Hyperparameters:")
|
| 372 |
+
LOGGER.info(" lr: %.6f", lr)
|
| 373 |
+
LOGGER.info(" batch_size: %d", batch_size)
|
| 374 |
+
LOGGER.info(" dropout: %.2f", dropout)
|
| 375 |
+
LOGGER.info("")
|
| 376 |
+
|
| 377 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 378 |
+
LOGGER.info("Using device: %s", device)
|
| 379 |
+
|
| 380 |
+
# Paths
|
| 381 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 382 |
+
data_dir = PROJECT_ROOT / "data"
|
| 383 |
+
base_model_path = model_directory / "multi_head.onnx"
|
| 384 |
+
params_path = model_directory / "multi_head_normalization_params.npz"
|
| 385 |
+
cache_file = data_dir / "training_data.npz"
|
| 386 |
+
|
| 387 |
+
# Load base model
|
| 388 |
+
LOGGER.info("")
|
| 389 |
+
LOGGER.info("Loading Multi-Head base model from %s...", base_model_path)
|
| 390 |
+
base_session = ort.InferenceSession(str(base_model_path))
|
| 391 |
+
params = np.load(params_path, allow_pickle=True)
|
| 392 |
+
input_params = params["input_params"].item()
|
| 393 |
+
output_params = params["output_params"].item()
|
| 394 |
+
|
| 395 |
+
# Load training data
|
| 396 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 397 |
+
data = np.load(cache_file)
|
| 398 |
+
X_train = data["X_train"]
|
| 399 |
+
y_train = data["y_train"]
|
| 400 |
+
X_val = data["X_val"]
|
| 401 |
+
y_val = data["y_val"]
|
| 402 |
+
|
| 403 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 404 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 405 |
+
|
| 406 |
+
# Generate base model predictions
|
| 407 |
+
LOGGER.info("")
|
| 408 |
+
LOGGER.info("Generating Multi-Head base model predictions...")
|
| 409 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 410 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 411 |
+
base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
|
| 412 |
+
|
| 413 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 414 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 415 |
+
base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
|
| 416 |
+
|
| 417 |
+
# Compute errors
|
| 418 |
+
error_train = y_train_norm - base_pred_train_norm
|
| 419 |
+
error_val = y_val_norm - base_pred_val_norm
|
| 420 |
+
|
| 421 |
+
LOGGER.info("")
|
| 422 |
+
LOGGER.info("Base model error statistics (normalized space):")
|
| 423 |
+
LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
|
| 424 |
+
LOGGER.info(" Std of error: %.6f", np.std(error_train))
|
| 425 |
+
LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
|
| 426 |
+
|
| 427 |
+
# Create combined input: [xyY_norm, base_prediction_norm]
|
| 428 |
+
X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
|
| 429 |
+
X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
|
| 430 |
+
|
| 431 |
+
# Convert to PyTorch tensors
|
| 432 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 433 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 434 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 435 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 436 |
+
|
| 437 |
+
# Create data loaders
|
| 438 |
+
train_dataset = TensorDataset(X_train_t, error_train_t)
|
| 439 |
+
val_dataset = TensorDataset(X_val_t, error_val_t)
|
| 440 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 441 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 442 |
+
|
| 443 |
+
# Initialize model
|
| 444 |
+
model = CrossAttentionErrorPredictor(
|
| 445 |
+
input_dim=7,
|
| 446 |
+
context_dim=context_dim,
|
| 447 |
+
component_dim=component_dim,
|
| 448 |
+
n_attention_heads=n_attention_heads,
|
| 449 |
+
dropout=dropout,
|
| 450 |
+
).to(device)
|
| 451 |
+
|
| 452 |
+
# Count parameters
|
| 453 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 454 |
+
LOGGER.info("")
|
| 455 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 456 |
+
|
| 457 |
+
context_params = sum(p.numel() for p in model.context_encoder.parameters())
|
| 458 |
+
attention_params = sum(p.numel() for p in model.cross_attention.parameters())
|
| 459 |
+
LOGGER.info(" - Context encoder: %s", f"{context_params:,}")
|
| 460 |
+
LOGGER.info(" - Cross-attention: %s", f"{attention_params:,}")
|
| 461 |
+
|
| 462 |
+
# Training setup
|
| 463 |
+
criterion = nn.MSELoss()
|
| 464 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 465 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
| 466 |
+
|
| 467 |
+
# MLflow setup
|
| 468 |
+
run_name = setup_mlflow_experiment("from_xyY", "cross_attention_error_predictor")
|
| 469 |
+
LOGGER.info("")
|
| 470 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 471 |
+
|
| 472 |
+
# Training loop
|
| 473 |
+
best_val_loss = float("inf")
|
| 474 |
+
best_state = None
|
| 475 |
+
patience = 30
|
| 476 |
+
patience_counter = 0
|
| 477 |
+
|
| 478 |
+
LOGGER.info("")
|
| 479 |
+
LOGGER.info("Starting training...")
|
| 480 |
+
|
| 481 |
+
with mlflow.start_run(run_name=run_name):
|
| 482 |
+
mlflow.log_params(
|
| 483 |
+
{
|
| 484 |
+
"model": "cross_attention_error_predictor",
|
| 485 |
+
"context_dim": context_dim,
|
| 486 |
+
"component_dim": component_dim,
|
| 487 |
+
"n_attention_heads": n_attention_heads,
|
| 488 |
+
"dropout": dropout,
|
| 489 |
+
"learning_rate": lr,
|
| 490 |
+
"batch_size": batch_size,
|
| 491 |
+
"epochs": epochs,
|
| 492 |
+
"patience": patience,
|
| 493 |
+
"total_params": total_params,
|
| 494 |
+
}
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
for epoch in range(epochs):
|
| 498 |
+
# Training
|
| 499 |
+
model.train()
|
| 500 |
+
train_loss = 0.0
|
| 501 |
+
for X_batch, y_batch in train_loader:
|
| 502 |
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
| 503 |
+
|
| 504 |
+
optimizer.zero_grad()
|
| 505 |
+
pred = model(X_batch)
|
| 506 |
+
loss = criterion(pred, y_batch)
|
| 507 |
+
loss.backward()
|
| 508 |
+
optimizer.step()
|
| 509 |
+
train_loss += loss.item() * len(X_batch)
|
| 510 |
+
|
| 511 |
+
train_loss /= len(X_train_t)
|
| 512 |
+
scheduler.step()
|
| 513 |
+
|
| 514 |
+
# Validation
|
| 515 |
+
model.eval()
|
| 516 |
+
val_loss = 0.0
|
| 517 |
+
with torch.no_grad():
|
| 518 |
+
for X_batch, y_batch in val_loader:
|
| 519 |
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
| 520 |
+
pred = model(X_batch)
|
| 521 |
+
val_loss += criterion(pred, y_batch).item() * len(X_batch)
|
| 522 |
+
val_loss /= len(X_val_t)
|
| 523 |
+
|
| 524 |
+
log_training_epoch(
|
| 525 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if val_loss < best_val_loss:
|
| 529 |
+
best_val_loss = val_loss
|
| 530 |
+
best_state = copy.deepcopy(model.state_dict())
|
| 531 |
+
patience_counter = 0
|
| 532 |
+
LOGGER.info(
|
| 533 |
+
"Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - LR: %.6f",
|
| 534 |
+
epoch + 1,
|
| 535 |
+
epochs,
|
| 536 |
+
train_loss,
|
| 537 |
+
val_loss,
|
| 538 |
+
optimizer.param_groups[0]["lr"],
|
| 539 |
+
)
|
| 540 |
+
else:
|
| 541 |
+
patience_counter += 1
|
| 542 |
+
if (epoch + 1) % 50 == 0:
|
| 543 |
+
LOGGER.info(
|
| 544 |
+
"Epoch %03d/%d - Train: %.6f, Val: %.6f",
|
| 545 |
+
epoch + 1,
|
| 546 |
+
epochs,
|
| 547 |
+
train_loss,
|
| 548 |
+
val_loss,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if patience_counter >= patience:
|
| 552 |
+
LOGGER.info("Early stopping at epoch %d", epoch + 1)
|
| 553 |
+
break
|
| 554 |
+
|
| 555 |
+
# Load best model
|
| 556 |
+
model.load_state_dict(best_state)
|
| 557 |
+
|
| 558 |
+
mlflow.log_metrics(
|
| 559 |
+
{
|
| 560 |
+
"best_val_loss": best_val_loss,
|
| 561 |
+
"final_epoch": epoch + 1,
|
| 562 |
+
}
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
LOGGER.info("")
|
| 566 |
+
LOGGER.info("Final Results:")
|
| 567 |
+
LOGGER.info(" Best Val Loss: %.6f", best_val_loss)
|
| 568 |
+
|
| 569 |
+
# Save model
|
| 570 |
+
model_directory.mkdir(exist_ok=True)
|
| 571 |
+
checkpoint_path = (
|
| 572 |
+
model_directory / "multi_head_cross_attention_error_predictor.pth"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
torch.save(
|
| 576 |
+
{
|
| 577 |
+
"model_state_dict": model.state_dict(),
|
| 578 |
+
"val_loss": best_val_loss,
|
| 579 |
+
"hyperparameters": {
|
| 580 |
+
"context_dim": context_dim,
|
| 581 |
+
"component_dim": component_dim,
|
| 582 |
+
"n_attention_heads": n_attention_heads,
|
| 583 |
+
"dropout": dropout,
|
| 584 |
+
"lr": lr,
|
| 585 |
+
"batch_size": batch_size,
|
| 586 |
+
},
|
| 587 |
+
},
|
| 588 |
+
checkpoint_path,
|
| 589 |
+
)
|
| 590 |
+
LOGGER.info("")
|
| 591 |
+
LOGGER.info("Saved checkpoint: %s", checkpoint_path)
|
| 592 |
+
|
| 593 |
+
# Export to ONNX
|
| 594 |
+
LOGGER.info("")
|
| 595 |
+
LOGGER.info("Exporting error predictor to ONNX...")
|
| 596 |
+
model.eval()
|
| 597 |
+
model.cpu()
|
| 598 |
+
|
| 599 |
+
dummy_input = torch.randn(1, 7)
|
| 600 |
+
onnx_path = model_directory / "multi_head_cross_attention_error_predictor.onnx"
|
| 601 |
+
|
| 602 |
+
torch.onnx.export(
|
| 603 |
+
model,
|
| 604 |
+
dummy_input,
|
| 605 |
+
onnx_path,
|
| 606 |
+
export_params=True,
|
| 607 |
+
opset_version=17,
|
| 608 |
+
input_names=["combined_input"],
|
| 609 |
+
output_names=["error_correction"],
|
| 610 |
+
dynamic_axes={
|
| 611 |
+
"combined_input": {0: "batch_size"},
|
| 612 |
+
"error_correction": {0: "batch_size"},
|
| 613 |
+
},
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
mlflow.log_artifact(str(checkpoint_path))
|
| 617 |
+
mlflow.log_artifact(str(onnx_path))
|
| 618 |
+
mlflow.pytorch.log_model(model, "model")
|
| 619 |
+
|
| 620 |
+
LOGGER.info("ONNX model saved to: %s", onnx_path)
|
| 621 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 622 |
+
|
| 623 |
+
LOGGER.info("=" * 80)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
return model, best_val_loss
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
if __name__ == "__main__":
|
| 630 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 631 |
+
|
| 632 |
+
train_cross_attention_error_predictor(
|
| 633 |
+
epochs=300,
|
| 634 |
+
batch_size=1024,
|
| 635 |
+
lr=0.0005,
|
| 636 |
+
dropout=0.1,
|
| 637 |
+
context_dim=512,
|
| 638 |
+
component_dim=256,
|
| 639 |
+
n_attention_heads=4,
|
| 640 |
+
)
|
learning_munsell/training/from_xyY/train_multi_head_gamma.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train multi-head ML model for xyY to Munsell conversion with gamma-corrected Y.
|
| 3 |
+
|
| 4 |
+
Experiment: Apply gamma 2.33 to Y before normalization to better align
|
| 5 |
+
with perceptual lightness (Munsell Value scale is perceptually uniform).
|
| 6 |
+
|
| 7 |
+
The multi-head architecture has separate heads for each Munsell component,
|
| 8 |
+
so gamma correction on Y should primarily benefit Value prediction without
|
| 9 |
+
negatively impacting Chroma prediction (unlike the single MLP).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import click
|
| 16 |
+
import mlflow
|
| 17 |
+
import mlflow.pytorch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from numpy.typing import NDArray
|
| 21 |
+
from torch import nn, optim
|
| 22 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 23 |
+
|
| 24 |
+
from learning_munsell import PROJECT_ROOT
|
| 25 |
+
from learning_munsell.models.networks import MultiHeadMLPToMunsell
|
| 26 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 27 |
+
from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell
|
| 28 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 29 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 30 |
+
|
| 31 |
+
LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Gamma value for Y transformation
|
| 34 |
+
GAMMA = 2.33
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def normalize_inputs(
|
| 38 |
+
X: NDArray, gamma: float = GAMMA
|
| 39 |
+
) -> tuple[NDArray, dict[str, Any]]:
|
| 40 |
+
"""
|
| 41 |
+
Normalize xyY inputs to [0, 1] range with gamma correction on Y.
|
| 42 |
+
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
X : ndarray
|
| 46 |
+
xyY values of shape (n, 3) where columns are [x, y, Y].
|
| 47 |
+
gamma : float
|
| 48 |
+
Gamma value to apply to Y component.
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
ndarray
|
| 53 |
+
Normalized values with gamma-corrected Y, dtype float32.
|
| 54 |
+
dict
|
| 55 |
+
Normalization parameters including gamma value.
|
| 56 |
+
"""
|
| 57 |
+
# xyY chromaticity and luminance ranges (all [0, 1])
|
| 58 |
+
x_range = (0.0, 1.0)
|
| 59 |
+
y_range = (0.0, 1.0)
|
| 60 |
+
Y_range = (0.0, 1.0)
|
| 61 |
+
|
| 62 |
+
X_norm = X.copy()
|
| 63 |
+
X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
|
| 64 |
+
X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
|
| 65 |
+
|
| 66 |
+
# Normalize Y first, then apply gamma
|
| 67 |
+
Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
|
| 68 |
+
# Clip to avoid numerical issues with negative values
|
| 69 |
+
Y_normalized = np.clip(Y_normalized, 0, 1)
|
| 70 |
+
# Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light
|
| 71 |
+
X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
|
| 72 |
+
|
| 73 |
+
params = {
|
| 74 |
+
"x_range": x_range,
|
| 75 |
+
"y_range": y_range,
|
| 76 |
+
"Y_range": Y_range,
|
| 77 |
+
"gamma": gamma,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return X_norm, params
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@click.command()
|
| 86 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 87 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 88 |
+
@click.option("--lr", default=5e-4, help="Learning rate")
|
| 89 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 90 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Train the multi-head model with gamma-corrected Y input.
|
| 93 |
+
|
| 94 |
+
Notes
|
| 95 |
+
-----
|
| 96 |
+
The training pipeline:
|
| 97 |
+
1. Loads training and validation data from cache
|
| 98 |
+
2. Normalizes inputs with gamma correction (gamma=2.33) on Y
|
| 99 |
+
3. Normalizes Munsell outputs to [0, 1] range
|
| 100 |
+
4. Trains multi-head MLP with weighted MSE loss
|
| 101 |
+
5. Uses early stopping based on validation loss
|
| 102 |
+
6. Exports best model to ONNX format
|
| 103 |
+
7. Logs metrics and artifacts to MLflow
|
| 104 |
+
|
| 105 |
+
The gamma correction on Y aligns with perceptual lightness. The Munsell
|
| 106 |
+
Value scale is perceptually uniform, so gamma correction should primarily
|
| 107 |
+
benefit Value prediction without negatively impacting Chroma prediction.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
LOGGER.info("=" * 80)
|
| 111 |
+
LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Gamma Experiment")
|
| 112 |
+
LOGGER.info("Gamma = %.2f applied to Y component", GAMMA)
|
| 113 |
+
LOGGER.info("=" * 80)
|
| 114 |
+
|
| 115 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 116 |
+
LOGGER.info("Using device: %s", device)
|
| 117 |
+
|
| 118 |
+
# Load training data
|
| 119 |
+
data_dir = PROJECT_ROOT / "data"
|
| 120 |
+
cache_file = data_dir / "training_data.npz"
|
| 121 |
+
|
| 122 |
+
if not cache_file.exists():
|
| 123 |
+
LOGGER.error("Error: Training data not found at %s", cache_file)
|
| 124 |
+
LOGGER.error("Please run 01_generate_training_data.py first")
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 128 |
+
data = np.load(cache_file)
|
| 129 |
+
|
| 130 |
+
X_train = data["X_train"]
|
| 131 |
+
y_train = data["y_train"]
|
| 132 |
+
X_val = data["X_val"]
|
| 133 |
+
y_val = data["y_val"]
|
| 134 |
+
|
| 135 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 136 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 137 |
+
|
| 138 |
+
# Normalize data with gamma correction
|
| 139 |
+
X_train_norm, input_params = normalize_inputs(X_train, gamma=GAMMA)
|
| 140 |
+
X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA)
|
| 141 |
+
|
| 142 |
+
# Use shared normalization parameters for Munsell outputs
|
| 143 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 144 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 145 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 146 |
+
|
| 147 |
+
LOGGER.info("")
|
| 148 |
+
LOGGER.info("Input normalization with gamma=%.2f:", GAMMA)
|
| 149 |
+
LOGGER.info(" Y range after gamma: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max())
|
| 150 |
+
|
| 151 |
+
# Convert to PyTorch tensors
|
| 152 |
+
X_train_t = torch.FloatTensor(X_train_norm)
|
| 153 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 154 |
+
X_val_t = torch.FloatTensor(X_val_norm)
|
| 155 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 156 |
+
|
| 157 |
+
# Create data loaders
|
| 158 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 159 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 160 |
+
|
| 161 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 162 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 163 |
+
|
| 164 |
+
# Initialize model
|
| 165 |
+
model = MultiHeadMLPToMunsell().to(device)
|
| 166 |
+
LOGGER.info("")
|
| 167 |
+
LOGGER.info("Model architecture:")
|
| 168 |
+
LOGGER.info("%s", model)
|
| 169 |
+
|
| 170 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 171 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 172 |
+
|
| 173 |
+
# Training setup
|
| 174 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 175 |
+
criterion = weighted_mse_loss
|
| 176 |
+
|
| 177 |
+
# MLflow setup
|
| 178 |
+
run_name = setup_mlflow_experiment("from_xyY", f"multi_head_gamma_{GAMMA}")
|
| 179 |
+
|
| 180 |
+
LOGGER.info("")
|
| 181 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 182 |
+
|
| 183 |
+
# Training loop
|
| 184 |
+
best_val_loss = float("inf")
|
| 185 |
+
patience_counter = 0
|
| 186 |
+
|
| 187 |
+
LOGGER.info("")
|
| 188 |
+
LOGGER.info("Starting training...")
|
| 189 |
+
|
| 190 |
+
with mlflow.start_run(run_name=run_name):
|
| 191 |
+
mlflow.log_params(
|
| 192 |
+
{
|
| 193 |
+
"model": "multi_head_gamma",
|
| 194 |
+
"num_epochs": epochs,
|
| 195 |
+
"batch_size": batch_size,
|
| 196 |
+
"learning_rate": lr,
|
| 197 |
+
"optimizer": "Adam",
|
| 198 |
+
"criterion": "weighted_mse_loss",
|
| 199 |
+
"patience": patience,
|
| 200 |
+
"total_params": total_params,
|
| 201 |
+
"gamma": GAMMA,
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
for epoch in range(epochs):
|
| 206 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 207 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 208 |
+
|
| 209 |
+
log_training_epoch(
|
| 210 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
LOGGER.info(
|
| 214 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 215 |
+
epoch + 1,
|
| 216 |
+
epochs,
|
| 217 |
+
train_loss,
|
| 218 |
+
val_loss,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if val_loss < best_val_loss:
|
| 222 |
+
best_val_loss = val_loss
|
| 223 |
+
patience_counter = 0
|
| 224 |
+
|
| 225 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 226 |
+
model_directory.mkdir(exist_ok=True)
|
| 227 |
+
checkpoint_file = model_directory / "multi_head_gamma_best.pth"
|
| 228 |
+
|
| 229 |
+
torch.save(
|
| 230 |
+
{
|
| 231 |
+
"model_state_dict": model.state_dict(),
|
| 232 |
+
"input_params": input_params,
|
| 233 |
+
"output_params": output_params,
|
| 234 |
+
"epoch": epoch,
|
| 235 |
+
"val_loss": val_loss,
|
| 236 |
+
},
|
| 237 |
+
checkpoint_file,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 241 |
+
else:
|
| 242 |
+
patience_counter += 1
|
| 243 |
+
if patience_counter >= patience:
|
| 244 |
+
LOGGER.info("")
|
| 245 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
mlflow.log_metrics(
|
| 249 |
+
{
|
| 250 |
+
"best_val_loss": best_val_loss,
|
| 251 |
+
"final_epoch": epoch + 1,
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Export to ONNX
|
| 256 |
+
LOGGER.info("")
|
| 257 |
+
LOGGER.info("Exporting model to ONNX...")
|
| 258 |
+
model.eval()
|
| 259 |
+
|
| 260 |
+
checkpoint = torch.load(checkpoint_file)
|
| 261 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 262 |
+
|
| 263 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 264 |
+
|
| 265 |
+
onnx_file = model_directory / "multi_head_gamma.onnx"
|
| 266 |
+
torch.onnx.export(
|
| 267 |
+
model,
|
| 268 |
+
dummy_input,
|
| 269 |
+
onnx_file,
|
| 270 |
+
export_params=True,
|
| 271 |
+
opset_version=15,
|
| 272 |
+
input_names=["xyY_gamma"],
|
| 273 |
+
output_names=["munsell_spec"],
|
| 274 |
+
dynamic_axes={"xyY_gamma": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Save normalization parameters (including gamma)
|
| 278 |
+
params_file = model_directory / "multi_head_gamma_normalization_params.npz"
|
| 279 |
+
np.savez(
|
| 280 |
+
params_file,
|
| 281 |
+
input_params=input_params,
|
| 282 |
+
output_params=output_params,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 286 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 287 |
+
LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA)
|
| 288 |
+
|
| 289 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 290 |
+
mlflow.log_artifact(str(onnx_file))
|
| 291 |
+
mlflow.log_artifact(str(params_file))
|
| 292 |
+
mlflow.pytorch.log_model(model, "model")
|
| 293 |
+
|
| 294 |
+
LOGGER.info("=" * 80)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
if __name__ == "__main__":
|
| 298 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 299 |
+
|
| 300 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train multi-head ML models with various gamma values to find optimal gamma.
|
| 3 |
+
|
| 4 |
+
Sweeps gamma from 1.0 to 3.0 in increments of 0.1 and evaluates each model
|
| 5 |
+
on real Munsell colours using Delta-E CIE2000.
|
| 6 |
+
|
| 7 |
+
Supports parallel execution with multiple runs per gamma for averaging.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from colour import XYZ_to_Lab, xyY_to_XYZ
|
| 17 |
+
from colour.difference import delta_E_CIE2000
|
| 18 |
+
from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
|
| 19 |
+
from colour.notation.munsell import (
|
| 20 |
+
CCS_ILLUMINANT_MUNSELL,
|
| 21 |
+
munsell_specification_to_xyY,
|
| 22 |
+
)
|
| 23 |
+
from numpy.typing import NDArray
|
| 24 |
+
from torch import nn, optim
|
| 25 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 26 |
+
|
| 27 |
+
from learning_munsell import PROJECT_ROOT
|
| 28 |
+
from learning_munsell.models.networks import MultiHeadMLPToMunsell
|
| 29 |
+
from learning_munsell.utilities.data import (
|
| 30 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 31 |
+
normalize_munsell,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def normalize_inputs(X: NDArray, gamma: float) -> tuple[NDArray, dict[str, Any]]:
|
| 38 |
+
"""
|
| 39 |
+
Normalize xyY inputs to [0, 1] range with gamma correction on Y.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
X : ndarray
|
| 44 |
+
xyY values of shape (n, 3) where columns are [x, y, Y].
|
| 45 |
+
gamma : float
|
| 46 |
+
Gamma value to apply to Y component.
|
| 47 |
+
|
| 48 |
+
Returns
|
| 49 |
+
-------
|
| 50 |
+
ndarray
|
| 51 |
+
Normalized values with gamma-corrected Y, dtype float32.
|
| 52 |
+
dict
|
| 53 |
+
Normalization parameters including gamma value.
|
| 54 |
+
"""
|
| 55 |
+
x_range = (0.0, 1.0)
|
| 56 |
+
y_range = (0.0, 1.0)
|
| 57 |
+
Y_range = (0.0, 1.0)
|
| 58 |
+
|
| 59 |
+
X_norm = X.copy()
|
| 60 |
+
X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
|
| 61 |
+
X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
|
| 62 |
+
|
| 63 |
+
Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
|
| 64 |
+
Y_normalized = np.clip(Y_normalized, 0, 1)
|
| 65 |
+
X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
|
| 66 |
+
|
| 67 |
+
params = {
|
| 68 |
+
"x_range": x_range,
|
| 69 |
+
"y_range": y_range,
|
| 70 |
+
"Y_range": Y_range,
|
| 71 |
+
"gamma": gamma,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
return X_norm, params
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
|
| 78 |
+
"""
|
| 79 |
+
Denormalize Munsell output from [0, 1] to original ranges.
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
y_norm : ndarray
|
| 84 |
+
Normalized Munsell values in [0, 1] range.
|
| 85 |
+
params : dict
|
| 86 |
+
Normalization parameters containing range information.
|
| 87 |
+
|
| 88 |
+
Returns
|
| 89 |
+
-------
|
| 90 |
+
ndarray
|
| 91 |
+
Denormalized Munsell values in original ranges.
|
| 92 |
+
"""
|
| 93 |
+
y = np.copy(y_norm)
|
| 94 |
+
y[..., 0] = (
|
| 95 |
+
y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
|
| 96 |
+
+ params["hue_range"][0]
|
| 97 |
+
)
|
| 98 |
+
y[..., 1] = (
|
| 99 |
+
y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
|
| 100 |
+
+ params["value_range"][0]
|
| 101 |
+
)
|
| 102 |
+
y[..., 2] = (
|
| 103 |
+
y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
|
| 104 |
+
+ params["chroma_range"][0]
|
| 105 |
+
)
|
| 106 |
+
y[..., 3] = (
|
| 107 |
+
y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
|
| 108 |
+
+ params["code_range"][0]
|
| 109 |
+
)
|
| 110 |
+
return y
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def weighted_mse_loss(
|
| 114 |
+
pred: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None
|
| 115 |
+
) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Component-wise weighted MSE loss.
|
| 118 |
+
|
| 119 |
+
Parameters
|
| 120 |
+
----------
|
| 121 |
+
pred : Tensor
|
| 122 |
+
Predicted Munsell values.
|
| 123 |
+
target : Tensor
|
| 124 |
+
Ground truth Munsell values.
|
| 125 |
+
weights : Tensor, optional
|
| 126 |
+
Component weights [w_hue, w_value, w_chroma, w_code].
|
| 127 |
+
|
| 128 |
+
Returns
|
| 129 |
+
-------
|
| 130 |
+
Tensor
|
| 131 |
+
Weighted mean squared error loss.
|
| 132 |
+
"""
|
| 133 |
+
if weights is None:
|
| 134 |
+
weights = torch.tensor([1.0, 1.0, 3.0, 0.5], device=pred.device)
|
| 135 |
+
mse = (pred - target) ** 2
|
| 136 |
+
weighted_mse = mse * weights
|
| 137 |
+
return weighted_mse.mean()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def clamp_munsell_specification(spec: NDArray) -> NDArray:
|
| 141 |
+
"""
|
| 142 |
+
Clamp Munsell specification to valid ranges.
|
| 143 |
+
|
| 144 |
+
Parameters
|
| 145 |
+
----------
|
| 146 |
+
spec : ndarray
|
| 147 |
+
Munsell specification [hue, value, chroma, code].
|
| 148 |
+
|
| 149 |
+
Returns
|
| 150 |
+
-------
|
| 151 |
+
ndarray
|
| 152 |
+
Clamped Munsell specification within valid ranges.
|
| 153 |
+
"""
|
| 154 |
+
clamped = np.copy(spec)
|
| 155 |
+
clamped[..., 0] = np.clip(spec[..., 0], 0.5, 10.0)
|
| 156 |
+
clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0)
|
| 157 |
+
clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0)
|
| 158 |
+
clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0)
|
| 159 |
+
return clamped
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]:
|
| 163 |
+
"""
|
| 164 |
+
Compute Delta-E CIE2000 for predicted Munsell specifications.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
pred : ndarray
|
| 169 |
+
Predicted Munsell specifications.
|
| 170 |
+
reference_Lab : ndarray
|
| 171 |
+
Reference CIELAB values for comparison.
|
| 172 |
+
|
| 173 |
+
Returns
|
| 174 |
+
-------
|
| 175 |
+
list of float
|
| 176 |
+
Delta-E CIE2000 values for valid predictions.
|
| 177 |
+
|
| 178 |
+
Notes
|
| 179 |
+
-----
|
| 180 |
+
Predictions that cannot be converted to valid xyY are skipped.
|
| 181 |
+
"""
|
| 182 |
+
delta_E_values = []
|
| 183 |
+
for idx in range(len(pred)):
|
| 184 |
+
try:
|
| 185 |
+
ml_spec = clamp_munsell_specification(pred[idx])
|
| 186 |
+
ml_spec_for_conversion = ml_spec.copy()
|
| 187 |
+
ml_spec_for_conversion[3] = round(ml_spec[3])
|
| 188 |
+
ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
|
| 189 |
+
ml_XYZ = xyY_to_XYZ(ml_xyy)
|
| 190 |
+
ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 191 |
+
delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
|
| 192 |
+
delta_E_values.append(delta_E)
|
| 193 |
+
except (RuntimeError, ValueError):
|
| 194 |
+
continue
|
| 195 |
+
return delta_E_values
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def train_model(
|
| 199 |
+
gamma: float,
|
| 200 |
+
X_train: NDArray,
|
| 201 |
+
y_train: NDArray,
|
| 202 |
+
X_val: NDArray,
|
| 203 |
+
y_val: NDArray,
|
| 204 |
+
device: torch.device,
|
| 205 |
+
num_epochs: int = 100,
|
| 206 |
+
patience: int = 15,
|
| 207 |
+
) -> tuple[nn.Module, dict[str, Any], dict[str, Any], float]:
|
| 208 |
+
"""
|
| 209 |
+
Train a multi-head model with specified gamma value.
|
| 210 |
+
|
| 211 |
+
Parameters
|
| 212 |
+
----------
|
| 213 |
+
gamma : float
|
| 214 |
+
Gamma value for Y correction.
|
| 215 |
+
X_train : ndarray
|
| 216 |
+
Training inputs (xyY values).
|
| 217 |
+
y_train : ndarray
|
| 218 |
+
Training targets (Munsell specifications).
|
| 219 |
+
X_val : ndarray
|
| 220 |
+
Validation inputs.
|
| 221 |
+
y_val : ndarray
|
| 222 |
+
Validation targets.
|
| 223 |
+
device : torch.device
|
| 224 |
+
Device to run training on.
|
| 225 |
+
num_epochs : int, optional
|
| 226 |
+
Maximum number of training epochs. Default is 100.
|
| 227 |
+
patience : int, optional
|
| 228 |
+
Early stopping patience. Default is 15.
|
| 229 |
+
|
| 230 |
+
Returns
|
| 231 |
+
-------
|
| 232 |
+
nn.Module
|
| 233 |
+
Trained model with best validation loss.
|
| 234 |
+
dict
|
| 235 |
+
Input normalization parameters.
|
| 236 |
+
dict
|
| 237 |
+
Output normalization parameters.
|
| 238 |
+
float
|
| 239 |
+
Best validation loss achieved.
|
| 240 |
+
"""
|
| 241 |
+
# Normalize data
|
| 242 |
+
X_train_norm, input_params = normalize_inputs(X_train, gamma=gamma)
|
| 243 |
+
X_val_norm, _ = normalize_inputs(X_val, gamma=gamma)
|
| 244 |
+
|
| 245 |
+
# Use shared normalization parameters covering the full Munsell space for generalization
|
| 246 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 247 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 248 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 249 |
+
|
| 250 |
+
# Convert to tensors
|
| 251 |
+
X_train_t = torch.FloatTensor(X_train_norm)
|
| 252 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 253 |
+
X_val_t = torch.FloatTensor(X_val_norm)
|
| 254 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 255 |
+
|
| 256 |
+
# Create data loaders
|
| 257 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 258 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 259 |
+
|
| 260 |
+
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
|
| 261 |
+
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)
|
| 262 |
+
|
| 263 |
+
# Initialize model
|
| 264 |
+
model = MultiHeadMLPToMunsell().to(device)
|
| 265 |
+
optimizer = optim.Adam(model.parameters(), lr=5e-4)
|
| 266 |
+
criterion = weighted_mse_loss
|
| 267 |
+
|
| 268 |
+
best_val_loss = float("inf")
|
| 269 |
+
patience_counter = 0
|
| 270 |
+
best_state = None
|
| 271 |
+
|
| 272 |
+
for epoch in range(num_epochs):
|
| 273 |
+
# Train
|
| 274 |
+
model.train()
|
| 275 |
+
for X_batch, y_batch in train_loader:
|
| 276 |
+
X_batch = X_batch.to(device)
|
| 277 |
+
y_batch = y_batch.to(device)
|
| 278 |
+
|
| 279 |
+
outputs = model(X_batch)
|
| 280 |
+
loss = criterion(outputs, y_batch)
|
| 281 |
+
|
| 282 |
+
optimizer.zero_grad()
|
| 283 |
+
loss.backward()
|
| 284 |
+
optimizer.step()
|
| 285 |
+
|
| 286 |
+
# Validate
|
| 287 |
+
model.eval()
|
| 288 |
+
total_val_loss = 0.0
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
for X_batch, y_batch in val_loader:
|
| 291 |
+
X_batch = X_batch.to(device)
|
| 292 |
+
y_batch = y_batch.to(device)
|
| 293 |
+
outputs = model(X_batch)
|
| 294 |
+
loss = criterion(outputs, y_batch)
|
| 295 |
+
total_val_loss += loss.item()
|
| 296 |
+
val_loss = total_val_loss / len(val_loader)
|
| 297 |
+
|
| 298 |
+
if val_loss < best_val_loss:
|
| 299 |
+
best_val_loss = val_loss
|
| 300 |
+
patience_counter = 0
|
| 301 |
+
best_state = model.state_dict().copy()
|
| 302 |
+
else:
|
| 303 |
+
patience_counter += 1
|
| 304 |
+
if patience_counter >= patience:
|
| 305 |
+
break
|
| 306 |
+
|
| 307 |
+
# Load best state
|
| 308 |
+
if best_state is not None:
|
| 309 |
+
model.load_state_dict(best_state)
|
| 310 |
+
|
| 311 |
+
return model, input_params, output_params, best_val_loss
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def evaluate_on_real_munsell(
|
| 315 |
+
model: nn.Module,
|
| 316 |
+
input_params: dict[str, Any],
|
| 317 |
+
output_params: dict[str, Any],
|
| 318 |
+
xyY_array: NDArray,
|
| 319 |
+
reference_Lab: NDArray,
|
| 320 |
+
device: torch.device,
|
| 321 |
+
) -> tuple[float, float]:
|
| 322 |
+
"""
|
| 323 |
+
Evaluate model on real Munsell colors using Delta-E CIE2000.
|
| 324 |
+
|
| 325 |
+
Parameters
|
| 326 |
+
----------
|
| 327 |
+
model : nn.Module
|
| 328 |
+
Trained model to evaluate.
|
| 329 |
+
input_params : dict
|
| 330 |
+
Input normalization parameters.
|
| 331 |
+
output_params : dict
|
| 332 |
+
Output normalization parameters.
|
| 333 |
+
xyY_array : ndarray
|
| 334 |
+
Real Munsell xyY values.
|
| 335 |
+
reference_Lab : ndarray
|
| 336 |
+
Reference CIELAB values for Delta-E computation.
|
| 337 |
+
device : torch.device
|
| 338 |
+
Device to run evaluation on.
|
| 339 |
+
|
| 340 |
+
Returns
|
| 341 |
+
-------
|
| 342 |
+
float
|
| 343 |
+
Mean Delta-E CIE2000.
|
| 344 |
+
float
|
| 345 |
+
Median Delta-E CIE2000.
|
| 346 |
+
"""
|
| 347 |
+
model.eval()
|
| 348 |
+
gamma = input_params["gamma"]
|
| 349 |
+
|
| 350 |
+
# Normalize inputs
|
| 351 |
+
X_norm, _ = normalize_inputs(xyY_array, gamma=gamma)
|
| 352 |
+
X_t = torch.FloatTensor(X_norm).to(device)
|
| 353 |
+
|
| 354 |
+
# Predict
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
pred_norm = model(X_t).cpu().numpy()
|
| 357 |
+
|
| 358 |
+
pred = denormalize_output(pred_norm, output_params)
|
| 359 |
+
delta_E_values = compute_delta_e(pred, reference_Lab)
|
| 360 |
+
|
| 361 |
+
return np.mean(delta_E_values), np.median(delta_E_values)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def run_single_trial(
|
| 365 |
+
gamma: float,
|
| 366 |
+
run_id: int,
|
| 367 |
+
X_train: NDArray,
|
| 368 |
+
y_train: NDArray,
|
| 369 |
+
X_val: NDArray,
|
| 370 |
+
y_val: NDArray,
|
| 371 |
+
xyY_array: NDArray,
|
| 372 |
+
reference_Lab: NDArray,
|
| 373 |
+
) -> dict[str, Any]:
|
| 374 |
+
"""
|
| 375 |
+
Run a single training trial for a given gamma value.
|
| 376 |
+
|
| 377 |
+
Parameters
|
| 378 |
+
----------
|
| 379 |
+
gamma : float
|
| 380 |
+
Gamma value for Y correction.
|
| 381 |
+
run_id : int
|
| 382 |
+
Run identifier for this trial.
|
| 383 |
+
X_train : ndarray
|
| 384 |
+
Training inputs.
|
| 385 |
+
y_train : ndarray
|
| 386 |
+
Training targets.
|
| 387 |
+
X_val : ndarray
|
| 388 |
+
Validation inputs.
|
| 389 |
+
y_val : ndarray
|
| 390 |
+
Validation targets.
|
| 391 |
+
xyY_array : ndarray
|
| 392 |
+
Real Munsell xyY values for evaluation.
|
| 393 |
+
reference_Lab : ndarray
|
| 394 |
+
Reference CIELAB values for Delta-E computation.
|
| 395 |
+
|
| 396 |
+
Returns
|
| 397 |
+
-------
|
| 398 |
+
dict
|
| 399 |
+
Results dictionary containing gamma, run_id, val_loss,
|
| 400 |
+
mean_delta_e, and median_delta_e.
|
| 401 |
+
|
| 402 |
+
Notes
|
| 403 |
+
-----
|
| 404 |
+
Uses CPU to avoid MPS multiprocessing issues.
|
| 405 |
+
"""
|
| 406 |
+
# Each process uses CPU to avoid MPS multiprocessing issues
|
| 407 |
+
device = torch.device("cpu")
|
| 408 |
+
|
| 409 |
+
model, input_params, output_params, val_loss = train_model(
|
| 410 |
+
gamma=gamma,
|
| 411 |
+
X_train=X_train,
|
| 412 |
+
y_train=y_train,
|
| 413 |
+
X_val=X_val,
|
| 414 |
+
y_val=y_val,
|
| 415 |
+
device=device,
|
| 416 |
+
num_epochs=100,
|
| 417 |
+
patience=15,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
mean_delta_e, median_delta_e = evaluate_on_real_munsell(
|
| 421 |
+
model, input_params, output_params, xyY_array, reference_Lab, device
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return {
|
| 425 |
+
"gamma": gamma,
|
| 426 |
+
"run_id": run_id,
|
| 427 |
+
"val_loss": val_loss,
|
| 428 |
+
"mean_delta_e": mean_delta_e,
|
| 429 |
+
"median_delta_e": median_delta_e,
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def main() -> None:
|
| 434 |
+
"""
|
| 435 |
+
Run gamma sweep experiment to find optimal gamma value.
|
| 436 |
+
|
| 437 |
+
Notes
|
| 438 |
+
-----
|
| 439 |
+
The training pipeline:
|
| 440 |
+
1. Loads training and validation data from cache
|
| 441 |
+
2. Loads real Munsell colors for evaluation
|
| 442 |
+
3. Sweeps gamma values from 1.0 to 3.0 in 0.1 increments
|
| 443 |
+
4. Trains multiple models per gamma value for averaging
|
| 444 |
+
5. Evaluates each model on real Munsell colors using Delta-E CIE2000
|
| 445 |
+
6. Aggregates results and identifies best gamma value
|
| 446 |
+
7. Saves results to NPZ file for analysis
|
| 447 |
+
|
| 448 |
+
Uses parallel execution with ProcessPoolExecutor for efficiency.
|
| 449 |
+
Each model is trained with early stopping and evaluated on validation set.
|
| 450 |
+
"""
|
| 451 |
+
import argparse
|
| 452 |
+
|
| 453 |
+
parser = argparse.ArgumentParser(description="Gamma sweep with averaging")
|
| 454 |
+
parser.add_argument("--runs", type=int, default=3, help="Number of runs per gamma")
|
| 455 |
+
parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers")
|
| 456 |
+
args = parser.parse_args()
|
| 457 |
+
|
| 458 |
+
num_runs = args.runs
|
| 459 |
+
num_workers = args.workers
|
| 460 |
+
|
| 461 |
+
LOGGER.info("=" * 80)
|
| 462 |
+
LOGGER.info("Multi-Head Gamma Sweep: Finding Optimal Gamma Value")
|
| 463 |
+
LOGGER.info("Testing gamma values from 1.0 to 3.0 in increments of 0.1")
|
| 464 |
+
LOGGER.info("Runs per gamma: %d, Parallel workers: %d", num_runs, num_workers)
|
| 465 |
+
LOGGER.info("=" * 80)
|
| 466 |
+
|
| 467 |
+
# Load training data
|
| 468 |
+
data_dir = PROJECT_ROOT / "data"
|
| 469 |
+
cache_file = data_dir / "training_data.npz"
|
| 470 |
+
|
| 471 |
+
if not cache_file.exists():
|
| 472 |
+
LOGGER.error("Error: Training data not found at %s", cache_file)
|
| 473 |
+
return
|
| 474 |
+
|
| 475 |
+
LOGGER.info("\nLoading training data...")
|
| 476 |
+
data = np.load(cache_file)
|
| 477 |
+
X_train = data["X_train"]
|
| 478 |
+
y_train = data["y_train"]
|
| 479 |
+
X_val = data["X_val"]
|
| 480 |
+
y_val = data["y_val"]
|
| 481 |
+
LOGGER.info("Train samples: %d, Validation samples: %d", len(X_train), len(X_val))
|
| 482 |
+
|
| 483 |
+
# Load real Munsell data for evaluation
|
| 484 |
+
LOGGER.info("Loading real Munsell colours for evaluation...")
|
| 485 |
+
xyY_values = []
|
| 486 |
+
reference_Lab = []
|
| 487 |
+
|
| 488 |
+
for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
|
| 489 |
+
try:
|
| 490 |
+
xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
|
| 491 |
+
XYZ = xyY_to_XYZ(xyY_scaled)
|
| 492 |
+
Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
|
| 493 |
+
xyY_values.append(xyY_scaled)
|
| 494 |
+
reference_Lab.append(Lab)
|
| 495 |
+
except (RuntimeError, ValueError):
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
xyY_array = np.array(xyY_values)
|
| 499 |
+
reference_Lab = np.array(reference_Lab)
|
| 500 |
+
LOGGER.info("Loaded %d real Munsell colours", len(xyY_array))
|
| 501 |
+
|
| 502 |
+
# Gamma values to test
|
| 503 |
+
gamma_values = [round(1.0 + i * 0.1, 1) for i in range(21)] # 1.0 to 3.0
|
| 504 |
+
|
| 505 |
+
# Create all tasks: (gamma, run_id) pairs
|
| 506 |
+
tasks = [(gamma, run_id) for gamma in gamma_values for run_id in range(num_runs)]
|
| 507 |
+
total_tasks = len(tasks)
|
| 508 |
+
|
| 509 |
+
LOGGER.info("\n" + "-" * 80)
|
| 510 |
+
LOGGER.info("Starting gamma sweep: %d total tasks (%d gamma values x %d runs)",
|
| 511 |
+
total_tasks, len(gamma_values), num_runs)
|
| 512 |
+
LOGGER.info("-" * 80)
|
| 513 |
+
|
| 514 |
+
all_results = []
|
| 515 |
+
completed = 0
|
| 516 |
+
|
| 517 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 518 |
+
futures = {
|
| 519 |
+
executor.submit(
|
| 520 |
+
run_single_trial, gamma, run_id,
|
| 521 |
+
X_train, y_train, X_val, y_val, xyY_array, reference_Lab
|
| 522 |
+
): (gamma, run_id)
|
| 523 |
+
for gamma, run_id in tasks
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
for future in as_completed(futures):
|
| 527 |
+
gamma, run_id = futures[future]
|
| 528 |
+
try:
|
| 529 |
+
result = future.result()
|
| 530 |
+
all_results.append(result)
|
| 531 |
+
completed += 1
|
| 532 |
+
LOGGER.info(
|
| 533 |
+
"[%3d/%3d] gamma=%.1f run=%d: mean_ΔE=%.4f, median_ΔE=%.4f",
|
| 534 |
+
completed, total_tasks, gamma, run_id,
|
| 535 |
+
result["mean_delta_e"], result["median_delta_e"]
|
| 536 |
+
)
|
| 537 |
+
except Exception as e:
|
| 538 |
+
LOGGER.error("Task failed for gamma=%.1f run=%d: %s", gamma, run_id, e)
|
| 539 |
+
completed += 1
|
| 540 |
+
|
| 541 |
+
# Aggregate results by gamma (average across runs)
|
| 542 |
+
aggregated = {}
|
| 543 |
+
for r in all_results:
|
| 544 |
+
gamma = r["gamma"]
|
| 545 |
+
if gamma not in aggregated:
|
| 546 |
+
aggregated[gamma] = {"val_losses": [], "means": [], "medians": []}
|
| 547 |
+
aggregated[gamma]["val_losses"].append(r["val_loss"])
|
| 548 |
+
aggregated[gamma]["means"].append(r["mean_delta_e"])
|
| 549 |
+
aggregated[gamma]["medians"].append(r["median_delta_e"])
|
| 550 |
+
|
| 551 |
+
results = []
|
| 552 |
+
for gamma in sorted(aggregated.keys()):
|
| 553 |
+
agg = aggregated[gamma]
|
| 554 |
+
results.append({
|
| 555 |
+
"gamma": gamma,
|
| 556 |
+
"val_loss": np.mean(agg["val_losses"]),
|
| 557 |
+
"val_loss_std": np.std(agg["val_losses"]),
|
| 558 |
+
"mean_delta_e": np.mean(agg["means"]),
|
| 559 |
+
"mean_delta_e_std": np.std(agg["means"]),
|
| 560 |
+
"median_delta_e": np.mean(agg["medians"]),
|
| 561 |
+
"median_delta_e_std": np.std(agg["medians"]),
|
| 562 |
+
"num_runs": len(agg["means"]),
|
| 563 |
+
})
|
| 564 |
+
|
| 565 |
+
# Print results
|
| 566 |
+
LOGGER.info("\n" + "=" * 80)
|
| 567 |
+
LOGGER.info("GAMMA SWEEP RESULTS (averaged over %d runs)", num_runs)
|
| 568 |
+
LOGGER.info("=" * 80)
|
| 569 |
+
LOGGER.info("")
|
| 570 |
+
LOGGER.info("%-8s %-14s %-14s %-14s", "Gamma", "Val Loss", "Mean ΔE", "Median ΔE")
|
| 571 |
+
LOGGER.info("-" * 50)
|
| 572 |
+
|
| 573 |
+
for r in results:
|
| 574 |
+
LOGGER.info(
|
| 575 |
+
"%-8.1f %-14s %-14s %-14s",
|
| 576 |
+
r["gamma"],
|
| 577 |
+
f"{r['val_loss']:.6f}±{r['val_loss_std']:.4f}",
|
| 578 |
+
f"{r['mean_delta_e']:.4f}±{r['mean_delta_e_std']:.4f}",
|
| 579 |
+
f"{r['median_delta_e']:.4f}±{r['median_delta_e_std']:.4f}",
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Find best by mean Delta-E
|
| 583 |
+
best_by_mean = min(results, key=lambda x: x["mean_delta_e"])
|
| 584 |
+
best_by_median = min(results, key=lambda x: x["median_delta_e"])
|
| 585 |
+
|
| 586 |
+
LOGGER.info("")
|
| 587 |
+
LOGGER.info("Best gamma by MEAN Delta-E: %.1f (ΔE = %.4f ± %.4f)",
|
| 588 |
+
best_by_mean["gamma"], best_by_mean["mean_delta_e"],
|
| 589 |
+
best_by_mean["mean_delta_e_std"])
|
| 590 |
+
LOGGER.info("Best gamma by MEDIAN Delta-E: %.1f (ΔE = %.4f ± %.4f)",
|
| 591 |
+
best_by_median["gamma"], best_by_median["median_delta_e"],
|
| 592 |
+
best_by_median["median_delta_e_std"])
|
| 593 |
+
|
| 594 |
+
# Save results
|
| 595 |
+
results_file = PROJECT_ROOT / "models" / "from_xyY" / "gamma_sweep_results_averaged.npz"
|
| 596 |
+
np.savez(results_file, results=results, all_results=all_results)
|
| 597 |
+
LOGGER.info("\nResults saved to: %s", results_file)
|
| 598 |
+
|
| 599 |
+
LOGGER.info("\n" + "=" * 80)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
if __name__ == "__main__":
|
| 603 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 604 |
+
|
| 605 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_large.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train multi-head ML model on large dataset (2M samples) for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
This script trains on the larger dataset for potentially improved accuracy.
|
| 5 |
+
Uses the same architecture as train_multi_head_mlp.py but with the large dataset.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import click
|
| 11 |
+
import mlflow
|
| 12 |
+
import mlflow.pytorch
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from numpy.typing import NDArray
|
| 16 |
+
from torch import nn, optim
|
| 17 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 18 |
+
|
| 19 |
+
from learning_munsell import PROJECT_ROOT
|
| 20 |
+
from learning_munsell.models.networks import MultiHeadMLPToMunsell
|
| 21 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 22 |
+
from learning_munsell.utilities.data import (
|
| 23 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 24 |
+
XYY_NORMALIZATION_PARAMS,
|
| 25 |
+
normalize_munsell,
|
| 26 |
+
)
|
| 27 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 28 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 29 |
+
|
| 30 |
+
LOGGER = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@click.command()
|
| 34 |
+
@click.option("--epochs", default=300, help="Number of training epochs")
|
| 35 |
+
@click.option("--batch-size", default=2048, help="Batch size for training")
|
| 36 |
+
@click.option("--lr", default=5e-4, help="Learning rate")
|
| 37 |
+
@click.option("--patience", default=30, help="Early stopping patience")
|
| 38 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Train multi-head MLP on large dataset (2M samples) for xyY to Munsell.
|
| 41 |
+
|
| 42 |
+
Notes
|
| 43 |
+
-----
|
| 44 |
+
The training pipeline:
|
| 45 |
+
1. Loads training and validation data from large cached .npz file
|
| 46 |
+
2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1]
|
| 47 |
+
3. Creates multi-head MLP with shared encoder and component-specific heads
|
| 48 |
+
4. Trains with weighted MSE loss (emphasizing chroma)
|
| 49 |
+
5. Uses Adam optimizer with ReduceLROnPlateau scheduler
|
| 50 |
+
6. Applies early stopping based on validation loss (patience=30)
|
| 51 |
+
7. Exports best model to ONNX format
|
| 52 |
+
8. Logs metrics and artifacts to MLflow
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
LOGGER.info("=" * 80)
|
| 56 |
+
LOGGER.info("Multi-Head Model Training on Large Dataset (2M samples)")
|
| 57 |
+
LOGGER.info("=" * 80)
|
| 58 |
+
|
| 59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
if torch.backends.mps.is_available():
|
| 61 |
+
device = torch.device("mps")
|
| 62 |
+
LOGGER.info("Using device: %s", device)
|
| 63 |
+
|
| 64 |
+
# Load large training data
|
| 65 |
+
data_dir = PROJECT_ROOT / "data"
|
| 66 |
+
cache_file = data_dir / "training_data_large.npz"
|
| 67 |
+
|
| 68 |
+
if not cache_file.exists():
|
| 69 |
+
LOGGER.error("Error: Large training data not found at %s", cache_file)
|
| 70 |
+
LOGGER.error("Please run generate_large_training_data.py first")
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
LOGGER.info("Loading large training data from %s...", cache_file)
|
| 74 |
+
data = np.load(cache_file)
|
| 75 |
+
|
| 76 |
+
X_train = data["X_train"]
|
| 77 |
+
y_train = data["y_train"]
|
| 78 |
+
X_val = data["X_val"]
|
| 79 |
+
y_val = data["y_val"]
|
| 80 |
+
|
| 81 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 82 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 83 |
+
|
| 84 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 85 |
+
# Use shared normalization parameters covering the full Munsell space for generalization
|
| 86 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 87 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 88 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 89 |
+
|
| 90 |
+
# Convert to PyTorch tensors
|
| 91 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 92 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 93 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 94 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 95 |
+
|
| 96 |
+
# Create data loaders (larger batch size for larger dataset)
|
| 97 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 98 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 99 |
+
|
| 100 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 101 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 102 |
+
|
| 103 |
+
# Initialize model
|
| 104 |
+
model = MultiHeadMLPToMunsell().to(device)
|
| 105 |
+
LOGGER.info("")
|
| 106 |
+
LOGGER.info("Model architecture:")
|
| 107 |
+
LOGGER.info("%s", model)
|
| 108 |
+
|
| 109 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 110 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 111 |
+
|
| 112 |
+
# Training setup
|
| 113 |
+
learning_rate = lr
|
| 114 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 115 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 116 |
+
optimizer, mode="min", factor=0.5, patience=10
|
| 117 |
+
)
|
| 118 |
+
criterion = weighted_mse_loss
|
| 119 |
+
|
| 120 |
+
# MLflow setup
|
| 121 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head_large")
|
| 122 |
+
|
| 123 |
+
LOGGER.info("")
|
| 124 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 125 |
+
|
| 126 |
+
# Training loop
|
| 127 |
+
best_val_loss = float("inf")
|
| 128 |
+
patience_counter = 0
|
| 129 |
+
|
| 130 |
+
LOGGER.info("")
|
| 131 |
+
LOGGER.info("Starting training...")
|
| 132 |
+
|
| 133 |
+
with mlflow.start_run(run_name=run_name):
|
| 134 |
+
mlflow.log_params(
|
| 135 |
+
{
|
| 136 |
+
"model": "multi_head_large",
|
| 137 |
+
"learning_rate": learning_rate,
|
| 138 |
+
"batch_size": batch_size,
|
| 139 |
+
"num_epochs": epochs,
|
| 140 |
+
"patience": patience,
|
| 141 |
+
"total_params": total_params,
|
| 142 |
+
"train_samples": len(X_train),
|
| 143 |
+
"val_samples": len(X_val),
|
| 144 |
+
"dataset": "large_2M",
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
for epoch in range(epochs):
|
| 149 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 150 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 151 |
+
|
| 152 |
+
scheduler.step(val_loss)
|
| 153 |
+
|
| 154 |
+
log_training_epoch(
|
| 155 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
LOGGER.info(
|
| 159 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 160 |
+
epoch + 1,
|
| 161 |
+
epochs,
|
| 162 |
+
train_loss,
|
| 163 |
+
val_loss,
|
| 164 |
+
optimizer.param_groups[0]["lr"],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if val_loss < best_val_loss:
|
| 168 |
+
best_val_loss = val_loss
|
| 169 |
+
patience_counter = 0
|
| 170 |
+
|
| 171 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 172 |
+
model_directory.mkdir(exist_ok=True)
|
| 173 |
+
checkpoint_file = model_directory / "multi_head_large_best.pth"
|
| 174 |
+
|
| 175 |
+
torch.save(
|
| 176 |
+
{
|
| 177 |
+
"model_state_dict": model.state_dict(),
|
| 178 |
+
"output_params": output_params,
|
| 179 |
+
"epoch": epoch,
|
| 180 |
+
"val_loss": val_loss,
|
| 181 |
+
},
|
| 182 |
+
checkpoint_file,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 186 |
+
else:
|
| 187 |
+
patience_counter += 1
|
| 188 |
+
if patience_counter >= patience:
|
| 189 |
+
LOGGER.info("")
|
| 190 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
mlflow.log_metrics(
|
| 194 |
+
{
|
| 195 |
+
"best_val_loss": best_val_loss,
|
| 196 |
+
"final_epoch": epoch + 1,
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Export to ONNX
|
| 201 |
+
LOGGER.info("")
|
| 202 |
+
LOGGER.info("Exporting model to ONNX...")
|
| 203 |
+
model.eval()
|
| 204 |
+
|
| 205 |
+
checkpoint = torch.load(checkpoint_file, weights_only=False)
|
| 206 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 207 |
+
|
| 208 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 209 |
+
|
| 210 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 211 |
+
onnx_file = model_directory / "multi_head_large.onnx"
|
| 212 |
+
torch.onnx.export(
|
| 213 |
+
model,
|
| 214 |
+
dummy_input,
|
| 215 |
+
onnx_file,
|
| 216 |
+
export_params=True,
|
| 217 |
+
opset_version=15,
|
| 218 |
+
input_names=["xyY"],
|
| 219 |
+
output_names=["munsell_spec"],
|
| 220 |
+
dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
params_file = model_directory / "multi_head_large_normalization_params.npz"
|
| 224 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 225 |
+
np.savez(
|
| 226 |
+
params_file,
|
| 227 |
+
input_params=input_params,
|
| 228 |
+
output_params=output_params,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 232 |
+
mlflow.log_artifact(str(onnx_file))
|
| 233 |
+
mlflow.log_artifact(str(params_file))
|
| 234 |
+
mlflow.pytorch.log_model(model, "model")
|
| 235 |
+
|
| 236 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 237 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 238 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 239 |
+
|
| 240 |
+
LOGGER.info("=" * 80)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 245 |
+
|
| 246 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_mlp.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train multi-head ML model for xyY to Munsell conversion.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- Shared encoder: 3 inputs → 512-dim features
|
| 6 |
+
- 4 separate heads (one per component):
|
| 7 |
+
- Hue head (circular/angular)
|
| 8 |
+
- Value head (linear lightness)
|
| 9 |
+
- Chroma head (non-linear saturation - larger capacity)
|
| 10 |
+
- Code head (discrete categorical)
|
| 11 |
+
|
| 12 |
+
This architecture allows each component to learn specialized features
|
| 13 |
+
while sharing the general color space understanding.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import click
|
| 18 |
+
import mlflow
|
| 19 |
+
import mlflow.pytorch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from numpy.typing import NDArray
|
| 23 |
+
from torch import nn, optim
|
| 24 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 25 |
+
|
| 26 |
+
from learning_munsell import PROJECT_ROOT
|
| 27 |
+
from learning_munsell.models.networks import MultiHeadMLPToMunsell
|
| 28 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 29 |
+
from learning_munsell.utilities.data import (
|
| 30 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 31 |
+
XYY_NORMALIZATION_PARAMS,
|
| 32 |
+
normalize_munsell,
|
| 33 |
+
)
|
| 34 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 35 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 36 |
+
|
| 37 |
+
LOGGER = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@click.command()
|
| 41 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 42 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 43 |
+
@click.option("--lr", default=5e-4, help="Learning rate")
|
| 44 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 45 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Train multi-head MLP for xyY to Munsell conversion.
|
| 48 |
+
|
| 49 |
+
Notes
|
| 50 |
+
-----
|
| 51 |
+
The training pipeline:
|
| 52 |
+
1. Loads training and validation data from cached .npz file
|
| 53 |
+
2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1]
|
| 54 |
+
3. Creates multi-head MLP with shared encoder and component-specific heads
|
| 55 |
+
4. Trains with weighted MSE loss (emphasizing chroma)
|
| 56 |
+
5. Uses Adam optimizer with no learning rate scheduling
|
| 57 |
+
6. Applies early stopping based on validation loss (patience=20)
|
| 58 |
+
7. Exports best model to ONNX format
|
| 59 |
+
8. Logs metrics and artifacts to MLflow
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
LOGGER.info("=" * 80)
|
| 64 |
+
LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Model Training")
|
| 65 |
+
LOGGER.info("=" * 80)
|
| 66 |
+
|
| 67 |
+
# Set device
|
| 68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 69 |
+
LOGGER.info("Using device: %s", device)
|
| 70 |
+
|
| 71 |
+
# Load training data
|
| 72 |
+
data_dir = PROJECT_ROOT / "data"
|
| 73 |
+
cache_file = data_dir / "training_data.npz"
|
| 74 |
+
|
| 75 |
+
if not cache_file.exists():
|
| 76 |
+
LOGGER.error("Error: Training data not found at %s", cache_file)
|
| 77 |
+
LOGGER.error("Please run 01_generate_training_data.py first")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 81 |
+
data = np.load(cache_file)
|
| 82 |
+
|
| 83 |
+
X_train = data["X_train"]
|
| 84 |
+
y_train = data["y_train"]
|
| 85 |
+
X_val = data["X_val"]
|
| 86 |
+
y_val = data["y_val"]
|
| 87 |
+
|
| 88 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 89 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 90 |
+
|
| 91 |
+
# Normalize outputs (xyY inputs are already in [0, 1] range)
|
| 92 |
+
# Use shared normalization parameters covering the full Munsell space for generalization
|
| 93 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 94 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 95 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 96 |
+
|
| 97 |
+
# Convert to PyTorch tensors
|
| 98 |
+
X_train_t = torch.FloatTensor(X_train)
|
| 99 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 100 |
+
X_val_t = torch.FloatTensor(X_val)
|
| 101 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 102 |
+
|
| 103 |
+
# Create data loaders
|
| 104 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 105 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 106 |
+
|
| 107 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 108 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 109 |
+
|
| 110 |
+
# Initialize model
|
| 111 |
+
model = MultiHeadMLPToMunsell().to(device)
|
| 112 |
+
LOGGER.info("")
|
| 113 |
+
LOGGER.info("Model architecture:")
|
| 114 |
+
LOGGER.info("%s", model)
|
| 115 |
+
|
| 116 |
+
# Count parameters
|
| 117 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 118 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 119 |
+
|
| 120 |
+
# Count parameters per component
|
| 121 |
+
encoder_params = sum(p.numel() for p in model.encoder.parameters())
|
| 122 |
+
hue_params = sum(p.numel() for p in model.hue_head.parameters())
|
| 123 |
+
value_params = sum(p.numel() for p in model.value_head.parameters())
|
| 124 |
+
chroma_params = sum(p.numel() for p in model.chroma_head.parameters())
|
| 125 |
+
code_params = sum(p.numel() for p in model.code_head.parameters())
|
| 126 |
+
|
| 127 |
+
LOGGER.info(" - Shared encoder: %s", f"{encoder_params:,}")
|
| 128 |
+
LOGGER.info(" - Hue head: %s", f"{hue_params:,}")
|
| 129 |
+
LOGGER.info(" - Value head: %s", f"{value_params:,}")
|
| 130 |
+
LOGGER.info(" - Chroma head: %s (WIDER)", f"{chroma_params:,}")
|
| 131 |
+
LOGGER.info(" - Code head: %s", f"{code_params:,}")
|
| 132 |
+
|
| 133 |
+
# Training setup
|
| 134 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 135 |
+
# Use weighted MSE with default weights
|
| 136 |
+
weights = torch.tensor([1.0, 1.0, 3.0, 0.5])
|
| 137 |
+
criterion = lambda pred, target: weighted_mse_loss(pred, target, weights)
|
| 138 |
+
|
| 139 |
+
# MLflow setup
|
| 140 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head")
|
| 141 |
+
|
| 142 |
+
LOGGER.info("")
|
| 143 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 144 |
+
|
| 145 |
+
# Training loop
|
| 146 |
+
best_val_loss = float("inf")
|
| 147 |
+
patience_counter = 0
|
| 148 |
+
|
| 149 |
+
LOGGER.info("")
|
| 150 |
+
LOGGER.info("Starting training...")
|
| 151 |
+
|
| 152 |
+
with mlflow.start_run(run_name=run_name):
|
| 153 |
+
# Log parameters
|
| 154 |
+
mlflow.log_params(
|
| 155 |
+
{
|
| 156 |
+
"model": "multi_head",
|
| 157 |
+
"learning_rate": lr,
|
| 158 |
+
"batch_size": batch_size,
|
| 159 |
+
"num_epochs": epochs,
|
| 160 |
+
"patience": patience,
|
| 161 |
+
"total_params": total_params,
|
| 162 |
+
}
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
for epoch in range(epochs):
|
| 166 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 167 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 168 |
+
|
| 169 |
+
# Log to MLflow
|
| 170 |
+
log_training_epoch(
|
| 171 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
LOGGER.info(
|
| 175 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 176 |
+
epoch + 1,
|
| 177 |
+
epochs,
|
| 178 |
+
train_loss,
|
| 179 |
+
val_loss,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Early stopping
|
| 183 |
+
if val_loss < best_val_loss:
|
| 184 |
+
best_val_loss = val_loss
|
| 185 |
+
patience_counter = 0
|
| 186 |
+
|
| 187 |
+
# Save best model
|
| 188 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 189 |
+
model_directory.mkdir(exist_ok=True)
|
| 190 |
+
checkpoint_file = model_directory / "multi_head_best.pth"
|
| 191 |
+
|
| 192 |
+
torch.save(
|
| 193 |
+
{
|
| 194 |
+
"model_state_dict": model.state_dict(),
|
| 195 |
+
"output_params": output_params,
|
| 196 |
+
"epoch": epoch,
|
| 197 |
+
"val_loss": val_loss,
|
| 198 |
+
},
|
| 199 |
+
checkpoint_file,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 203 |
+
else:
|
| 204 |
+
patience_counter += 1
|
| 205 |
+
if patience_counter >= patience:
|
| 206 |
+
LOGGER.info("")
|
| 207 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
# Log final metrics
|
| 211 |
+
mlflow.log_metrics(
|
| 212 |
+
{
|
| 213 |
+
"best_val_loss": best_val_loss,
|
| 214 |
+
"final_epoch": epoch + 1,
|
| 215 |
+
}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Export to ONNX
|
| 219 |
+
LOGGER.info("")
|
| 220 |
+
LOGGER.info("Exporting model to ONNX...")
|
| 221 |
+
model.eval()
|
| 222 |
+
|
| 223 |
+
# Load best model
|
| 224 |
+
checkpoint = torch.load(checkpoint_file)
|
| 225 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 226 |
+
|
| 227 |
+
# Create dummy input
|
| 228 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 229 |
+
|
| 230 |
+
# Export
|
| 231 |
+
onnx_file = model_directory / "multi_head.onnx"
|
| 232 |
+
torch.onnx.export(
|
| 233 |
+
model,
|
| 234 |
+
dummy_input,
|
| 235 |
+
onnx_file,
|
| 236 |
+
export_params=True,
|
| 237 |
+
opset_version=15,
|
| 238 |
+
input_names=["xyY"],
|
| 239 |
+
output_names=["munsell_spec"],
|
| 240 |
+
dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Save normalization parameters alongside model
|
| 244 |
+
params_file = model_directory / "multi_head_normalization_params.npz"
|
| 245 |
+
input_params = XYY_NORMALIZATION_PARAMS
|
| 246 |
+
np.savez(
|
| 247 |
+
params_file,
|
| 248 |
+
input_params=input_params,
|
| 249 |
+
output_params=output_params,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Log artifacts to MLflow
|
| 253 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 254 |
+
mlflow.log_artifact(str(onnx_file))
|
| 255 |
+
mlflow.log_artifact(str(params_file))
|
| 256 |
+
mlflow.pytorch.log_model(model, "model")
|
| 257 |
+
|
| 258 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 259 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 260 |
+
LOGGER.info("Artifacts logged to MLflow")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
LOGGER.info("=" * 80)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 268 |
+
|
| 269 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Multi-Head error predictor for Multi-Head base model.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- 4 independent error correction branches (one per component)
|
| 6 |
+
- Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output
|
| 7 |
+
- Chroma branch: WIDER (1.5x capacity for hardest component)
|
| 8 |
+
|
| 9 |
+
Complete independence matches the Multi-Head base model philosophy.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
import click
|
| 17 |
+
import mlflow
|
| 18 |
+
import mlflow.pytorch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import onnxruntime as ort
|
| 21 |
+
import torch
|
| 22 |
+
from numpy.typing import NDArray
|
| 23 |
+
from torch import nn, optim
|
| 24 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 25 |
+
|
| 26 |
+
from learning_munsell import PROJECT_ROOT
|
| 27 |
+
from learning_munsell.models.networks import (
|
| 28 |
+
ComponentErrorPredictor,
|
| 29 |
+
MultiHeadErrorPredictorToMunsell,
|
| 30 |
+
)
|
| 31 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 32 |
+
from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
|
| 33 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 34 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 35 |
+
|
| 36 |
+
LOGGER = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_base_model(
|
| 40 |
+
model_path: Path, params_path: Path
|
| 41 |
+
) -> tuple[ort.InferenceSession, dict, dict]:
|
| 42 |
+
"""
|
| 43 |
+
Load Multi-Head base ONNX model and normalization parameters.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
model_path : Path
|
| 48 |
+
Path to Multi-Head base model ONNX file.
|
| 49 |
+
params_path : Path
|
| 50 |
+
Path to normalization parameters .npz file.
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
session : ort.InferenceSession
|
| 55 |
+
ONNX Runtime inference session.
|
| 56 |
+
input_params : dict
|
| 57 |
+
Input normalization ranges.
|
| 58 |
+
output_params : dict
|
| 59 |
+
Output normalization ranges.
|
| 60 |
+
"""
|
| 61 |
+
session = ort.InferenceSession(str(model_path))
|
| 62 |
+
params = np.load(params_path, allow_pickle=True)
|
| 63 |
+
return session, params["input_params"].item(), params["output_params"].item()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@click.command()
|
| 67 |
+
@click.option(
|
| 68 |
+
"--base-model",
|
| 69 |
+
type=click.Path(exists=True, path_type=Path),
|
| 70 |
+
default=None,
|
| 71 |
+
help="Path to Multi-Head base model ONNX file",
|
| 72 |
+
)
|
| 73 |
+
@click.option(
|
| 74 |
+
"--params",
|
| 75 |
+
type=click.Path(exists=True, path_type=Path),
|
| 76 |
+
default=None,
|
| 77 |
+
help="Path to normalization params file",
|
| 78 |
+
)
|
| 79 |
+
@click.option(
|
| 80 |
+
"--epochs",
|
| 81 |
+
type=int,
|
| 82 |
+
default=200,
|
| 83 |
+
help="Number of training epochs",
|
| 84 |
+
)
|
| 85 |
+
@click.option(
|
| 86 |
+
"--batch-size",
|
| 87 |
+
type=int,
|
| 88 |
+
default=1024,
|
| 89 |
+
help="Batch size for training",
|
| 90 |
+
)
|
| 91 |
+
@click.option(
|
| 92 |
+
"--lr",
|
| 93 |
+
type=float,
|
| 94 |
+
default=3e-4,
|
| 95 |
+
help="Learning rate",
|
| 96 |
+
)
|
| 97 |
+
@click.option(
|
| 98 |
+
"--patience",
|
| 99 |
+
type=int,
|
| 100 |
+
default=20,
|
| 101 |
+
help="Patience for early stopping",
|
| 102 |
+
)
|
| 103 |
+
def main(
|
| 104 |
+
base_model: Path | None,
|
| 105 |
+
params: Path | None,
|
| 106 |
+
epochs: int,
|
| 107 |
+
batch_size: int,
|
| 108 |
+
lr: float,
|
| 109 |
+
patience: int,
|
| 110 |
+
) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Train Multi-Head error predictor with 4 independent branches.
|
| 113 |
+
|
| 114 |
+
Parameters
|
| 115 |
+
----------
|
| 116 |
+
base_model : Path or None
|
| 117 |
+
Path to Multi-Head base model ONNX file. Uses default if None.
|
| 118 |
+
params : Path or None
|
| 119 |
+
Path to normalization parameters. Uses default if None.
|
| 120 |
+
|
| 121 |
+
Notes
|
| 122 |
+
-----
|
| 123 |
+
The training pipeline:
|
| 124 |
+
1. Loads pre-trained base model
|
| 125 |
+
2. Generates base model predictions for training data
|
| 126 |
+
3. Computes residual errors between predictions and targets
|
| 127 |
+
4. Trains error predictor on these residuals
|
| 128 |
+
5. Uses precision-focused loss function
|
| 129 |
+
6. Learning rate scheduling with ReduceLROnPlateau
|
| 130 |
+
7. Early stopping based on validation loss
|
| 131 |
+
8. Exports model to ONNX format
|
| 132 |
+
9. Logs metrics and artifacts to MLflow
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
LOGGER.info("=" * 80)
|
| 137 |
+
LOGGER.info("Multi-Head Error Predictor: 4 Independent Branches")
|
| 138 |
+
LOGGER.info("=" * 80)
|
| 139 |
+
|
| 140 |
+
# Set device
|
| 141 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 142 |
+
LOGGER.info("Using device: %s", device)
|
| 143 |
+
|
| 144 |
+
# Paths
|
| 145 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 146 |
+
data_dir = PROJECT_ROOT / "data"
|
| 147 |
+
|
| 148 |
+
# Use provided paths or defaults
|
| 149 |
+
if base_model is None:
|
| 150 |
+
base_model = model_directory / "multi_head.onnx"
|
| 151 |
+
if params is None:
|
| 152 |
+
params = model_directory / "multi_head_normalization_params.npz"
|
| 153 |
+
|
| 154 |
+
cache_file = data_dir / "training_data.npz"
|
| 155 |
+
|
| 156 |
+
# Load base model
|
| 157 |
+
LOGGER.info("")
|
| 158 |
+
LOGGER.info("Loading Multi-Head base model from %s...", base_model)
|
| 159 |
+
base_session, input_params, output_params = load_base_model(base_model, params)
|
| 160 |
+
|
| 161 |
+
# Load training data
|
| 162 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 163 |
+
data = np.load(cache_file)
|
| 164 |
+
X_train = data["X_train"]
|
| 165 |
+
y_train = data["y_train"]
|
| 166 |
+
X_val = data["X_val"]
|
| 167 |
+
y_val = data["y_val"]
|
| 168 |
+
|
| 169 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 170 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 171 |
+
|
| 172 |
+
# Generate base model predictions
|
| 173 |
+
LOGGER.info("")
|
| 174 |
+
LOGGER.info("Generating Multi-Head base model predictions...")
|
| 175 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 176 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 177 |
+
|
| 178 |
+
# Base predictions (normalized)
|
| 179 |
+
base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
|
| 180 |
+
|
| 181 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 182 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 183 |
+
base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
|
| 184 |
+
|
| 185 |
+
# Compute errors (in normalized space)
|
| 186 |
+
error_train = y_train_norm - base_pred_train_norm
|
| 187 |
+
error_val = y_val_norm - base_pred_val_norm
|
| 188 |
+
|
| 189 |
+
# Statistics
|
| 190 |
+
LOGGER.info("")
|
| 191 |
+
LOGGER.info("Multi-Head base model error statistics (normalized space):")
|
| 192 |
+
LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
|
| 193 |
+
LOGGER.info(" Std of error: %.6f", np.std(error_train))
|
| 194 |
+
LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
|
| 195 |
+
|
| 196 |
+
# Create combined input: [xyY_norm, base_prediction_norm]
|
| 197 |
+
X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
|
| 198 |
+
X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
|
| 199 |
+
|
| 200 |
+
# Convert to PyTorch tensors
|
| 201 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 202 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 203 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 204 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 205 |
+
|
| 206 |
+
# Create data loaders
|
| 207 |
+
train_dataset = TensorDataset(X_train_t, error_train_t)
|
| 208 |
+
val_dataset = TensorDataset(X_val_t, error_val_t)
|
| 209 |
+
|
| 210 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 211 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 212 |
+
|
| 213 |
+
# Initialize Multi-Head error predictor
|
| 214 |
+
model = MultiHeadErrorPredictorToMunsell().to(device)
|
| 215 |
+
LOGGER.info("")
|
| 216 |
+
LOGGER.info("Multi-Head error predictor architecture:")
|
| 217 |
+
LOGGER.info("%s", model)
|
| 218 |
+
|
| 219 |
+
# Count parameters
|
| 220 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 221 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 222 |
+
|
| 223 |
+
# Count parameters per branch
|
| 224 |
+
hue_params = sum(p.numel() for p in model.hue_branch.parameters())
|
| 225 |
+
value_params = sum(p.numel() for p in model.value_branch.parameters())
|
| 226 |
+
chroma_params = sum(p.numel() for p in model.chroma_branch.parameters())
|
| 227 |
+
code_params = sum(p.numel() for p in model.code_branch.parameters())
|
| 228 |
+
|
| 229 |
+
LOGGER.info(" - Hue branch: %s", f"{hue_params:,}")
|
| 230 |
+
LOGGER.info(" - Value branch: %s", f"{value_params:,}")
|
| 231 |
+
LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}")
|
| 232 |
+
LOGGER.info(" - Code branch: %s", f"{code_params:,}")
|
| 233 |
+
|
| 234 |
+
# Training setup with precision-focused loss
|
| 235 |
+
LOGGER.info("")
|
| 236 |
+
LOGGER.info("Using precision-focused loss function:")
|
| 237 |
+
LOGGER.info(" - MSE (weight: 1.0)")
|
| 238 |
+
LOGGER.info(" - MAE (weight: 0.5)")
|
| 239 |
+
LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
|
| 240 |
+
LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
|
| 241 |
+
|
| 242 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 243 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 244 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 245 |
+
)
|
| 246 |
+
criterion = precision_focused_loss
|
| 247 |
+
|
| 248 |
+
# MLflow setup
|
| 249 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head_multi_error_predictor")
|
| 250 |
+
|
| 251 |
+
LOGGER.info("")
|
| 252 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 253 |
+
|
| 254 |
+
# Training loop
|
| 255 |
+
best_val_loss = float("inf")
|
| 256 |
+
patience_counter = 0
|
| 257 |
+
|
| 258 |
+
LOGGER.info("")
|
| 259 |
+
LOGGER.info("Starting training...")
|
| 260 |
+
|
| 261 |
+
with mlflow.start_run(run_name=run_name):
|
| 262 |
+
# Log hyperparameters
|
| 263 |
+
mlflow.log_params(
|
| 264 |
+
{
|
| 265 |
+
"num_epochs": epochs,
|
| 266 |
+
"batch_size": batch_size,
|
| 267 |
+
"learning_rate": lr,
|
| 268 |
+
"weight_decay": 1e-5,
|
| 269 |
+
"optimizer": "AdamW",
|
| 270 |
+
"scheduler": "ReduceLROnPlateau",
|
| 271 |
+
"criterion": "precision_focused_loss",
|
| 272 |
+
"patience": patience,
|
| 273 |
+
"total_params": total_params,
|
| 274 |
+
}
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
for epoch in range(epochs):
|
| 278 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 279 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 280 |
+
|
| 281 |
+
# Update learning rate
|
| 282 |
+
scheduler.step(val_loss)
|
| 283 |
+
|
| 284 |
+
# Log to MLflow
|
| 285 |
+
log_training_epoch(
|
| 286 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
LOGGER.info(
|
| 290 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 291 |
+
epoch + 1,
|
| 292 |
+
epochs,
|
| 293 |
+
train_loss,
|
| 294 |
+
val_loss,
|
| 295 |
+
optimizer.param_groups[0]["lr"],
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Early stopping
|
| 299 |
+
if val_loss < best_val_loss:
|
| 300 |
+
best_val_loss = val_loss
|
| 301 |
+
patience_counter = 0
|
| 302 |
+
|
| 303 |
+
# Save best model
|
| 304 |
+
model_directory.mkdir(exist_ok=True)
|
| 305 |
+
checkpoint_file = (
|
| 306 |
+
model_directory / "multi_head_multi_error_predictor_best.pth"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
torch.save(
|
| 310 |
+
{
|
| 311 |
+
"model_state_dict": model.state_dict(),
|
| 312 |
+
"epoch": epoch,
|
| 313 |
+
"val_loss": val_loss,
|
| 314 |
+
},
|
| 315 |
+
checkpoint_file,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 319 |
+
else:
|
| 320 |
+
patience_counter += 1
|
| 321 |
+
if patience_counter >= patience:
|
| 322 |
+
LOGGER.info("")
|
| 323 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 324 |
+
break
|
| 325 |
+
|
| 326 |
+
# Log final metrics
|
| 327 |
+
mlflow.log_metrics(
|
| 328 |
+
{
|
| 329 |
+
"best_val_loss": best_val_loss,
|
| 330 |
+
"final_epoch": epoch + 1,
|
| 331 |
+
}
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Export to ONNX
|
| 335 |
+
LOGGER.info("")
|
| 336 |
+
LOGGER.info("Exporting Multi-Head error predictor to ONNX...")
|
| 337 |
+
model.eval()
|
| 338 |
+
|
| 339 |
+
# Load best model
|
| 340 |
+
checkpoint = torch.load(checkpoint_file)
|
| 341 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 342 |
+
|
| 343 |
+
# Create dummy input (xyY_norm + base_pred_norm = 7 inputs)
|
| 344 |
+
dummy_input = torch.randn(1, 7).to(device)
|
| 345 |
+
|
| 346 |
+
# Export
|
| 347 |
+
onnx_file = model_directory / "multi_head_multi_error_predictor.onnx"
|
| 348 |
+
torch.onnx.export(
|
| 349 |
+
model,
|
| 350 |
+
dummy_input,
|
| 351 |
+
onnx_file,
|
| 352 |
+
export_params=True,
|
| 353 |
+
opset_version=15,
|
| 354 |
+
input_names=["combined_input"],
|
| 355 |
+
output_names=["error_correction"],
|
| 356 |
+
dynamic_axes={
|
| 357 |
+
"combined_input": {0: "batch_size"},
|
| 358 |
+
"error_correction": {0: "batch_size"},
|
| 359 |
+
},
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
LOGGER.info("Multi-Head error predictor ONNX model saved to: %s", onnx_file)
|
| 363 |
+
|
| 364 |
+
# Log artifacts
|
| 365 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 366 |
+
mlflow.log_artifact(str(onnx_file))
|
| 367 |
+
|
| 368 |
+
# Log model
|
| 369 |
+
mlflow.pytorch.log_model(model, "model")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
LOGGER.info("=" * 80)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 377 |
+
|
| 378 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Multi-Head error predictor on large dataset (2M samples).
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- 4 independent error correction branches (one per component)
|
| 6 |
+
- Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output
|
| 7 |
+
- Chroma branch: WIDER (1.5x capacity for hardest component)
|
| 8 |
+
|
| 9 |
+
Uses the large dataset for improved model training.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import click
|
| 15 |
+
import mlflow
|
| 16 |
+
import mlflow.pytorch
|
| 17 |
+
import numpy as np
|
| 18 |
+
import onnxruntime as ort
|
| 19 |
+
import torch
|
| 20 |
+
from numpy.typing import NDArray
|
| 21 |
+
from torch import nn, optim
|
| 22 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 23 |
+
|
| 24 |
+
from learning_munsell import PROJECT_ROOT
|
| 25 |
+
from learning_munsell.models.networks import (
|
| 26 |
+
ComponentErrorPredictor,
|
| 27 |
+
MultiHeadErrorPredictorToMunsell,
|
| 28 |
+
)
|
| 29 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 30 |
+
from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
|
| 31 |
+
from learning_munsell.utilities.losses import precision_focused_loss
|
| 32 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 33 |
+
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_base_model(
|
| 38 |
+
model_path: Path, params_path: Path
|
| 39 |
+
) -> tuple[ort.InferenceSession, dict, dict]:
|
| 40 |
+
"""
|
| 41 |
+
Load the base ONNX model and normalization parameters.
|
| 42 |
+
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
model_path : Path
|
| 46 |
+
Path to the ONNX model file.
|
| 47 |
+
params_path : Path
|
| 48 |
+
Path to the normalization parameters file (.npz).
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
session : ort.InferenceSession
|
| 53 |
+
ONNX Runtime inference session.
|
| 54 |
+
input_params : dict
|
| 55 |
+
Input normalization parameters.
|
| 56 |
+
output_params : dict
|
| 57 |
+
Output normalization parameters.
|
| 58 |
+
"""
|
| 59 |
+
session = ort.InferenceSession(str(model_path))
|
| 60 |
+
params = np.load(params_path, allow_pickle=True)
|
| 61 |
+
return session, params["input_params"].item(), params["output_params"].item()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@click.command()
|
| 65 |
+
@click.option(
|
| 66 |
+
"--base-model",
|
| 67 |
+
type=click.Path(exists=True, path_type=Path),
|
| 68 |
+
default=None,
|
| 69 |
+
help="Path to Multi-Head large base model ONNX file",
|
| 70 |
+
)
|
| 71 |
+
@click.option(
|
| 72 |
+
"--params",
|
| 73 |
+
type=click.Path(exists=True, path_type=Path),
|
| 74 |
+
default=None,
|
| 75 |
+
help="Path to normalization params file",
|
| 76 |
+
)
|
| 77 |
+
@click.option(
|
| 78 |
+
"--output-suffix",
|
| 79 |
+
type=str,
|
| 80 |
+
default="large",
|
| 81 |
+
help="Suffix for output filenames (default: 'large')",
|
| 82 |
+
)
|
| 83 |
+
@click.option(
|
| 84 |
+
"--epochs",
|
| 85 |
+
type=int,
|
| 86 |
+
default=300,
|
| 87 |
+
help="Number of training epochs (default: 300)",
|
| 88 |
+
)
|
| 89 |
+
@click.option(
|
| 90 |
+
"--batch-size",
|
| 91 |
+
type=int,
|
| 92 |
+
default=2048,
|
| 93 |
+
help="Batch size for training (default: 2048)",
|
| 94 |
+
)
|
| 95 |
+
@click.option(
|
| 96 |
+
"--lr",
|
| 97 |
+
type=float,
|
| 98 |
+
default=3e-4,
|
| 99 |
+
help="Learning rate (default: 3e-4)",
|
| 100 |
+
)
|
| 101 |
+
@click.option(
|
| 102 |
+
"--patience",
|
| 103 |
+
type=int,
|
| 104 |
+
default=30,
|
| 105 |
+
help="Early stopping patience (default: 30)",
|
| 106 |
+
)
|
| 107 |
+
def main(
|
| 108 |
+
base_model: Path | None,
|
| 109 |
+
params: Path | None,
|
| 110 |
+
output_suffix: str,
|
| 111 |
+
epochs: int,
|
| 112 |
+
batch_size: int,
|
| 113 |
+
lr: float,
|
| 114 |
+
patience: int,
|
| 115 |
+
) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Train Multi-Head error predictor on large dataset.
|
| 118 |
+
|
| 119 |
+
This script trains an error predictor on top of the Multi-Head large
|
| 120 |
+
base model, using the 2M sample dataset for improved accuracy.
|
| 121 |
+
|
| 122 |
+
Parameters
|
| 123 |
+
----------
|
| 124 |
+
base_model : Path, optional
|
| 125 |
+
Path to the Multi-Head large base model ONNX file.
|
| 126 |
+
Default: models/from_xyY/multi_head_large.onnx
|
| 127 |
+
params : Path, optional
|
| 128 |
+
Path to the normalization parameters file.
|
| 129 |
+
Default: models/from_xyY/multi_head_large_normalization_params.npz
|
| 130 |
+
output_suffix : str
|
| 131 |
+
Suffix for output filenames (default: 'large').
|
| 132 |
+
|
| 133 |
+
Notes
|
| 134 |
+
-----
|
| 135 |
+
The training pipeline:
|
| 136 |
+
1. Loads pre-trained Multi-Head large base model
|
| 137 |
+
2. Generates base model predictions for training data (in batches)
|
| 138 |
+
3. Computes residual errors between predictions and targets
|
| 139 |
+
4. Trains multi-head error predictor on these residuals
|
| 140 |
+
5. Uses precision-focused loss function
|
| 141 |
+
6. Learning rate scheduling with ReduceLROnPlateau
|
| 142 |
+
7. Early stopping based on validation loss
|
| 143 |
+
8. Exports model to ONNX format
|
| 144 |
+
9. Logs metrics and artifacts to MLflow
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
LOGGER.info("=" * 80)
|
| 148 |
+
LOGGER.info("Multi-Head Error Predictor: Large Dataset (2M samples)")
|
| 149 |
+
LOGGER.info("=" * 80)
|
| 150 |
+
|
| 151 |
+
# Set device
|
| 152 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 153 |
+
if torch.backends.mps.is_available():
|
| 154 |
+
device = torch.device("mps")
|
| 155 |
+
LOGGER.info("Using device: %s", device)
|
| 156 |
+
|
| 157 |
+
# Paths
|
| 158 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 159 |
+
data_dir = PROJECT_ROOT / "data"
|
| 160 |
+
|
| 161 |
+
# Use provided paths or defaults for large model
|
| 162 |
+
if base_model is None:
|
| 163 |
+
base_model = model_directory / "multi_head_large.onnx"
|
| 164 |
+
if params is None:
|
| 165 |
+
params = model_directory / "multi_head_large_normalization_params.npz"
|
| 166 |
+
|
| 167 |
+
cache_file = data_dir / "training_data_large.npz"
|
| 168 |
+
|
| 169 |
+
if not cache_file.exists():
|
| 170 |
+
LOGGER.error("Error: Large training data not found at %s", cache_file)
|
| 171 |
+
LOGGER.error("Please run generate_large_training_data.py first")
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
if not base_model.exists():
|
| 175 |
+
LOGGER.error("Error: Multi-Head large base model not found at %s", base_model)
|
| 176 |
+
LOGGER.error("Please run train_multi_head_large.py first")
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
# Load base model
|
| 180 |
+
LOGGER.info("")
|
| 181 |
+
LOGGER.info("Loading Multi-Head large base model from %s...", base_model)
|
| 182 |
+
base_session, input_params, output_params = load_base_model(base_model, params)
|
| 183 |
+
|
| 184 |
+
# Load training data
|
| 185 |
+
LOGGER.info("Loading large training data from %s...", cache_file)
|
| 186 |
+
data = np.load(cache_file)
|
| 187 |
+
X_train = data["X_train"]
|
| 188 |
+
y_train = data["y_train"]
|
| 189 |
+
X_val = data["X_val"]
|
| 190 |
+
y_val = data["y_val"]
|
| 191 |
+
|
| 192 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 193 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 194 |
+
|
| 195 |
+
# Generate base model predictions
|
| 196 |
+
LOGGER.info("")
|
| 197 |
+
LOGGER.info("Generating Multi-Head large base model predictions...")
|
| 198 |
+
X_train_norm = normalize_xyY(X_train, input_params)
|
| 199 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 200 |
+
|
| 201 |
+
# Base predictions (normalized) - process in batches for memory efficiency
|
| 202 |
+
LOGGER.info(" Processing training set predictions...")
|
| 203 |
+
inference_batch_size = 50000
|
| 204 |
+
base_pred_train_norm = []
|
| 205 |
+
for i in range(0, len(X_train_norm), inference_batch_size):
|
| 206 |
+
batch = X_train_norm[i : i + inference_batch_size]
|
| 207 |
+
pred = base_session.run(None, {"xyY": batch})[0]
|
| 208 |
+
base_pred_train_norm.append(pred)
|
| 209 |
+
base_pred_train_norm = np.concatenate(base_pred_train_norm, axis=0)
|
| 210 |
+
|
| 211 |
+
X_val_norm = normalize_xyY(X_val, input_params)
|
| 212 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 213 |
+
|
| 214 |
+
LOGGER.info(" Processing validation set predictions...")
|
| 215 |
+
base_pred_val_norm = []
|
| 216 |
+
for i in range(0, len(X_val_norm), inference_batch_size):
|
| 217 |
+
batch = X_val_norm[i : i + inference_batch_size]
|
| 218 |
+
pred = base_session.run(None, {"xyY": batch})[0]
|
| 219 |
+
base_pred_val_norm.append(pred)
|
| 220 |
+
base_pred_val_norm = np.concatenate(base_pred_val_norm, axis=0)
|
| 221 |
+
|
| 222 |
+
# Compute errors (in normalized space)
|
| 223 |
+
error_train = y_train_norm - base_pred_train_norm
|
| 224 |
+
error_val = y_val_norm - base_pred_val_norm
|
| 225 |
+
|
| 226 |
+
# Statistics
|
| 227 |
+
LOGGER.info("")
|
| 228 |
+
LOGGER.info("Multi-Head large base model error statistics (normalized space):")
|
| 229 |
+
LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
|
| 230 |
+
LOGGER.info(" Std of error: %.6f", np.std(error_train))
|
| 231 |
+
LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
|
| 232 |
+
|
| 233 |
+
# Create combined input: [xyY_norm, base_prediction_norm]
|
| 234 |
+
X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
|
| 235 |
+
X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
|
| 236 |
+
|
| 237 |
+
# Convert to PyTorch tensors
|
| 238 |
+
X_train_t = torch.FloatTensor(X_train_combined)
|
| 239 |
+
error_train_t = torch.FloatTensor(error_train)
|
| 240 |
+
X_val_t = torch.FloatTensor(X_val_combined)
|
| 241 |
+
error_val_t = torch.FloatTensor(error_val)
|
| 242 |
+
|
| 243 |
+
# Create data loaders (larger batch size for large dataset)
|
| 244 |
+
train_dataset = TensorDataset(X_train_t, error_train_t)
|
| 245 |
+
val_dataset = TensorDataset(X_val_t, error_val_t)
|
| 246 |
+
|
| 247 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 248 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 249 |
+
|
| 250 |
+
# Initialize Multi-Head error predictor
|
| 251 |
+
model = MultiHeadErrorPredictorToMunsell().to(device)
|
| 252 |
+
LOGGER.info("")
|
| 253 |
+
LOGGER.info("Multi-Head error predictor architecture:")
|
| 254 |
+
LOGGER.info("%s", model)
|
| 255 |
+
|
| 256 |
+
# Count parameters
|
| 257 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 258 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 259 |
+
|
| 260 |
+
# Count parameters per branch
|
| 261 |
+
hue_params = sum(p.numel() for p in model.hue_branch.parameters())
|
| 262 |
+
value_params = sum(p.numel() for p in model.value_branch.parameters())
|
| 263 |
+
chroma_params = sum(p.numel() for p in model.chroma_branch.parameters())
|
| 264 |
+
code_params = sum(p.numel() for p in model.code_branch.parameters())
|
| 265 |
+
|
| 266 |
+
LOGGER.info(" - Hue branch: %s", f"{hue_params:,}")
|
| 267 |
+
LOGGER.info(" - Value branch: %s", f"{value_params:,}")
|
| 268 |
+
LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}")
|
| 269 |
+
LOGGER.info(" - Code branch: %s", f"{code_params:,}")
|
| 270 |
+
|
| 271 |
+
# Training setup
|
| 272 |
+
LOGGER.info("")
|
| 273 |
+
LOGGER.info("Using precision-focused loss function:")
|
| 274 |
+
LOGGER.info(" - MSE (weight: 1.0)")
|
| 275 |
+
LOGGER.info(" - MAE (weight: 0.5)")
|
| 276 |
+
LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
|
| 277 |
+
LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
|
| 278 |
+
|
| 279 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
|
| 280 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 281 |
+
optimizer, mode="min", factor=0.5, patience=10
|
| 282 |
+
)
|
| 283 |
+
criterion = precision_focused_loss
|
| 284 |
+
|
| 285 |
+
# MLflow setup
|
| 286 |
+
run_name = setup_mlflow_experiment(
|
| 287 |
+
"from_xyY", f"multi_head_multi_error_predictor_{output_suffix}"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
LOGGER.info("")
|
| 291 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 292 |
+
|
| 293 |
+
# Training loop
|
| 294 |
+
best_val_loss = float("inf")
|
| 295 |
+
patience_counter = 0
|
| 296 |
+
|
| 297 |
+
LOGGER.info("")
|
| 298 |
+
LOGGER.info("Starting training...")
|
| 299 |
+
|
| 300 |
+
with mlflow.start_run(run_name=run_name):
|
| 301 |
+
mlflow.log_params(
|
| 302 |
+
{
|
| 303 |
+
"model": f"multi_head_multi_error_predictor_{output_suffix}",
|
| 304 |
+
"num_epochs": epochs,
|
| 305 |
+
"batch_size": batch_size,
|
| 306 |
+
"learning_rate": lr,
|
| 307 |
+
"weight_decay": 1e-5,
|
| 308 |
+
"optimizer": "AdamW",
|
| 309 |
+
"scheduler": "ReduceLROnPlateau",
|
| 310 |
+
"criterion": "precision_focused_loss",
|
| 311 |
+
"patience": patience,
|
| 312 |
+
"total_params": total_params,
|
| 313 |
+
"train_samples": len(X_train),
|
| 314 |
+
"val_samples": len(X_val),
|
| 315 |
+
"dataset": "large_2M",
|
| 316 |
+
}
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
for epoch in range(epochs):
|
| 320 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 321 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 322 |
+
|
| 323 |
+
scheduler.step(val_loss)
|
| 324 |
+
|
| 325 |
+
log_training_epoch(
|
| 326 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
LOGGER.info(
|
| 330 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
|
| 331 |
+
epoch + 1,
|
| 332 |
+
epochs,
|
| 333 |
+
train_loss,
|
| 334 |
+
val_loss,
|
| 335 |
+
optimizer.param_groups[0]["lr"],
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if val_loss < best_val_loss:
|
| 339 |
+
best_val_loss = val_loss
|
| 340 |
+
patience_counter = 0
|
| 341 |
+
|
| 342 |
+
model_directory.mkdir(exist_ok=True)
|
| 343 |
+
checkpoint_file = (
|
| 344 |
+
model_directory / f"multi_head_multi_error_predictor_{output_suffix}_best.pth"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
torch.save(
|
| 348 |
+
{
|
| 349 |
+
"model_state_dict": model.state_dict(),
|
| 350 |
+
"epoch": epoch,
|
| 351 |
+
"val_loss": val_loss,
|
| 352 |
+
"output_params": output_params,
|
| 353 |
+
},
|
| 354 |
+
checkpoint_file,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 358 |
+
else:
|
| 359 |
+
patience_counter += 1
|
| 360 |
+
if patience_counter >= patience:
|
| 361 |
+
LOGGER.info("")
|
| 362 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
mlflow.log_metrics(
|
| 366 |
+
{
|
| 367 |
+
"best_val_loss": best_val_loss,
|
| 368 |
+
"final_epoch": epoch + 1,
|
| 369 |
+
}
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Export to ONNX
|
| 373 |
+
LOGGER.info("")
|
| 374 |
+
LOGGER.info("Exporting Multi-Head error predictor to ONNX...")
|
| 375 |
+
model.eval()
|
| 376 |
+
|
| 377 |
+
checkpoint = torch.load(checkpoint_file, weights_only=False)
|
| 378 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 379 |
+
|
| 380 |
+
dummy_input = torch.randn(1, 7).to(device)
|
| 381 |
+
|
| 382 |
+
onnx_file = model_directory / f"multi_head_multi_error_predictor_{output_suffix}.onnx"
|
| 383 |
+
torch.onnx.export(
|
| 384 |
+
model,
|
| 385 |
+
dummy_input,
|
| 386 |
+
onnx_file,
|
| 387 |
+
export_params=True,
|
| 388 |
+
opset_version=15,
|
| 389 |
+
input_names=["combined_input"],
|
| 390 |
+
output_names=["error_correction"],
|
| 391 |
+
dynamic_axes={
|
| 392 |
+
"combined_input": {0: "batch_size"},
|
| 393 |
+
"error_correction": {0: "batch_size"},
|
| 394 |
+
},
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
LOGGER.info("Multi-Head error predictor ONNX model saved to: %s", onnx_file)
|
| 398 |
+
|
| 399 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 400 |
+
mlflow.log_artifact(str(onnx_file))
|
| 401 |
+
mlflow.pytorch.log_model(model, "model")
|
| 402 |
+
|
| 403 |
+
LOGGER.info("=" * 80)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 408 |
+
|
| 409 |
+
main()
|
learning_munsell/training/from_xyY/train_multi_head_st2084.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train multi-head ML model for xyY to Munsell conversion with ST.2084 (PQ) encoded Y.
|
| 3 |
+
|
| 4 |
+
Experiment: Apply SMPTE ST.2084 (Perceptual Quantizer) encoding to Y before
|
| 5 |
+
normalization. ST.2084 is designed for perceptual uniformity across a wide
|
| 6 |
+
luminance range, potentially providing better alignment with Munsell Value
|
| 7 |
+
than simple gamma correction.
|
| 8 |
+
|
| 9 |
+
The multi-head architecture has separate heads for each Munsell component,
|
| 10 |
+
so PQ encoding on Y should primarily benefit Value prediction without
|
| 11 |
+
negatively impacting Chroma prediction.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
import click
|
| 18 |
+
import mlflow
|
| 19 |
+
import mlflow.pytorch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from colour.models import eotf_inverse_ST2084
|
| 23 |
+
from numpy.typing import NDArray
|
| 24 |
+
from torch import nn, optim
|
| 25 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 26 |
+
|
| 27 |
+
from learning_munsell import PROJECT_ROOT
|
| 28 |
+
from learning_munsell.models.networks import MultiHeadMLPToMunsell
|
| 29 |
+
from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
|
| 30 |
+
from learning_munsell.utilities.data import (
|
| 31 |
+
MUNSELL_NORMALIZATION_PARAMS,
|
| 32 |
+
normalize_munsell,
|
| 33 |
+
)
|
| 34 |
+
from learning_munsell.utilities.losses import weighted_mse_loss
|
| 35 |
+
from learning_munsell.utilities.training import train_epoch, validate
|
| 36 |
+
|
| 37 |
+
LOGGER = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
# Peak luminance for ST.2084 scaling
|
| 40 |
+
# Munsell Y is relative luminance [0, 1], we scale to cd/m² for ST.2084
|
| 41 |
+
# Using 100 cd/m² as reference white (typical SDR display)
|
| 42 |
+
L_P_REFERENCE = 100.0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def normalize_inputs(
|
| 46 |
+
X: NDArray, L_p: float = L_P_REFERENCE
|
| 47 |
+
) -> tuple[NDArray, dict[str, Any]]:
|
| 48 |
+
"""
|
| 49 |
+
Normalize xyY inputs to [0, 1] range with ST.2084 (PQ) encoding on Y.
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
X : ndarray
|
| 54 |
+
xyY values of shape (n, 3) where columns are [x, y, Y].
|
| 55 |
+
L_p : float
|
| 56 |
+
Peak luminance in cd/m² for ST.2084 scaling.
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
ndarray
|
| 61 |
+
Normalized values with ST.2084-encoded Y, dtype float32.
|
| 62 |
+
dict
|
| 63 |
+
Normalization parameters including L_p and encoding type.
|
| 64 |
+
"""
|
| 65 |
+
# xyY chromaticity and luminance ranges (all [0, 1])
|
| 66 |
+
x_range = (0.0, 1.0)
|
| 67 |
+
y_range = (0.0, 1.0)
|
| 68 |
+
Y_range = (0.0, 1.0)
|
| 69 |
+
|
| 70 |
+
X_norm = X.copy()
|
| 71 |
+
X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
|
| 72 |
+
X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
|
| 73 |
+
|
| 74 |
+
# Normalize Y first, then apply ST.2084
|
| 75 |
+
Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
|
| 76 |
+
# Clip to avoid numerical issues
|
| 77 |
+
Y_normalized = np.clip(Y_normalized, 0, 1)
|
| 78 |
+
# Scale to cd/m² and apply ST.2084 inverse EOTF (PQ encoding)
|
| 79 |
+
# ST.2084 expects absolute luminance in cd/m²
|
| 80 |
+
Y_cdm2 = Y_normalized * L_p
|
| 81 |
+
# eotf_inverse_ST2084 returns values in [0, 1] for the 10000 cd/m² range
|
| 82 |
+
# We use a custom L_p to scale appropriately
|
| 83 |
+
X_norm[:, 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p)
|
| 84 |
+
|
| 85 |
+
params = {
|
| 86 |
+
"x_range": x_range,
|
| 87 |
+
"y_range": y_range,
|
| 88 |
+
"Y_range": Y_range,
|
| 89 |
+
"encoding": "ST2084",
|
| 90 |
+
"L_p": L_p,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return X_norm, params
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@click.command()
|
| 97 |
+
@click.option("--epochs", default=200, help="Number of training epochs")
|
| 98 |
+
@click.option("--batch-size", default=1024, help="Batch size for training")
|
| 99 |
+
@click.option("--lr", default=5e-4, help="Learning rate")
|
| 100 |
+
@click.option("--patience", default=20, help="Early stopping patience")
|
| 101 |
+
def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
|
| 102 |
+
"""
|
| 103 |
+
Train the multi-head model with ST.2084 (PQ) encoded Y input.
|
| 104 |
+
|
| 105 |
+
Notes
|
| 106 |
+
-----
|
| 107 |
+
The training pipeline:
|
| 108 |
+
1. Loads training and validation data from cache
|
| 109 |
+
2. Normalizes inputs with ST.2084 (PQ) encoding on Y
|
| 110 |
+
3. Normalizes Munsell outputs to [0, 1] range
|
| 111 |
+
4. Trains multi-head MLP with weighted MSE loss
|
| 112 |
+
5. Uses early stopping based on validation loss
|
| 113 |
+
6. Exports best model to ONNX format
|
| 114 |
+
7. Logs metrics and artifacts to MLflow
|
| 115 |
+
|
| 116 |
+
ST.2084 (Perceptual Quantizer) encoding is designed for perceptual
|
| 117 |
+
uniformity across a wide luminance range, potentially providing better
|
| 118 |
+
alignment with Munsell Value than simple gamma correction. The multi-head
|
| 119 |
+
architecture isolates this effect to the Value head without negatively
|
| 120 |
+
impacting Chroma prediction.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
LOGGER.info("=" * 80)
|
| 124 |
+
LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head ST.2084 Experiment")
|
| 125 |
+
LOGGER.info("ST.2084 (PQ) encoding applied to Y component (L_p=%.0f cd/m²)", L_P_REFERENCE)
|
| 126 |
+
LOGGER.info("=" * 80)
|
| 127 |
+
|
| 128 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 129 |
+
LOGGER.info("Using device: %s", device)
|
| 130 |
+
|
| 131 |
+
# Load training data
|
| 132 |
+
data_dir = PROJECT_ROOT / "data"
|
| 133 |
+
cache_file = data_dir / "training_data.npz"
|
| 134 |
+
|
| 135 |
+
if not cache_file.exists():
|
| 136 |
+
LOGGER.error("Error: Training data not found at %s", cache_file)
|
| 137 |
+
LOGGER.error("Please run 01_generate_training_data.py first")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
LOGGER.info("Loading training data from %s...", cache_file)
|
| 141 |
+
data = np.load(cache_file)
|
| 142 |
+
|
| 143 |
+
X_train = data["X_train"]
|
| 144 |
+
y_train = data["y_train"]
|
| 145 |
+
X_val = data["X_val"]
|
| 146 |
+
y_val = data["y_val"]
|
| 147 |
+
|
| 148 |
+
LOGGER.info("Train samples: %d", len(X_train))
|
| 149 |
+
LOGGER.info("Validation samples: %d", len(X_val))
|
| 150 |
+
|
| 151 |
+
# Normalize data with ST.2084 encoding
|
| 152 |
+
X_train_norm, input_params = normalize_inputs(X_train, L_p=L_P_REFERENCE)
|
| 153 |
+
X_val_norm, _ = normalize_inputs(X_val, L_p=L_P_REFERENCE)
|
| 154 |
+
|
| 155 |
+
output_params = MUNSELL_NORMALIZATION_PARAMS
|
| 156 |
+
y_train_norm = normalize_munsell(y_train, output_params)
|
| 157 |
+
y_val_norm = normalize_munsell(y_val, output_params)
|
| 158 |
+
|
| 159 |
+
LOGGER.info("")
|
| 160 |
+
LOGGER.info("Input normalization with ST.2084 (L_p=%.0f):", L_P_REFERENCE)
|
| 161 |
+
LOGGER.info(" Y range after ST.2084: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max())
|
| 162 |
+
|
| 163 |
+
# Convert to PyTorch tensors
|
| 164 |
+
X_train_t = torch.FloatTensor(X_train_norm)
|
| 165 |
+
y_train_t = torch.FloatTensor(y_train_norm)
|
| 166 |
+
X_val_t = torch.FloatTensor(X_val_norm)
|
| 167 |
+
y_val_t = torch.FloatTensor(y_val_norm)
|
| 168 |
+
|
| 169 |
+
# Create data loaders
|
| 170 |
+
train_dataset = TensorDataset(X_train_t, y_train_t)
|
| 171 |
+
val_dataset = TensorDataset(X_val_t, y_val_t)
|
| 172 |
+
|
| 173 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 174 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 175 |
+
|
| 176 |
+
# Initialize model
|
| 177 |
+
model = MultiHeadMLPToMunsell().to(device)
|
| 178 |
+
LOGGER.info("")
|
| 179 |
+
LOGGER.info("Model architecture:")
|
| 180 |
+
LOGGER.info("%s", model)
|
| 181 |
+
|
| 182 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 183 |
+
LOGGER.info("Total parameters: %s", f"{total_params:,}")
|
| 184 |
+
|
| 185 |
+
# Training setup
|
| 186 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 187 |
+
criterion = weighted_mse_loss
|
| 188 |
+
|
| 189 |
+
# MLflow setup
|
| 190 |
+
run_name = setup_mlflow_experiment("from_xyY", "multi_head_st2084")
|
| 191 |
+
|
| 192 |
+
LOGGER.info("")
|
| 193 |
+
LOGGER.info("MLflow run: %s", run_name)
|
| 194 |
+
|
| 195 |
+
# Training loop
|
| 196 |
+
best_val_loss = float("inf")
|
| 197 |
+
patience_counter = 0
|
| 198 |
+
|
| 199 |
+
LOGGER.info("")
|
| 200 |
+
LOGGER.info("Starting training...")
|
| 201 |
+
|
| 202 |
+
with mlflow.start_run(run_name=run_name):
|
| 203 |
+
mlflow.log_params(
|
| 204 |
+
{
|
| 205 |
+
"model": "multi_head_st2084",
|
| 206 |
+
"num_epochs": epochs,
|
| 207 |
+
"batch_size": batch_size,
|
| 208 |
+
"learning_rate": lr,
|
| 209 |
+
"optimizer": "Adam",
|
| 210 |
+
"criterion": "weighted_mse_loss",
|
| 211 |
+
"patience": patience,
|
| 212 |
+
"total_params": total_params,
|
| 213 |
+
"encoding": "ST2084",
|
| 214 |
+
"L_p": L_P_REFERENCE,
|
| 215 |
+
}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
for epoch in range(epochs):
|
| 219 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 220 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 221 |
+
|
| 222 |
+
log_training_epoch(
|
| 223 |
+
epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
LOGGER.info(
|
| 227 |
+
"Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
|
| 228 |
+
epoch + 1,
|
| 229 |
+
epochs,
|
| 230 |
+
train_loss,
|
| 231 |
+
val_loss,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if val_loss < best_val_loss:
|
| 235 |
+
best_val_loss = val_loss
|
| 236 |
+
patience_counter = 0
|
| 237 |
+
|
| 238 |
+
model_directory = PROJECT_ROOT / "models" / "from_xyY"
|
| 239 |
+
model_directory.mkdir(exist_ok=True)
|
| 240 |
+
checkpoint_file = model_directory / "multi_head_st2084_best.pth"
|
| 241 |
+
|
| 242 |
+
torch.save(
|
| 243 |
+
{
|
| 244 |
+
"model_state_dict": model.state_dict(),
|
| 245 |
+
"input_params": input_params,
|
| 246 |
+
"output_params": output_params,
|
| 247 |
+
"epoch": epoch,
|
| 248 |
+
"val_loss": val_loss,
|
| 249 |
+
},
|
| 250 |
+
checkpoint_file,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
|
| 254 |
+
else:
|
| 255 |
+
patience_counter += 1
|
| 256 |
+
if patience_counter >= patience:
|
| 257 |
+
LOGGER.info("")
|
| 258 |
+
LOGGER.info("Early stopping after %d epochs", epoch + 1)
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
mlflow.log_metrics(
|
| 262 |
+
{
|
| 263 |
+
"best_val_loss": best_val_loss,
|
| 264 |
+
"final_epoch": epoch + 1,
|
| 265 |
+
}
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Export to ONNX
|
| 269 |
+
LOGGER.info("")
|
| 270 |
+
LOGGER.info("Exporting model to ONNX...")
|
| 271 |
+
model.eval()
|
| 272 |
+
|
| 273 |
+
checkpoint = torch.load(checkpoint_file)
|
| 274 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 275 |
+
|
| 276 |
+
dummy_input = torch.randn(1, 3).to(device)
|
| 277 |
+
|
| 278 |
+
onnx_file = model_directory / "multi_head_st2084.onnx"
|
| 279 |
+
torch.onnx.export(
|
| 280 |
+
model,
|
| 281 |
+
dummy_input,
|
| 282 |
+
onnx_file,
|
| 283 |
+
export_params=True,
|
| 284 |
+
opset_version=17,
|
| 285 |
+
input_names=["xyY_st2084"],
|
| 286 |
+
output_names=["munsell_spec"],
|
| 287 |
+
dynamic_axes={"xyY_st2084": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Save normalization parameters (including ST.2084 info)
|
| 291 |
+
params_file = model_directory / "multi_head_st2084_normalization_params.npz"
|
| 292 |
+
np.savez(
|
| 293 |
+
params_file,
|
| 294 |
+
input_params=input_params,
|
| 295 |
+
output_params=output_params,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
LOGGER.info("ONNX model saved to: %s", onnx_file)
|
| 299 |
+
LOGGER.info("Normalization parameters saved to: %s", params_file)
|
| 300 |
+
LOGGER.info("IMPORTANT: Input Y must be ST.2084-encoded with L_p=%.0f", L_P_REFERENCE)
|
| 301 |
+
|
| 302 |
+
mlflow.log_artifact(str(checkpoint_file))
|
| 303 |
+
mlflow.log_artifact(str(onnx_file))
|
| 304 |
+
mlflow.log_artifact(str(params_file))
|
| 305 |
+
mlflow.pytorch.log_model(model, "model")
|
| 306 |
+
|
| 307 |
+
LOGGER.info("=" * 80)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
|
| 312 |
+
|
| 313 |
+
main()
|