KelSolaar commited on
Commit
3c7db92
·
0 Parent(s):

Initial commit.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +37 -0
  3. .pre-commit-config.yaml +39 -0
  4. LICENSE +11 -0
  5. README.md +284 -0
  6. docs/_static/gamma_sweep_plot.pdf +0 -0
  7. docs/_static/gamma_sweep_plot.png +3 -0
  8. docs/learning_munsell.md +588 -0
  9. docs/training_data_large_hue_density.png +3 -0
  10. learning_munsell/__init__.py +7 -0
  11. learning_munsell/analysis/__init__.py +1 -0
  12. learning_munsell/analysis/error_analysis.py +381 -0
  13. learning_munsell/analysis/hue_density_plot.py +73 -0
  14. learning_munsell/comparison/from_xyY/__init__.py +1 -0
  15. learning_munsell/comparison/from_xyY/compare_all_models.py +1594 -0
  16. learning_munsell/comparison/from_xyY/compare_gamma_model.py +452 -0
  17. learning_munsell/comparison/to_xyY/__init__.py +1 -0
  18. learning_munsell/comparison/to_xyY/compare_all_models.py +584 -0
  19. learning_munsell/data_generation/generate_training_data.py +310 -0
  20. learning_munsell/data_generation/generate_training_data_uniform.py +245 -0
  21. learning_munsell/interpolation/__init__.py +1 -0
  22. learning_munsell/interpolation/from_xyY/__init__.py +43 -0
  23. learning_munsell/interpolation/from_xyY/compare_methods.py +219 -0
  24. learning_munsell/interpolation/from_xyY/delaunay_interpolator.py +283 -0
  25. learning_munsell/interpolation/from_xyY/kdtree_interpolator.py +263 -0
  26. learning_munsell/interpolation/from_xyY/rbf_interpolator.py +300 -0
  27. learning_munsell/losses/__init__.py +17 -0
  28. learning_munsell/losses/jax_delta_e.py +299 -0
  29. learning_munsell/models/__init__.py +46 -0
  30. learning_munsell/models/networks.py +1507 -0
  31. learning_munsell/training/from_xyY/__init__.py +1 -0
  32. learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py +506 -0
  33. learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py +547 -0
  34. learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py +555 -0
  35. learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py +473 -0
  36. learning_munsell/training/from_xyY/refine_multi_head_real.py +357 -0
  37. learning_munsell/training/from_xyY/train_deep_wide.py +370 -0
  38. learning_munsell/training/from_xyY/train_ft_transformer.py +358 -0
  39. learning_munsell/training/from_xyY/train_mixture_of_experts.py +622 -0
  40. learning_munsell/training/from_xyY/train_mlp.py +273 -0
  41. learning_munsell/training/from_xyY/train_mlp_attention.py +462 -0
  42. learning_munsell/training/from_xyY/train_mlp_error_predictor.py +461 -0
  43. learning_munsell/training/from_xyY/train_mlp_gamma.py +309 -0
  44. learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py +423 -0
  45. learning_munsell/training/from_xyY/train_multi_head_circular.py +496 -0
  46. learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py +667 -0
  47. learning_munsell/training/from_xyY/train_multi_head_gamma.py +311 -0
  48. learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py +637 -0
  49. learning_munsell/training/from_xyY/train_multi_head_large.py +249 -0
  50. learning_munsell/training/from_xyY/train_multi_head_mlp.py +273 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx.data filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar filter=lfs diff=lfs merge=lfs -text
30
+ *.tflite filter=lfs diff=lfs merge=lfs -text
31
+ *.tgz filter=lfs diff=lfs merge=lfs -text
32
+ *.wasm filter=lfs diff=lfs merge=lfs -text
33
+ *.xz filter=lfs diff=lfs merge=lfs -text
34
+ *.zip filter=lfs diff=lfs merge=lfs -text
35
+ *.zst filter=lfs diff=lfs merge=lfs -text
36
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Common Files
2
+ *.egg-info
3
+ *.pyc
4
+ *.pyo
5
+ .DS_Store
6
+ .coverage*
7
+ uv.lock
8
+
9
+ # Common Directories
10
+ .fleet/
11
+ .idea/
12
+ .ipynb_checkpoints/
13
+ .python-version
14
+ .vs/
15
+ .vscode/
16
+ .sandbox/
17
+ build/
18
+ dist/
19
+ docs/_build/
20
+ docs/generated/
21
+ node_modules/
22
+ references/
23
+
24
+ __pycache__
25
+
26
+ .claude/settings.local.json
27
+ .claude/scratchpad.md
28
+
29
+ # Project Directories
30
+ data/
31
+ logs/
32
+ mlartifacts/
33
+ mlruns/
34
+ mlruns.db
35
+ reports/
36
+ results/
37
+ runs/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: "v5.0.0"
4
+ hooks:
5
+ - id: check-added-large-files
6
+ - id: check-case-conflict
7
+ - id: check-merge-conflict
8
+ - id: check-symlinks
9
+ - id: check-yaml
10
+ - id: debug-statements
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ - id: requirements-txt-fixer
14
+ - id: trailing-whitespace
15
+ - repo: https://github.com/codespell-project/codespell
16
+ rev: v2.4.1
17
+ hooks:
18
+ - id: codespell
19
+ args: ["--ignore-words-list=colour"]
20
+ - repo: https://github.com/PyCQA/isort
21
+ rev: "6.0.1"
22
+ hooks:
23
+ - id: isort
24
+ - repo: https://github.com/astral-sh/ruff-pre-commit
25
+ rev: "v0.12.4"
26
+ hooks:
27
+ - id: ruff-format
28
+ - id: ruff
29
+ args: [--fix]
30
+ - repo: https://github.com/pre-commit/mirrors-prettier
31
+ rev: "v4.0.0-alpha.8"
32
+ hooks:
33
+ - id: prettier
34
+ - repo: https://github.com/pre-commit/pygrep-hooks
35
+ rev: "v1.10.0"
36
+ hooks:
37
+ - id: rst-backticks
38
+ - id: rst-directive-colons
39
+ - id: rst-inline-touching-normal
LICENSE ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2025 Colour Developers
2
+
3
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4
+
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
README.md ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ language:
4
+ - en
5
+ tags:
6
+ - python
7
+ - colour
8
+ - color
9
+ - colour-science
10
+ - color-science
11
+ - colour-spaces
12
+ - color-spaces
13
+ - colourspace
14
+ - colorspace
15
+ pipeline_tag: tabular-regression
16
+ library_name: onnxruntime
17
+ metrics:
18
+ - mae
19
+ model-index:
20
+ - name: from_xyY (CIE xyY to Munsell)
21
+ results:
22
+ - task:
23
+ type: tabular-regression
24
+ name: CIE xyY to Munsell Specification
25
+ dataset:
26
+ name: CIE xyY to Munsell Specification
27
+ type: munsell-renotation
28
+ metrics:
29
+ - type: delta-e
30
+ value: 0.51
31
+ name: Delta-E CIE2000
32
+ - type: inference_time_ms
33
+ value: 0.061
34
+ name: Inference Time (ms/sample)
35
+ - name: to_xyY (Munsell to CIE xyY)
36
+ results:
37
+ - task:
38
+ type: tabular-regression
39
+ name: Munsell Specification to CIE xyY
40
+ dataset:
41
+ name: Munsell Specification to CIE xyY
42
+ type: munsell-renotation
43
+ metrics:
44
+ - type: delta-e
45
+ value: 0.48
46
+ name: Delta-E CIE2000
47
+ - type: inference_time_ms
48
+ value: 0.066
49
+ name: Inference Time (ms/sample)
50
+ ---
51
+
52
+ # Learning Munsell - Machine Learning for Munsell Color Conversions
53
+
54
+ A project implementing machine learning-based methods for bidirectional conversion between CIE xyY colourspace values and Munsell specifications.
55
+
56
+ **Two Conversion Directions:**
57
+
58
+ - **from_xyY**: CIE xyY to Munsell specification
59
+ - **to_xyY**: Munsell specification to CIE xyY
60
+
61
+ ## Project Overview
62
+
63
+ ### Objective
64
+
65
+ Provide 100-1000x speedup for batch Munsell conversions compared to colour-science routines while maintaining high accuracy.
66
+
67
+ ### Results
68
+
69
+ **from_xyY** (CIE xyY to Munsell) — 33 models evaluated on all 2,734 REAL Munsell colors:
70
+
71
+ | Model | Delta-E | Speed (ms) |
72
+ | ---------------------------------------------------------------------------------- | -------- | ---------- |
73
+ | Colour Library (Baseline) | 0.00 | 116.3 |
74
+ | **Multi-MLP + Multi-Error Predictor** | **0.51** | 0.061 |
75
+ | Multi-ResNet + Multi-Error Predictor (Large Dataset) | 0.52 | 0.096 |
76
+ | Transformer + Error Predictor (Large Dataset) | 0.52 | 0.163 |
77
+ | Multi-MLP (Classification Code) + Code-Aware Multi-Error Predictor (Large Dataset) | 0.53 | 0.118 |
78
+ | Multi-Head + Multi-Error Predictor (Large Dataset) | 0.53 | 0.046 |
79
+ | Multi-MLP (Classification Code) + Code-Aware Multi-Error Predictor | 0.53 | 0.052 |
80
+ | Multi-MLP (Classification Code) + Multi-Error Predictor | 0.53 | 0.050 |
81
+ | MLP + Error Predictor | 0.53 | 0.036 |
82
+ | Multi-Head + Multi-Error Predictor | 0.54 | 0.057 |
83
+ | Multi-ResNet (Large Dataset) | 0.56 | 0.047 |
84
+ | Multi-Head (Large Dataset) | 0.57 | 0.012 |
85
+ | FT-Transformer | 0.70 | 0.067 |
86
+ | Unified MLP | 0.71 | 0.074 |
87
+ | Transformer (Large Dataset) | 0.73 | 0.120 |
88
+ | Mixture of Experts | 0.74 | 0.021 |
89
+ | Multi-MLP | 0.91 | 0.027 |
90
+ | Multi-MLP (Classification Code) | 0.92 | 0.026 |
91
+ | Multi-MLP (Classification Code) (Large Dataset) | 0.89 | 0.066 |
92
+ | MLP + Self-Attention | 0.88 | 0.185 |
93
+ | Deep + Wide | 1.18 | 0.078 |
94
+ | MLP (Base Only) | 1.30 | **0.009** |
95
+ | Multi-MLP (Hue Angle sin/cos) | 1.78 | 0.021 |
96
+
97
+ - **Best Accuracy**: Multi-MLP + Multi-Error Predictor — Delta-E 0.51, 1,905x faster
98
+ - **Fastest**: MLP Base Only (0.009 ms/sample) — 13,365x faster than Colour library
99
+ - **Best Balance**: MLP + Error Predictor — 3,196x faster with Delta-E 0.53
100
+
101
+ **to_xyY** (Munsell to CIE xyY) — 6 models evaluated on all 2,734 REAL Munsell colors:
102
+
103
+ | Model | Delta-E | Speed (ms) |
104
+ | ------------------------------------- | -------- | ---------- |
105
+ | Colour Library (Baseline) | 0.00 | 1.40 |
106
+ | **Multi-MLP + Multi-Error Predictor** | **0.48** | 0.066 |
107
+ | Multi-Head + Multi-Error Predictor | 0.51 | 0.054 |
108
+ | Simple MLP | 0.52 | 0.002 |
109
+ | Multi-MLP | 0.57 | 0.028 |
110
+ | Multi-Head | 0.60 | 0.013 |
111
+ | Multi-MLP + Error Predictor | 0.61 | 0.060 |
112
+
113
+ - **Best Accuracy**: Multi-MLP + Multi-Error Predictor — Delta-E 0.48, 21x faster
114
+ - **Fastest**: Simple MLP (0.002 ms/sample) — 669x faster than Colour library
115
+
116
+ ### Approach
117
+
118
+ - **33 models** tested for from_xyY (MLP, Multi-Head, Multi-MLP, Multi-ResNet, Transformers, FT-Transformer, Mixture of Experts)
119
+ - **6 models** tested for to_xyY (Simple MLP, Multi-Head, Multi-MLP with error predictors)
120
+ - **Two-stage models** (base + error predictor) proved most effective
121
+ - **Best model**: Multi-MLP + Multi-Error Predictor with Delta-E 0.51
122
+ - **Training data**: ~1.4M samples from dense xyY grid with boundary refinement and forward Munsell sampling
123
+ - **Deployment**: ONNX format with ONNX Runtime
124
+
125
+ For detailed architecture comparisons, model benchmarks, training pipeline details, and experimental results, see [docs/learning_munsell.md](docs/learning_munsell.md).
126
+
127
+ ## Installation
128
+
129
+ **Dependencies (Runtime)**:
130
+
131
+ - numpy >= 2.0
132
+ - onnxruntime >= 1.16
133
+
134
+ **Dependencies (Training)**:
135
+
136
+ - torch >= 2.0
137
+ - scikit-learn >= 1.3
138
+ - matplotlib >= 3.9
139
+ - mlflow >= 2.10
140
+ - optuna >= 3.0
141
+ - colour-science >= 0.4.7
142
+ - click >= 8.0
143
+ - onnx >= 1.15
144
+ - onnxscript >= 0.5.6
145
+ - tqdm >= 4.66
146
+ - jax >= 0.4.20
147
+ - jaxlib >= 0.4.20
148
+ - flax >= 0.10.7
149
+ - optax >= 0.2.6
150
+ - scipy >= 1.12
151
+ - tensorboard >= 2.20
152
+
153
+ From the project root:
154
+
155
+ ```bash
156
+ cd learning-munsell
157
+
158
+ # Install all dependencies (creates virtual environment automatically)
159
+ uv sync
160
+ ```
161
+
162
+ ## Usage
163
+
164
+ ### Generate Training Data
165
+
166
+ ```bash
167
+ uv run python learning_munsell/data_generation/generate_training_data.py
168
+ ```
169
+
170
+ **Note**: This step is computationally expensive (uses iterative algorithm for ground truth).
171
+
172
+ ### Train Models
173
+
174
+ **xyY to Munsell (from_xyY)**
175
+
176
+ Best performing model (Multi-MLP + Multi-Error Predictor):
177
+
178
+ ```bash
179
+ # Train base Multi-MLP
180
+ uv run python learning_munsell/training/from_xyY/train_multi_mlp.py
181
+
182
+ # Train multi-error predictor
183
+ uv run python learning_munsell/training/from_xyY/train_multi_mlp_multi_error_predictor.py
184
+ ```
185
+
186
+ Alternative architectures:
187
+
188
+ ```bash
189
+ uv run python learning_munsell/training/from_xyY/train_multi_resnet_large.py
190
+ uv run python learning_munsell/training/from_xyY/train_multi_head_large.py
191
+ uv run python learning_munsell/training/from_xyY/train_ft_transformer.py
192
+ uv run python learning_munsell/training/from_xyY/train_unified_mlp.py
193
+ uv run python learning_munsell/training/from_xyY/train_deep_wide.py
194
+ uv run python learning_munsell/training/from_xyY/train_mlp_attention.py
195
+ ```
196
+
197
+ **Munsell to xyY (to_xyY)**
198
+
199
+ Best performing model (Multi-MLP + Multi-Error Predictor):
200
+
201
+ ```bash
202
+ # Train base Multi-MLP
203
+ uv run python learning_munsell/training/to_xyY/train_multi_mlp.py
204
+
205
+ # Train multi-error predictor
206
+ uv run python learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py
207
+ ```
208
+
209
+ Other architectures:
210
+
211
+ ```bash
212
+ uv run python learning_munsell/training/to_xyY/train_multi_head.py
213
+ uv run python learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py
214
+ uv run python learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py
215
+ ```
216
+
217
+ Train the differentiable approximator for use in Delta-E loss:
218
+
219
+ ```bash
220
+ uv run python learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py
221
+ ```
222
+
223
+ ### Hyperparameter Search
224
+
225
+ ```bash
226
+ uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py
227
+ uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py
228
+ ```
229
+
230
+ ### Compare All Models
231
+
232
+ ```bash
233
+ uv run python learning_munsell/comparison/from_xyY/compare_all_models.py
234
+ uv run python learning_munsell/comparison/to_xyY/compare_all_models.py
235
+ ```
236
+
237
+ Generates comprehensive HTML reports at `reports/from_xyY/model_comparison.html` and `reports/to_xyY/model_comparison.html`.
238
+
239
+ ### Monitor Training
240
+
241
+ **MLflow**:
242
+
243
+ ```bash
244
+ uv run mlflow ui --backend-store-uri "sqlite:///mlruns.db" --port=5000
245
+ ```
246
+
247
+ Open <http://localhost:5000> in your browser.
248
+
249
+ ## Directory Structure
250
+
251
+ ```
252
+ learning-munsell/
253
+ +-- data/ # Training data
254
+ | +-- training_data.npz # Generated training samples
255
+ | +-- training_data_large.npz # Large dataset (~1.4M samples)
256
+ | +-- training_data_params.json # Generation parameters
257
+ | +-- training_data_large_params.json
258
+ +-- models/ # Trained models (ONNX + PyTorch)
259
+ | +-- from_xyY/ # xyY to Munsell models (38 ONNX models)
260
+ | | +-- multi_mlp_multi_error_predictor.onnx # BEST
261
+ | | +-- ... (additional model variants)
262
+ | +-- to_xyY/ # Munsell to xyY models (6 ONNX models)
263
+ | +-- multi_mlp_multi_error_predictor.onnx # BEST
264
+ | +-- ... (additional model variants)
265
+ +-- learning_munsell/ # Source code
266
+ | +-- analysis/ # Analysis scripts
267
+ | +-- comparison/ # Model comparison scripts
268
+ | +-- data_generation/ # Data generation scripts
269
+ | +-- interpolation/ # Classical interpolation methods
270
+ | +-- losses/ # Loss functions (JAX Delta-E)
271
+ | +-- models/ # Model architecture definitions
272
+ | +-- training/ # Model training scripts
273
+ | +-- utilities/ # Shared utilities
274
+ +-- docs/ # Documentation
275
+ +-- reports/ # HTML comparison reports
276
+ +-- logs/ # Script output logs
277
+ +-- mlruns.db # MLflow experiment tracking database
278
+ ```
279
+
280
+ ## About
281
+
282
+ **Learning Munsell** by Colour Developers
283
+ Research project for the Colour library
284
+ <https://github.com/colour-science/colour>
docs/_static/gamma_sweep_plot.pdf ADDED
Binary file (22.2 kB). View file
 
docs/_static/gamma_sweep_plot.png ADDED

Git LFS Details

  • SHA256: e2a0d5dc57c0d37d5889cff4ac41a08b490387a54615d4372af5e5bd86018e36
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
docs/learning_munsell.md ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learning Munsell
2
+
3
+ Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings.
4
+
5
+ ## Overview
6
+
7
+ This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications:
8
+
9
+ - **xyY to Munsell (from_xyY)**: 33 models, best Delta-E 0.51
10
+ - **Munsell to xyY (to_xyY)**: 6 models, best Delta-E 0.48
11
+
12
+ ### Delta-E Interpretation
13
+
14
+ - **< 1.0**: Not perceptible by human eye
15
+ - **1-2**: Perceptible through close observation
16
+ - **2-10**: Perceptible at a glance
17
+ - **> 10**: Colors are perceived as completely different
18
+
19
+ Our best models achieve **Delta-E 0.48-0.51**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**.
20
+
21
+ ## xyY to Munsell (from_xyY)
22
+
23
+ ### Performance Benchmarks
24
+
25
+ 33 models compared using all 2,734 REAL Munsell colors:
26
+
27
+ | Model | Delta-E | Speed (ms) |
28
+ | ---------------------------------------------------------------------------------- | -------- | ---------- |
29
+ | Colour Library (Baseline) | 0.00 | 116.3 |
30
+ | **Multi-MLP + Multi-Error Predictor** | **0.51** | 0.061 |
31
+ | Multi-ResNet + Multi-Error Predictor (Large Dataset) | 0.52 | 0.096 |
32
+ | Transformer + Error Predictor (Large Dataset) | 0.52 | 0.163 |
33
+ | Multi-Head + Multi-Error Predictor (Large Dataset) | 0.53 | 0.046 |
34
+ | Multi-MLP (Classification Code) + Code-Aware Multi-Error Predictor (Large Dataset) | 0.53 | 0.118 |
35
+ | Multi-MLP (Classification Code) + Code-Aware Multi-Error Predictor | 0.53 | 0.052 |
36
+ | Multi-MLP (Classification Code) + Multi-Error Predictor | 0.53 | 0.050 |
37
+ | MLP + Error Predictor | 0.53 | 0.036 |
38
+ | Multi-Head + Multi-Error Predictor | 0.54 | 0.057 |
39
+ | Multi-ResNet (Large Dataset) | 0.56 | 0.047 |
40
+ | Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.56 | 0.060 |
41
+ | Multi-Head + Cross-Attention Error Predictor | 0.57 | 0.032 |
42
+ | Multi-Head (Large Dataset) | 0.57 | 0.012 |
43
+ | FT-Transformer | 0.70 | 0.067 |
44
+ | Unified MLP | 0.71 | 0.074 |
45
+ | Transformer (Large Dataset) | 0.73 | 0.120 |
46
+ | Mixture of Experts | 0.74 | 0.021 |
47
+ | Multi-MLP | 0.91 | 0.027 |
48
+ | Multi-MLP (Classification Code) | 0.92 | 0.026 |
49
+ | Multi-MLP (Classification Code) (Large Dataset) | 0.89 | 0.066 |
50
+ | MLP + Self-Attention | 0.88 | 0.185 |
51
+ | Deep + Wide | 1.18 | 0.078 |
52
+ | MLP (Base Only) | 1.30 | **0.009** |
53
+ | Multi-MLP (Hue Angle sin/cos) | 1.78 | 0.021 |
54
+
55
+ Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate). See the full HTML report for all 33 models.
56
+
57
+ **Best Models**:
58
+
59
+ - **Best Accuracy**: Multi-MLP + Multi-Error Predictor — Delta-E 0.51, 1,905x faster
60
+ - **Fastest**: MLP Base Only (0.009 ms/sample) — 13,365x faster than Colour library
61
+ - **Best Balance**: MLP + Error Predictor — 3,196x faster with Delta-E 0.53
62
+
63
+ ### Model Architectures
64
+
65
+ 33 models were systematically evaluated:
66
+
67
+ **Single-Stage Models**
68
+
69
+ 1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs
70
+ 2. **Unified MLP** - Single large MLP with shared features
71
+ 3. **Multi-Head** - Shared encoder with 4 independent decoder heads
72
+ 4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples
73
+ 5. **Multi-MLP** - 4 completely independent MLP branches (one per output)
74
+ 6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples
75
+ 7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting
76
+ 8. **Deep + Wide** - Combined deep and wide network paths
77
+ 9. **Mixture of Experts** - Gating network selecting specialized expert networks
78
+ 10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data
79
+ 11. **FT-Transformer** - Feature Tokenizer Transformer (standard size)
80
+
81
+ **Two-Stage Models**
82
+
83
+ 12. **MLP + Error Predictor** - Base MLP with unified error correction
84
+ 13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors
85
+ 14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant
86
+ 15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors
87
+ 16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant
88
+ 17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches
89
+ 18. **Transformer + Error Predictor (Large Dataset)** - Transformer base with error correction
90
+ 19. **Multi-Head + Cross-Attention Error Predictor** - Cross-attention mechanism for error correction
91
+ 20. **Multi-Head + 3-Stage Error Predictor** - Three cascaded error correction stages
92
+
93
+ **Alternative Code Prediction Models**
94
+
95
+ 21. **Multi-MLP (Classification Code)** - 3 regression branches + 1 classification branch (10 logits for hue code)
96
+ 22. **Multi-MLP (Classification Code) + Multi-Error Predictor** - Classification Code base with 3-branch error predictor (hue, value, chroma)
97
+ 23. **Multi-MLP (Classification Code) + Code-Aware Multi-Error Predictor** - Classification Code base with code-aware 3-branch error predictor (input: xyY + regression + code one-hot)
98
+ 24. **Multi-MLP (Classification Code) (Large Dataset)** - Classification Code trained on 1.4M samples
99
+ 25. **Multi-MLP (Classification Code) + Code-Aware Multi-Error Predictor (Large Dataset)** - Large dataset variant of code-aware error predictor
100
+ 26. **Multi-MLP (Hue Angle sin/cos)** - Encodes full Munsell angle as sin/cos pair, eliminates separate code branch
101
+ The **Multi-MLP + Multi-Error Predictor** architecture achieved the best results with Delta-E 0.51.
102
+
103
+ ### Training Methodology
104
+
105
+ **Data Generation**
106
+
107
+ 1. **Dense xyY Grid** (~500K samples)
108
+ - Regular grid in valid xyY space (MacAdam limits for Illuminant C)
109
+ - Captures general input distribution
110
+ 2. **Boundary Refinement** (~700K samples)
111
+ - Adaptive dense sampling near Munsell gamut boundaries
112
+ - Uses `maximum_chroma_from_renotation` to detect edges
113
+ - Focuses on regions where iterative algorithm is most complex
114
+ - Includes Y/GY/G hue regions with high value/chroma (challenging areas)
115
+ 3. **Forward Augmentation** (~200K samples)
116
+ - Dense Munsell space sampling via `munsell_specification_to_xyY`
117
+ - Ensures coverage of known valid colors
118
+
119
+ Total: ~1.4M samples for large dataset training.
120
+
121
+ **Loss Functions**
122
+
123
+ Two loss function approaches were tested:
124
+
125
+ _Precision-Focused Loss_ (Default):
126
+
127
+ ```
128
+ total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss
129
+ ```
130
+
131
+ - MSE: Standard mean squared error
132
+ - MAE: Mean absolute error
133
+ - Log penalty: Heavily penalizes small errors (pushes toward high precision)
134
+ - Huber loss: Small delta (0.01) for precision on small errors
135
+
136
+ _Pure MSE Loss_ (Optimized config):
137
+
138
+ ```
139
+ total_loss = MSE
140
+ ```
141
+
142
+ Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy.
143
+
144
+ ### Design Rationale
145
+
146
+ **Two-Stage Architecture**
147
+
148
+ The error predictor stage corrects systematic biases in the base model:
149
+
150
+ 1. Base model learns the general xyY to Munsell mapping
151
+ 2. Error predictor learns residual corrections specific to each component
152
+ 3. Combined prediction: `final = base_prediction + error_correction`
153
+
154
+ This decomposition allows each stage to specialize and reduces the complexity each network must learn.
155
+
156
+ **Independent Branch Design**
157
+
158
+ Munsell components have different characteristics:
159
+
160
+ - **Hue**: Circular (0-10, wrapping), most complex
161
+ - **Value**: Linear (0-10), easiest to predict
162
+ - **Chroma**: Highly variable range depending on hue/value
163
+ - **Code**: Discrete hue sector (0-9)
164
+
165
+ Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization.
166
+
167
+ **Architecture Details**
168
+
169
+ _MLP (Base Only)_
170
+
171
+ Simple feedforward network predicting all 4 outputs simultaneously:
172
+
173
+ Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code)
174
+
175
+ - Smallest model (~2.3 MB ONNX)
176
+ - Fastest inference (0.009 ms)
177
+ - Baseline for comparison
178
+
179
+ _Unified MLP_
180
+
181
+ Single large MLP with shared internal features:
182
+
183
+ Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4)
184
+
185
+ - Shared representations across all outputs
186
+ - Moderate size, good speed
187
+
188
+ _Multi-Head MLP_
189
+
190
+ Shared encoder with specialized decoder heads:
191
+
192
+ Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1)
193
+ ├──► Value Head (512→256→128→1)
194
+ ├──► Chroma Head (512→384→256→128→1)
195
+ └──► Code Head (512→256→128→1)
196
+
197
+ - Shared encoder learns common color space features
198
+ - 4 specialized decoder heads branch from shared representation
199
+ - Parameter efficient (encoder weights shared)
200
+ - Fast inference (encoder computed once)
201
+
202
+ _Multi-MLP_
203
+
204
+ Fully independent branches with no weight sharing:
205
+
206
+ Input (3) ──► Hue Branch (3→128→256→512→256→128→1)
207
+ Input (3) ──► Value Branch (3→128→256→512→256→128→1)
208
+ Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider]
209
+ Input (3) ──► Code Branch (3→128→256→512→256→128→1)
210
+
211
+ - 4 completely independent MLPs
212
+ - Each branch learns its own features from scratch
213
+ - Chroma branch is wider (2x) to handle its complexity
214
+ - Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors)
215
+
216
+ _Multi-ResNet_
217
+
218
+ Deep branches with residual-style connections:
219
+
220
+ Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers]
221
+ Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers]
222
+ Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider]
223
+ Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers]
224
+
225
+ - Deeper architecture than Multi-MLP
226
+ - BatchNorm + SiLU activation
227
+ - Strong accuracy when combined with error predictor (Delta-E 0.52)
228
+ - Largest model (~14MB base, ~28MB with error predictor)
229
+
230
+ _Deep + Wide_
231
+
232
+ Combined deep and wide network paths:
233
+
234
+ Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4)
235
+ └──► Wide Path (direct connection) ─┘
236
+
237
+ - Deep path captures complex patterns
238
+ - Wide path preserves direct input information
239
+ - Good for mixed linear/nonlinear relationships
240
+
241
+ _MLP + Self-Attention_
242
+
243
+ MLP with attention mechanism for feature weighting:
244
+
245
+ Input (3) ──► MLP ──► Self-Attention ──► Output (4)
246
+
247
+ - Attention weights learn feature importance
248
+ - Slower due to attention computation (0.173 ms)
249
+ - Did not improve over simpler MLPs
250
+
251
+ _Mixture of Experts_
252
+
253
+ Gating network selecting specialized expert networks:
254
+
255
+ Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4)
256
+
257
+ - Multiple expert networks specialize in different input regions
258
+ - Gating network learns which expert to use
259
+ - More complex but did not outperform Multi-MLP
260
+
261
+ _FT-Transformer_
262
+
263
+ Feature Tokenizer Transformer for tabular data:
264
+
265
+ Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4)
266
+
267
+ - Each input feature tokenized separately
268
+ - Self-attention across feature tokens
269
+ - Good for tabular data with feature interactions
270
+ - Slower inference due to attention computation
271
+
272
+ _Error Predictor (Two-Stage)_
273
+
274
+ Second-stage network that corrects base model errors:
275
+
276
+ Stage 1: Input (3) ──► Base Model ──► Base Prediction (4)
277
+ Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4)
278
+ Final: Base Prediction + Error Correction = Final Output
279
+
280
+ - Learns residual corrections for each component
281
+ - Can have unified (1 network) or multi (4 networks) error predictors
282
+ - Consistently improves accuracy across all base architectures
283
+ - Best results: Multi-MLP + Multi-Error Predictor (Delta-E 0.51)
284
+
285
+ **Loss-Metric Mismatch**
286
+
287
+ An important finding: **optimizing MSE does not optimize Delta-E**.
288
+
289
+ The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because:
290
+
291
+ - MSE treats all component errors equally
292
+ - Delta-E (CIE2000) weights errors based on human perception
293
+ - The precision-focused loss with custom weights better approximates perceptual importance
294
+
295
+ **Weighted Boundary Loss (Experimental)**
296
+
297
+ Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by:
298
+
299
+ 1. Applying 3x loss weight to samples in challenging regions:
300
+ - Hue: 0.18-0.35 (normalized range covering Y/YG/G)
301
+ - Value > 0.7 (high brightness)
302
+ - Chroma > 0.5 (high saturation)
303
+ 2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits
304
+
305
+ **Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model and the standard dataset models achieve similar results, making explicit loss weighting optional. The best overall model is Multi-MLP + Multi-Error Predictor with Delta-E 0.51.
306
+
307
+ ### Hue Code Boundary Analysis
308
+
309
+ The best model (Multi-MLP + Multi-Error Predictor, Delta-E 0.51) treats the hue code as a continuous regression target. After rounding to the nearest integer, **13.1% of predictions** (335/2,563 real Munsell colours) have incorrect hue codes — mostly off-by-one at hue family boundaries (e.g., model predicts code 9.198, rounds to 9 instead of correct 10).
310
+
311
+ While the forward direction (Munsell to xyY) is essentially perfect, these code errors produce large visible colour shifts in round-trip (xyY to Munsell to xyY) comparisons because adjacent hue codes represent different hue families (e.g., P vs PB).
312
+
313
+ Two alternative architectures were designed and evaluated to address this:
314
+
315
+ **Approach A: Classification Head for Code** (`MultiMLPClassCodeToMunsell`)
316
+
317
+ - 3 regression branches (hue, value, chroma) identical to standard Multi-MLP
318
+ - 1 classification branch outputting 10 logits (one per hue code 1-10)
319
+ - Loss: weighted MSE on regression targets + cross-entropy on code class
320
+ - Inference: `code = argmax(logits) + 1`
321
+ - Rationale: Discrete classification avoids the rounding boundary problem entirely
322
+
323
+ Architecture:
324
+
325
+ Input (3) ──► Hue Branch (3→128→256→512→256→128→1)
326
+ Input (3) ──► Value Branch (3→128→256→512→256→128→1)
327
+ Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider]
328
+ Input (3) ──► Code Branch (3→128→256→512→256→128→10) [10 logits]
329
+
330
+ **Approach B: Hue Angle sin/cos Encoding** (`MultiMLPHueAngleToMunsell`)
331
+
332
+ - Encodes the full Munsell angle (`hue + (code - 1) * 10`) as sin/cos pair
333
+ - Eliminates separate hue and code branches entirely
334
+ - Loss: MSE on [sin, cos, value_norm, chroma_norm]
335
+ - Inference: `angle = atan2(sin, cos) * 100 / (2*pi)`, then decompose to hue + code
336
+ - Rationale: Continuous circular encoding naturally handles wrap-around
337
+
338
+ Architecture:
339
+
340
+ Input (3) ──► Hue Angle Branch (3→128→256→512→256→128→2) [sin, cos]
341
+ Input (3) ──► Value Branch (3→128→256→512→256→128→1)
342
+ Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider]
343
+
344
+ **Results**
345
+
346
+ | Model | Delta-E | Code Accuracy | Code MAE | Speed (ms) |
347
+ | ---------------------------------------------------------- | -------- | ------------- | -------- | ---------- |
348
+ | Multi-MLP + Multi-Error Predictor | 0.51 | 86.9% | - | 0.058 |
349
+ | **Classification Code + Code-Aware Multi-Error Predictor** | **0.53** | **100.0%** | 0.0004 | 0.052 |
350
+ | Classification Code + Multi-Error Predictor | 0.53 | **100.0%** | 0.0004 | 0.050 |
351
+ | Multi-MLP (Classification Code) | 0.92 | **100.0%** | 0.0004 | 0.026 |
352
+ | Multi-MLP (Hue Angle sin/cos) | 1.78 | 96.3% | - | 0.021 |
353
+
354
+ _Key Findings_:
355
+
356
+ - **Classification Code + Code-Aware Multi-Error Predictor** achieves the best of both worlds: **perfect code accuracy (100%)** with **Delta-E 0.53** — competitive with the overall best model (0.51). The code-aware 3-branch error predictor receives the one-hot encoded classified hue code (10 dims) alongside xyY input and regression predictions (input_dim=16), allowing each branch to learn hue-family-specific corrections. This marginally outperforms the standard multi-error predictor (Delta-E 0.5255 vs 0.5278).
357
+ - **Classification Code** (base only) achieves perfect code accuracy but Delta-E 0.92 without error correction.
358
+ - **Hue Angle sin/cos** improves code accuracy to 96.3% (vs 86.9% baseline) but has the worst Delta-E (1.78).
359
+
360
+ ### Experimental Findings
361
+
362
+ The following experiments were conducted but did not improve results:
363
+
364
+ **Circular Hue Encoding**
365
+
366
+ Two approaches to encoding hue as sin/cos were tested, motivated by systematic hue regression errors at hue family boundaries (e.g., 10B, 10RP, 10R) where the scalar hue regression suffers from the 0/10 discontinuity:
367
+
368
+ | Swatch | Centore | ONNX (Large) | Max RGB diff |
369
+ | -------- | ------------- | ------------- | ------------ |
370
+ | 10B 4/6 | 1.1PB 4.0/5.9 | 0.3PB 4.0/5.9 | 0.157 |
371
+ | 10RP 4/6 | 0.8R 4.0/5.8 | 2.7RP 4.1/5.7 | 0.098 |
372
+ | 10R 9/6 | 0.7YR 9.1/6.4 | 6.5YR 9.1/6.5 | 0.081 |
373
+ | 10RP 3/6 | 0.6R 3.0/5.8 | 4.8RP 3.1/5.8 | 0.072 |
374
+
375
+ _Approach A — Full Angle sin/cos_ (`MultiMLPHueAngleToMunsell`): Encodes the full Munsell angle (`hue + (code - 1) * 10`, range 0-100) as a sin/cos pair, eliminating the separate code branch entirely.
376
+
377
+ _Approach B — Within-Family sin/cos + Classification Code_ (`MultiMLPClassCodeCircularToMunsell`): Retains the perfect classification code head while encoding only the within-family hue (0-10) as sin/cos, hypothesizing that the boundary discontinuity could be smoothed without losing code accuracy.
378
+
379
+ _Results_:
380
+
381
+ | Model | Delta-E | Hue MAE | Code Acc. |
382
+ | --------------------------------------------------- | ------- | ------- | --------- |
383
+ | Classification Code + Code-Aware Error Pred (Large) | 0.53 | 0.045 | 100.0% |
384
+ | Full Angle sin/cos (Approach A) | 1.78 | - | 96.3% |
385
+ | Within-Family sin/cos + Code (Approach B, Large) | 1.71 | 0.817 | 100.0% |
386
+ | Approach B + Code-Aware Error Predictor (Large) | 1.88 | 1.140 | 100.0% |
387
+
388
+ Both circular approaches performed significantly worse than scalar hue regression (Delta-E 1.71-1.88 vs 0.53).
389
+
390
+ _Key Takeaway_: **Hue within a Munsell family is not circular.** Within a single family, hue 0 and hue 10 are genuinely different points (opposite edges of the family). The sin/cos encoding wraps them as if they were the same, confusing the network and degrading hue prediction by ~18x (MAE 0.817 vs 0.045). The circularity only exists across the full 0-100 Munsell angle (where 100 = 0), but Approach A showed that encoding the full angle is also ineffective — the atan2 recovery introduces its own ambiguities across the 10 hue families. The boundary errors in the scalar regression model (max RGB diff 0.157 on 4 out of 2,734 swatches) are a better trade-off than the global accuracy degradation from circular encoding.
391
+
392
+ **Delta-E Training**
393
+
394
+ Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator.
395
+
396
+ _Hypothesis_: Perceptual Delta-E loss might outperform MSE-trained models.
397
+
398
+ _Implementation_: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients.
399
+
400
+ _Results_: While Delta-E is comparable, **hue accuracy is ~10x worse**:
401
+
402
+ | Metric (Normalized MAE) | Delta-E Model | MSE Model |
403
+ | ------------------------ | ------------- | --------- |
404
+ | Hue MAE | 0.30 | 0.03 |
405
+ | Value MAE | 0.002 | 0.004 |
406
+ | Chroma MAE | 0.007 | 0.008 |
407
+ | Code MAE | 0.07 | 0.01 |
408
+ | **Delta-E (perceptual)** | **0.52** | **0.50** |
409
+
410
+ _Key Takeaway_: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification.
411
+
412
+ **Classical Interpolation**
413
+
414
+ Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors.
415
+
416
+ _Results (Validation MAE)_:
417
+
418
+ | Component | RBF | KD-Tree | Delaunay | ML (Best) |
419
+ | --------- | ---- | ------- | -------- | --------- |
420
+ | Hue | 1.40 | 1.40 | 1.29 | **0.03** |
421
+ | Value | 0.01 | 0.10 | 0.02 | 0.05 |
422
+ | Chroma | 0.22 | 0.99 | 0.35 | **0.11** |
423
+ | Code | 0.33 | 0.28 | 0.28 | **0.00** |
424
+
425
+ _Key Insight_: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy.
426
+
427
+ **Circular Hue Loss**
428
+
429
+ Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps).
430
+
431
+ _Results_: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24). Combined with the circular encoding findings above, this provides strong evidence that within-family Munsell hue should be treated as a bounded linear quantity, not a circular one.
432
+
433
+ _Key Takeaway_: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization.
434
+
435
+ **REAL-Only Refinement**
436
+
437
+ Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995).
438
+
439
+ _Results_: Essentially identical performance (Delta-E 1.5233 vs 1.5191).
440
+
441
+ _Key Takeaway_: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate.
442
+
443
+ **Gamma Normalization**
444
+
445
+ Gamma correction to the Y (luminance) channel during normalization.
446
+
447
+ _Results_: No consistent improvement across gamma values 1.0-3.0:
448
+
449
+ | Gamma | Median ΔE (± std) |
450
+ | -------------- | ----------------- |
451
+ | 1.0 (baseline) | 0.730 ± 0.054 |
452
+ | 2.5 (best) | 0.683 ± 0.132 |
453
+
454
+ ![Gamma sweep results](_static/gamma_sweep_plot.png)
455
+
456
+ _Key Takeaway_: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise.
457
+
458
+ **Uniform Random Sampling**
459
+
460
+ The default training data generator perturbs around the 4,995 base colours from `MUNSELL_COLOURS_ALL` (hue prefixes 2.5, 5, 7.5, 10), creating islands of coverage with gaps between prefixes and zero samples below hue 1.0. A uniform random sampling approach was tested to fill these gaps.
461
+
462
+ _Implementation_: Sample hue uniformly in [0, 10], value in [1, 9], chroma in [0, 50], and code uniformly from {1, ..., 10}. Invalid specifications are discarded. 2M valid samples generated.
463
+
464
+ ![Training data density](training_data_large_hue_density.png)
465
+
466
+ _Results_:
467
+
468
+ | Model | Delta-E | Code Acc. | Hue MAE |
469
+ | --------------------------------------------- | ------- | --------- | ------- |
470
+ | Class. Code + Error Pred (perturbation, 500K) | 0.53 | 100.0% | 0.048 |
471
+ | Class. Code + Error Pred (uniform, 2M) | 2.23 | 82.2% | 1.186 |
472
+
473
+ _Key Takeaway_: **Uniform random sampling produces significantly worse models** (Delta-E 2.23 vs 0.53, code accuracy 82.2% vs 100%). The perturbation-based approach concentrates samples around actual Munsell colours where the interpolation function has structure, while uniform sampling wastes density on regions of the space where many hue/value/chroma combinations are out of gamut and the accepted samples are too sparse to learn the mapping accurately. The coverage gaps at hue boundaries in the perturbation-based data are a minor visual artefact (affecting ~4 out of 2,734 swatches) that does not justify the dramatic accuracy loss from uniform sampling.
474
+
475
+ ## Munsell to xyY (to_xyY)
476
+
477
+ ### Performance Benchmarks
478
+
479
+ 6 models compared using all 2,734 REAL Munsell colors:
480
+
481
+ | Model | Delta-E | Speed (ms) |
482
+ | ------------------------------------- | -------- | ---------- |
483
+ | Colour Library (Baseline) | 0.00 | 1.40 |
484
+ | **Multi-MLP + Multi-Error Predictor** | **0.48** | 0.066 |
485
+ | Multi-Head + Multi-Error Predictor | 0.51 | 0.054 |
486
+ | Simple MLP | 0.52 | 0.002 |
487
+ | Multi-MLP | 0.57 | 0.028 |
488
+ | Multi-Head | 0.60 | 0.013 |
489
+ | Multi-MLP + Error Predictor | 0.61 | 0.060 |
490
+
491
+ **Best Models**:
492
+
493
+ - **Best Accuracy**: Multi-MLP + Multi-Error Predictor — Delta-E 0.48, 21x faster
494
+ - **Fastest**: Simple MLP (0.002 ms/sample) — 669x faster than Colour library
495
+
496
+ ### Model Architectures
497
+
498
+ 6 models were evaluated for the Munsell to xyY direction:
499
+
500
+ **Single-Stage Models**
501
+
502
+ 1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs
503
+ 2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y)
504
+ 3. **Multi-MLP** - 3 completely independent MLP branches
505
+
506
+ **Two-Stage Models**
507
+
508
+ 4. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction
509
+ 5. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors (BEST)
510
+ 6. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction
511
+
512
+ The **Multi-MLP + Multi-Error Predictor** architecture achieved the best results with Delta-E 0.48.
513
+
514
+ ### Differentiable Approximator
515
+
516
+ A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss:
517
+
518
+ - **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU
519
+ - **Accuracy**: MAE ~0.0006 for x, y, and Y components
520
+ - **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz)
521
+
522
+ This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables.
523
+
524
+ ## Shared Infrastructure
525
+
526
+ ### Hyperparameter Optimization
527
+
528
+ Optuna was used for systematic hyperparameter search over:
529
+
530
+ - Learning rate (1e-4 to 1e-3)
531
+ - Batch size (256, 512, 1024)
532
+ - Dropout rate (0.0 to 0.2)
533
+ - Chroma branch width multiplier (1.0 to 2.0)
534
+ - Loss function weights (MSE, Huber)
535
+
536
+ Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization.
537
+
538
+ ### Training Infrastructure
539
+
540
+ - **Optimizer**: AdamW with weight decay
541
+ - **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5)
542
+ - **Early stopping**: Patience=20 epochs
543
+ - **Checkpointing**: Best model saved based on validation loss
544
+ - **Logging**: MLflow for experiment tracking
545
+
546
+ ### JAX Delta-E Implementation
547
+
548
+ Located in `learning_munsell/losses/jax_delta_e.py`:
549
+
550
+ - Differentiable xyY -> XYZ -> Lab color space conversions
551
+ - Full CIE 2000 Delta-E implementation with gradient support
552
+ - JIT-compiled functions for performance
553
+
554
+ Usage:
555
+
556
+ ```python
557
+ from learning_munsell.losses import delta_E_loss, delta_E_CIE2000
558
+
559
+ # Compute perceptual loss between predicted and target xyY
560
+ loss = delta_E_loss(pred_xyY, target_xyY)
561
+ ```
562
+
563
+ ## Limitations
564
+
565
+ ### BatchNorm Instability on MPS
566
+
567
+ Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend:
568
+
569
+ 1. **Validation loss spikes** during training
570
+ 2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1)
571
+ 3. **Non-reproducible behavior**
572
+
573
+ **Affected Models**: Large dataset error predictors using BatchNorm.
574
+
575
+ **Workarounds**:
576
+
577
+ 1. Use CPU for training
578
+ 2. Replace BatchNorm with LayerNorm
579
+ 3. Use smaller models (300K samples vs 2M)
580
+ 4. Skip error predictor stage for affected models
581
+
582
+ The recommended production model (`multi_mlp_multi_error_predictor.onnx`) does not exhibit this instability.
583
+
584
+ **References**:
585
+
586
+ - [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602)
587
+ - [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230)
588
+ - [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/)
docs/training_data_large_hue_density.png ADDED

Git LFS Details

  • SHA256: 64aa9228d2ce0eb0db3697681b18c859236f82b5b7cb70ee8e45e0ff502e20d6
  • Pointer size: 130 Bytes
  • Size of remote file: 51.8 kB
learning_munsell/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Learning Munsell - Machine Learning for Munsell Color Conversions."""
2
+
3
+ from pathlib import Path
4
+
5
+ __all__ = ["PROJECT_ROOT"]
6
+
7
+ PROJECT_ROOT = Path(__file__).parent.parent
learning_munsell/analysis/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Analysis utilities for Munsell color conversion models."""
learning_munsell/analysis/error_analysis.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analyze error distribution to identify problematic regions in Munsell space.
3
+
4
+ This script:
5
+ 1. Runs the best model on all REAL Munsell colors
6
+ 2. Computes Delta-E for each sample
7
+ 3. Identifies samples with high error (Delta-E > threshold)
8
+ 4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues
9
+ 5. Outputs statistics and visualizations
10
+ """
11
+
12
+ import logging
13
+ from collections import defaultdict
14
+
15
+ import numpy as np
16
+ import onnxruntime as ort
17
+ from colour import XYZ_to_Lab, xyY_to_XYZ
18
+ from colour.difference import delta_E_CIE2000
19
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
20
+ from colour.notation.munsell import (
21
+ CCS_ILLUMINANT_MUNSELL,
22
+ munsell_colour_to_munsell_specification,
23
+ munsell_specification_to_xyY,
24
+ )
25
+
26
+ from learning_munsell import PROJECT_ROOT
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
29
+ LOGGER = logging.getLogger(__name__)
30
+
31
+ HUE_NAMES = {
32
+ 1: "R",
33
+ 2: "YR",
34
+ 3: "Y",
35
+ 4: "GY",
36
+ 5: "G",
37
+ 6: "BG",
38
+ 7: "B",
39
+ 8: "PB",
40
+ 9: "P",
41
+ 10: "RP",
42
+ 0: "RP",
43
+ }
44
+
45
+
46
+ def load_model_and_params(model_name: str) -> tuple:
47
+ """Load ONNX model and normalization parameters."""
48
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
49
+
50
+ model_path = model_dir / f"{model_name}.onnx"
51
+ params_path = model_dir / f"{model_name}_normalization_parameters.npz"
52
+
53
+ if not model_path.exists():
54
+ msg = f"Model not found: {model_path}"
55
+ raise FileNotFoundError(msg)
56
+ if not params_path.exists():
57
+ msg = f"Params not found: {params_path}"
58
+ raise FileNotFoundError(msg)
59
+
60
+ session = ort.InferenceSession(str(model_path))
61
+ params = np.load(params_path, allow_pickle=True)
62
+ input_parameters = params["input_parameters"].item()
63
+ output_parameters = params["output_parameters"].item()
64
+
65
+ return session, input_parameters, output_parameters
66
+
67
+
68
+ def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray:
69
+ """Normalize xyY input."""
70
+ normalized = np.copy(xyY).astype(np.float32)
71
+ # Scale Y from 0-100 to 0-1 range before normalization
72
+ normalized[..., 2] = xyY[..., 2] / 100.0
73
+ normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / (
74
+ params["x_range"][1] - params["x_range"][0]
75
+ )
76
+ normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / (
77
+ params["y_range"][1] - params["y_range"][0]
78
+ )
79
+ normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / (
80
+ params["Y_range"][1] - params["Y_range"][0]
81
+ )
82
+ return normalized
83
+
84
+
85
+ def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray:
86
+ """Denormalize Munsell output."""
87
+ denorm = np.copy(pred)
88
+ denorm[..., 0] = (
89
+ pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
90
+ + params["hue_range"][0]
91
+ )
92
+ denorm[..., 1] = (
93
+ pred[..., 1] * (params["value_range"][1] - params["value_range"][0])
94
+ + params["value_range"][0]
95
+ )
96
+ denorm[..., 2] = (
97
+ pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
98
+ + params["chroma_range"][0]
99
+ )
100
+ denorm[..., 3] = (
101
+ pred[..., 3] * (params["code_range"][1] - params["code_range"][0])
102
+ + params["code_range"][0]
103
+ )
104
+ return denorm
105
+
106
+
107
+ def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float:
108
+ """Compute Delta-E between predicted spec (via xyY) and ground truth xyY."""
109
+ try:
110
+ pred_xyY = munsell_specification_to_xyY(pred_spec)
111
+ pred_XYZ = xyY_to_XYZ(pred_xyY)
112
+ pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL)
113
+
114
+ # Ground truth Y is in 0-100 range, need to scale to 0-1
115
+ gt_xyY_scaled = gt_xyY.copy()
116
+ gt_xyY_scaled[2] = gt_xyY[2] / 100.0
117
+ gt_XYZ = xyY_to_XYZ(gt_xyY_scaled)
118
+ gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL)
119
+
120
+ return delta_E_CIE2000(gt_Lab, pred_Lab)
121
+ except Exception: # noqa: BLE001
122
+ return np.nan
123
+
124
+
125
+ def analyze_errors(
126
+ model_name: str = "multi_head_large", threshold: float = 3.0
127
+ ) -> list:
128
+ """Analyze error distribution for a model."""
129
+ LOGGER.info("=" * 80)
130
+ LOGGER.info("Error Analysis for %s", model_name)
131
+ LOGGER.info("=" * 80)
132
+
133
+ # Load model
134
+ session, input_parameters, output_parameters = load_model_and_params(model_name)
135
+ input_name = session.get_inputs()[0].name
136
+
137
+ # Collect data
138
+ results = []
139
+
140
+ for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL:
141
+ hue_code_str, value, chroma = munsell_spec_tuple
142
+ munsell_str = f"{hue_code_str} {value}/{chroma}"
143
+
144
+ try:
145
+ gt_spec = munsell_colour_to_munsell_specification(munsell_str)
146
+ gt_xyY = np.array(xyY_gt)
147
+
148
+ # Predict
149
+ xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_parameters)
150
+ pred_norm = session.run(None, {input_name: xyY_norm})[0]
151
+ pred_spec = denormalize_output(pred_norm, output_parameters)[0]
152
+
153
+ # Clamp to valid ranges
154
+ pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0)
155
+ pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0)
156
+ pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0)
157
+ pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0)
158
+ pred_spec[3] = np.round(pred_spec[3])
159
+
160
+ # Compute Delta-E
161
+ delta_e = compute_delta_e(pred_spec, gt_xyY)
162
+
163
+ if not np.isnan(delta_e):
164
+ results.append(
165
+ {
166
+ "munsell_str": munsell_str,
167
+ "gt_spec": gt_spec,
168
+ "pred_spec": pred_spec,
169
+ "delta_e": delta_e,
170
+ "hue": gt_spec[0],
171
+ "value": gt_spec[1],
172
+ "chroma": gt_spec[2],
173
+ "code": int(gt_spec[3]),
174
+ "gt_xyY": gt_xyY,
175
+ }
176
+ )
177
+ except Exception as e: # noqa: BLE001
178
+ LOGGER.warning("Failed for %s: %s", munsell_str, e)
179
+
180
+ LOGGER.info("\nTotal samples evaluated: %d", len(results))
181
+
182
+ # Overall statistics
183
+ delta_es = [r["delta_e"] for r in results]
184
+ LOGGER.info("\nOverall Delta-E Statistics:")
185
+ LOGGER.info(" Mean: %.4f", np.mean(delta_es))
186
+ LOGGER.info(" Median: %.4f", np.median(delta_es))
187
+ LOGGER.info(" Std: %.4f", np.std(delta_es))
188
+ LOGGER.info(" Min: %.4f", np.min(delta_es))
189
+ LOGGER.info(" Max: %.4f", np.max(delta_es))
190
+
191
+ # Distribution
192
+ LOGGER.info("\nDelta-E Distribution:")
193
+ for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]:
194
+ count = sum(1 for d in delta_es if d <= thresh)
195
+ pct = 100 * count / len(delta_es)
196
+ LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct)
197
+
198
+ # High error samples
199
+ high_error = [r for r in results if r["delta_e"] > threshold]
200
+ LOGGER.info(
201
+ "\nSamples with Delta-E > %.1f: %d (%.1f%%)",
202
+ threshold,
203
+ len(high_error),
204
+ 100 * len(high_error) / len(results),
205
+ )
206
+
207
+ # Analyze by hue family
208
+ LOGGER.info("\n%s", "=" * 40)
209
+ LOGGER.info("Analysis by Hue Family")
210
+ LOGGER.info("=" * 40)
211
+
212
+ by_hue = defaultdict(list)
213
+ for r in results:
214
+ hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}")
215
+ by_hue[hue_name].append(r["delta_e"])
216
+
217
+ LOGGER.info(
218
+ "\n%-4s %5s %6s %6s %6s %s",
219
+ "Hue",
220
+ "Count",
221
+ "Mean",
222
+ "Median",
223
+ "Max",
224
+ ">3.0",
225
+ )
226
+ for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]:
227
+ if hue_name in by_hue:
228
+ des = by_hue[hue_name]
229
+ high = sum(1 for d in des if d > 3.0)
230
+ LOGGER.info(
231
+ "%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
232
+ hue_name,
233
+ len(des),
234
+ np.mean(des),
235
+ np.median(des),
236
+ np.max(des),
237
+ high,
238
+ 100 * high / len(des),
239
+ )
240
+
241
+ # Analyze by value range
242
+ LOGGER.info("\n%s", "=" * 40)
243
+ LOGGER.info("Analysis by Value Range")
244
+ LOGGER.info("=" * 40)
245
+
246
+ value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)]
247
+ LOGGER.info(
248
+ "\n%-8s %5s %6s %6s %6s %s",
249
+ "Value",
250
+ "Count",
251
+ "Mean",
252
+ "Median",
253
+ "Max",
254
+ ">3.0",
255
+ )
256
+ for v_min, v_max in value_ranges:
257
+ des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max]
258
+ if des:
259
+ high = sum(1 for d in des if d > 3.0)
260
+ LOGGER.info(
261
+ "[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
262
+ v_min,
263
+ v_max,
264
+ len(des),
265
+ np.mean(des),
266
+ np.median(des),
267
+ np.max(des),
268
+ high,
269
+ 100 * high / len(des) if des else 0,
270
+ )
271
+
272
+ # Analyze by chroma range
273
+ LOGGER.info("\n%s", "=" * 40)
274
+ LOGGER.info("Analysis by Chroma Range")
275
+ LOGGER.info("=" * 40)
276
+
277
+ chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)]
278
+ LOGGER.info(
279
+ "\n%-8s %5s %6s %6s %6s %s",
280
+ "Chroma",
281
+ "Count",
282
+ "Mean",
283
+ "Median",
284
+ "Max",
285
+ ">3.0",
286
+ )
287
+ for c_min, c_max in chroma_ranges:
288
+ des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max]
289
+ if des:
290
+ high = sum(1 for d in des if d > 3.0)
291
+ LOGGER.info(
292
+ "[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
293
+ c_min,
294
+ c_max,
295
+ len(des),
296
+ np.mean(des),
297
+ np.median(des),
298
+ np.max(des),
299
+ high,
300
+ 100 * high / len(des) if des else 0,
301
+ )
302
+
303
+ # Top 20 worst samples
304
+ LOGGER.info("\n%s", "=" * 40)
305
+ LOGGER.info("Top 20 Worst Samples")
306
+ LOGGER.info("=" * 40)
307
+
308
+ worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20]
309
+ LOGGER.info(
310
+ "\n%-15s %6s %-20s %-20s", "Munsell", "DeltaE", "GT Spec", "Pred Spec"
311
+ )
312
+ for r in worst:
313
+ gs = r["gt_spec"]
314
+ ps = r["pred_spec"]
315
+ gt = f"[{gs[0]:.1f}, {gs[1]:.1f}, {gs[2]:.1f}, {int(gs[3])}]"
316
+ pred = f"[{ps[0]:.1f}, {ps[1]:.1f}, {ps[2]:.1f}, {int(ps[3])}]"
317
+ LOGGER.info(
318
+ "%-15s %6.2f %-20s %-20s", r["munsell_str"], r["delta_e"], gt, pred
319
+ )
320
+
321
+ # Analyze component errors for high-error samples
322
+ LOGGER.info("\n%s", "=" * 40)
323
+ LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold)
324
+ LOGGER.info("=" * 40)
325
+
326
+ if high_error:
327
+ hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error]
328
+ value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error]
329
+ chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error]
330
+ code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error]
331
+
332
+ LOGGER.info("\n%-10s %6s %6s %6s", "Component", "Mean", "Median", "Max")
333
+ LOGGER.info(
334
+ "%-10s %6.2f %6.2f %6.2f",
335
+ "Hue",
336
+ np.mean(hue_errors),
337
+ np.median(hue_errors),
338
+ np.max(hue_errors),
339
+ )
340
+ LOGGER.info(
341
+ "%-10s %6.2f %6.2f %6.2f",
342
+ "Value",
343
+ np.mean(value_errors),
344
+ np.median(value_errors),
345
+ np.max(value_errors),
346
+ )
347
+ LOGGER.info(
348
+ "%-10s %6.2f %6.2f %6.2f",
349
+ "Chroma",
350
+ np.mean(chroma_errors),
351
+ np.median(chroma_errors),
352
+ np.max(chroma_errors),
353
+ )
354
+ LOGGER.info(
355
+ "%-10s %6.2f %6.2f %6.2f",
356
+ "Code",
357
+ np.mean(code_errors),
358
+ np.median(code_errors),
359
+ np.max(code_errors),
360
+ )
361
+
362
+ return results
363
+
364
+
365
+ def main() -> None:
366
+ """Run error analysis."""
367
+ # Try the best models
368
+ models = [
369
+ "multi_head_large",
370
+ ]
371
+
372
+ for model_name in models:
373
+ try:
374
+ analyze_errors(model_name, threshold=3.0)
375
+ except FileNotFoundError as e:
376
+ LOGGER.warning("Skipping %s: %s", model_name, e)
377
+ LOGGER.info("\n")
378
+
379
+
380
+ if __name__ == "__main__":
381
+ main()
learning_munsell/analysis/hue_density_plot.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["numpy", "matplotlib"]
3
+ # ///
4
+ """
5
+ Plot training data density in Munsell hue x value space.
6
+
7
+ Each cell shows the sample count for a (hue_bin, value_bin) combination,
8
+ aggregated across all codes and chromas. Hue on the x-axis (0-10),
9
+ value on the y-axis (1-9).
10
+
11
+ Usage:
12
+ uv run learning_munsell/analysis/hue_density_plot.py
13
+ uv run learning_munsell/analysis/hue_density_plot.py --data data/training_data.npz
14
+ """
15
+
16
+ import argparse
17
+
18
+ import matplotlib.colors as mcolors
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+
22
+
23
+ def main(data_path: str = "data/training_data_large.npz") -> None:
24
+ """Plot hue x value density heatmap."""
25
+
26
+ data = np.load(data_path)
27
+ y_train = data["y_train"]
28
+
29
+ hue = y_train[:, 0]
30
+ value = y_train[:, 1]
31
+
32
+ hue_bins = np.linspace(0, 10, 41) # 0.25 step
33
+ value_bins = np.linspace(0, 10, 41)
34
+
35
+ hist, _, _ = np.histogram2d(hue, value, bins=[hue_bins, value_bins])
36
+
37
+ # Log scale to reveal structure across the full dynamic range.
38
+ hist_log = hist.copy()
39
+ hist_log[hist_log == 0] = np.nan
40
+
41
+ fig, ax = plt.subplots(figsize=(14, 5))
42
+ im = ax.pcolormesh(
43
+ hue_bins,
44
+ value_bins,
45
+ hist_log.T,
46
+ cmap="inferno",
47
+ shading="flat",
48
+ norm=mcolors.LogNorm(vmin=1, vmax=hist.max()),
49
+ )
50
+ ax.set_xlabel("Hue (0-10)")
51
+ ax.set_ylabel("Value (0-10)")
52
+ ax.set_title(
53
+ f"Training Data Density in Munsell Space "
54
+ f"({len(y_train):,} samples, all codes, log scale)\n"
55
+ f"hue < 1: {(hue < 1).sum():,} "
56
+ f"({(hue < 1).mean() * 100:.1f}%) | "
57
+ f"hue < 2: {(hue < 2).sum():,} "
58
+ f"({(hue < 2).mean() * 100:.1f}%)"
59
+ )
60
+ fig.colorbar(im, ax=ax, label="Sample count")
61
+
62
+ output = "docs/training_data_large_hue_density.png"
63
+ fig.savefig(output, dpi=150, bbox_inches="tight")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument(
69
+ "--data",
70
+ default="data/training_data_large.npz",
71
+ help="Path to training data .npz file",
72
+ )
73
+ main(parser.parse_args().data)
learning_munsell/comparison/from_xyY/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Comparison scripts for xyY to Munsell conversion models."""
learning_munsell/comparison/from_xyY/compare_all_models.py ADDED
@@ -0,0 +1,1594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compare all ML models for xyY to Munsell conversion on real Munsell data.
3
+
4
+ Models to compare:
5
+ 1. MLP (Base only)
6
+ 2. MLP + Error Predictor (Two-stage)
7
+ 3. Unified MLP
8
+ 4. MLP + Self-Attention
9
+ 5. MLP + Self-Attention + Error Predictor
10
+ 6. Deep + Wide
11
+ 7. Mixture of Experts
12
+ 8. FT-Transformer
13
+ """
14
+
15
+ import logging
16
+ import time
17
+ import warnings
18
+ from datetime import datetime
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ import numpy as np
23
+ import onnxruntime as ort
24
+ from colour import XYZ_to_Lab, xyY_to_XYZ
25
+ from colour.difference import delta_E_CIE2000
26
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
27
+ from colour.notation.munsell import (
28
+ CCS_ILLUMINANT_MUNSELL,
29
+ munsell_colour_to_munsell_specification,
30
+ munsell_specification_to_xyY,
31
+ xyY_to_munsell_specification,
32
+ )
33
+ from numpy.typing import NDArray
34
+
35
+ from learning_munsell import PROJECT_ROOT
36
+ from learning_munsell.utilities.common import (
37
+ benchmark_inference_speed,
38
+ get_model_size_mb,
39
+ )
40
+
41
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
42
+ LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ def normalize_input(X: NDArray, params: dict[str, Any] | None) -> NDArray:
46
+ """Normalize xyY input.
47
+
48
+ If params is None, xyY is assumed to already be in [0, 1] range (no normalization needed).
49
+ """
50
+ if params is None:
51
+ # xyY is already in [0, 1] range - no normalization needed
52
+ return X.astype(np.float32)
53
+
54
+ X_norm = np.copy(X)
55
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
56
+ params["x_range"][1] - params["x_range"][0]
57
+ )
58
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
59
+ params["y_range"][1] - params["y_range"][0]
60
+ )
61
+ X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / (
62
+ params["Y_range"][1] - params["Y_range"][0]
63
+ )
64
+ return X_norm.astype(np.float32)
65
+
66
+
67
+ def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
68
+ """Denormalize Munsell output."""
69
+ y = np.copy(y_norm)
70
+ y[..., 0] = (
71
+ y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
72
+ + params["hue_range"][0]
73
+ )
74
+ y[..., 1] = (
75
+ y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
76
+ + params["value_range"][0]
77
+ )
78
+ y[..., 2] = (
79
+ y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
80
+ + params["chroma_range"][0]
81
+ )
82
+ y[..., 3] = (
83
+ y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
84
+ + params["code_range"][0]
85
+ )
86
+ return y
87
+
88
+
89
+ def clamp_munsell_specification(specification: NDArray) -> NDArray:
90
+ """Clamp Munsell specification to valid ranges."""
91
+
92
+ clamped = np.copy(specification)
93
+ clamped[..., 0] = np.clip(specification[..., 0], 0.0, 10.0) # Hue: [0, 10]
94
+ clamped[..., 1] = np.clip(
95
+ specification[..., 1], 1.0, 9.0
96
+ ) # Value: [1, 9] (colour library constraint)
97
+ clamped[..., 2] = np.clip(specification[..., 2], 0.0, 50.0) # Chroma: [0, 50]
98
+ clamped[..., 3] = np.clip(specification[..., 3], 1.0, 10.0) # Code: [1, 10]
99
+
100
+ return clamped
101
+
102
+
103
+ def decode_class_code_output(
104
+ pred_norm: NDArray, output_parameters: dict[str, Any]
105
+ ) -> NDArray:
106
+ """Decode output from MultiMLPClassCodeToMunsell.
107
+
108
+ Input shape: (N, 13) — [hue_norm, value_norm, chroma_norm, logit_0..logit_9]
109
+ Output shape: (N, 4) — [hue, value, chroma, code] in original scale.
110
+ """
111
+ # Denormalize regression outputs (hue, value, chroma)
112
+ hue = (
113
+ pred_norm[..., 0]
114
+ * (output_parameters["hue_range"][1] - output_parameters["hue_range"][0])
115
+ + output_parameters["hue_range"][0]
116
+ )
117
+ value = (
118
+ pred_norm[..., 1]
119
+ * (output_parameters["value_range"][1] - output_parameters["value_range"][0])
120
+ + output_parameters["value_range"][0]
121
+ )
122
+ chroma = (
123
+ pred_norm[..., 2]
124
+ * (output_parameters["chroma_range"][1] - output_parameters["chroma_range"][0])
125
+ + output_parameters["chroma_range"][0]
126
+ )
127
+ # Classification output: argmax of logits + 1
128
+ code = np.argmax(pred_norm[..., 3:], axis=-1).astype(np.float64) + 1.0
129
+
130
+ return np.stack([hue, value, chroma, code], axis=-1)
131
+
132
+
133
+ def decode_hue_angle_output(
134
+ pred: NDArray, output_parameters: dict[str, Any]
135
+ ) -> NDArray:
136
+ """Decode output from MultiMLPHueAngleToMunsell.
137
+
138
+ Input shape: (N, 4) — [sin_angle, cos_angle, value_norm, chroma_norm]
139
+ Output shape: (N, 4) — [hue, value, chroma, code] in original scale.
140
+ """
141
+ sin_a = pred[..., 0]
142
+ cos_a = pred[..., 1]
143
+
144
+ # Recover angle in [0, 100) from sin/cos
145
+ angle = np.arctan2(sin_a, cos_a) * 100.0 / (2.0 * np.pi)
146
+ angle = angle % 100.0 # wrap to [0, 100)
147
+
148
+ # Decompose angle into code and hue
149
+ code = np.floor(angle / 10.0).astype(np.float64) + 1.0
150
+ code = np.clip(code, 1.0, 10.0)
151
+ hue = angle - (code - 1.0) * 10.0
152
+
153
+ # Denormalize value and chroma
154
+ value = (
155
+ pred[..., 2]
156
+ * (output_parameters["value_range"][1] - output_parameters["value_range"][0])
157
+ + output_parameters["value_range"][0]
158
+ )
159
+ chroma = (
160
+ pred[..., 3]
161
+ * (output_parameters["chroma_range"][1] - output_parameters["chroma_range"][0])
162
+ + output_parameters["chroma_range"][0]
163
+ )
164
+
165
+ return np.stack([hue, value, chroma, code], axis=-1)
166
+
167
+
168
+ def evaluate_model(
169
+ session: ort.InferenceSession,
170
+ X_norm: NDArray,
171
+ ground_truth: NDArray,
172
+ params: dict[str, Any],
173
+ input_name: str = "xyY",
174
+ reference_Lab: NDArray | None = None,
175
+ ) -> dict[str, Any]:
176
+ """Evaluate a single model."""
177
+ pred_norm = session.run(None, {input_name: X_norm})[0]
178
+ pred = denormalize_output(pred_norm, params)
179
+ errors = np.abs(pred - ground_truth)
180
+
181
+ result = {
182
+ "hue_mae": np.mean(errors[:, 0]),
183
+ "value_mae": np.mean(errors[:, 1]),
184
+ "chroma_mae": np.mean(errors[:, 2]),
185
+ "code_mae": np.mean(errors[:, 3]),
186
+ "max_errors": np.max(errors, axis=1),
187
+ "hue_errors": errors[:, 0],
188
+ "value_errors": errors[:, 1],
189
+ "chroma_errors": errors[:, 2],
190
+ "code_errors": errors[:, 3],
191
+ }
192
+
193
+ # Compute Delta-E against ground truth
194
+ if reference_Lab is not None:
195
+ delta_E_values = []
196
+ for idx in range(len(pred)):
197
+ try:
198
+ # Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab
199
+ ml_spec = clamp_munsell_specification(pred[idx])
200
+
201
+ # Round Code to nearest integer before round-trip conversion
202
+ ml_spec_for_conversion = ml_spec.copy()
203
+ ml_spec_for_conversion[3] = round(ml_spec[3])
204
+
205
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
206
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
207
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
208
+
209
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
210
+ delta_E_values.append(delta_E)
211
+ except (RuntimeError, ValueError):
212
+ # Skip if conversion fails
213
+ continue
214
+
215
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
216
+ else:
217
+ result["delta_E"] = np.nan
218
+
219
+ return result
220
+
221
+
222
+ def generate_html_report(
223
+ results: dict[str, dict[str, Any]],
224
+ num_samples: int,
225
+ output_file: Path,
226
+ baseline_inference_time_ms: float | None = None,
227
+ ) -> None:
228
+ """Generate HTML report with visualizations."""
229
+ # Calculate metrics
230
+ avg_maes = {}
231
+ for model_name, result in results.items():
232
+ avg_maes[model_name] = np.mean(
233
+ [
234
+ result["hue_mae"],
235
+ result["value_mae"],
236
+ result["chroma_mae"],
237
+ result["code_mae"],
238
+ ]
239
+ )
240
+
241
+ # Sort by average MAE
242
+ sorted_models = sorted(avg_maes.items(), key=lambda x: x[1])
243
+
244
+ # Precision thresholds
245
+ thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0]
246
+
247
+ html = f"""<!DOCTYPE html>
248
+ <html lang="en" class="dark">
249
+ <head>
250
+ <meta charset="UTF-8">
251
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
252
+ <title>ML Model Comparison Report - {datetime.now().strftime("%Y-%m-%d %H:%M")}</title>
253
+ <script src="https://cdn.tailwindcss.com"></script>
254
+ <script>
255
+ tailwind.config = {{
256
+ darkMode: 'class',
257
+ theme: {{
258
+ extend: {{
259
+ colors: {{
260
+ border: "hsl(240 3.7% 15.9%)",
261
+ input: "hsl(240 3.7% 15.9%)",
262
+ ring: "hsl(240 4.9% 83.9%)",
263
+ background: "hsl(240 10% 3.9%)",
264
+ foreground: "hsl(0 0% 98%)",
265
+ primary: {{
266
+ DEFAULT: "hsl(263 70% 60%)",
267
+ foreground: "hsl(0 0% 98%)",
268
+ }},
269
+ secondary: {{
270
+ DEFAULT: "hsl(240 3.7% 15.9%)",
271
+ foreground: "hsl(0 0% 98%)",
272
+ }},
273
+ muted: {{
274
+ DEFAULT: "hsl(240 3.7% 15.9%)",
275
+ foreground: "hsl(240 5% 64.9%)",
276
+ }},
277
+ accent: {{
278
+ DEFAULT: "hsl(240 3.7% 15.9%)",
279
+ foreground: "hsl(0 0% 98%)",
280
+ }},
281
+ card: {{
282
+ DEFAULT: "hsl(240 10% 6%)",
283
+ foreground: "hsl(0 0% 98%)",
284
+ }},
285
+ }}
286
+ }}
287
+ }}
288
+ }}
289
+ </script>
290
+ <style>
291
+ .gradient-primary {{
292
+ background: linear-gradient(135deg, hsl(263 70% 50%) 0%, hsl(280 70% 45%) 100%);
293
+ }}
294
+ .bar-fill {{
295
+ background: linear-gradient(90deg, hsl(263 70% 60%) 0%, hsl(280 70% 55%) 100%);
296
+ transition: width 0.5s cubic-bezier(0.4, 0, 0.2, 1);
297
+ }}
298
+ </style>
299
+ </head>
300
+ <body class="bg-background text-foreground antialiased">
301
+ <div class="max-w-7xl mx-auto p-6 space-y-6">
302
+ <!-- Header -->
303
+ <div class="gradient-primary rounded-lg p-8 shadow-2xl border border-primary/20">
304
+ <h1 class="text-4xl font-bold text-white mb-2">ML Model Comparison Report</h1>
305
+ <div class="text-white/90 space-y-1">
306
+ <p class="text-lg">xyY to Munsell Specification Conversion</p>
307
+ <p class="text-sm">Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p>
308
+ <p class="text-sm">Test Samples: <span class="font-semibold">{num_samples:,}</span> real Munsell colors</p>
309
+ </div>
310
+ </div>
311
+ """
312
+
313
+ # Best Models Summary (FIRST - moved to top)
314
+ # Find best models for each metric
315
+ delta_E_values = [
316
+ r["delta_E"] for r in results.values() if not np.isnan(r["delta_E"])
317
+ ]
318
+
319
+ best_delta_E = (
320
+ min(
321
+ results.items(),
322
+ key=lambda x: x[1]["delta_E"]
323
+ if not np.isnan(x[1]["delta_E"])
324
+ else float("inf"),
325
+ )[0]
326
+ if delta_E_values
327
+ else None
328
+ )
329
+ best_avg = sorted_models[0][0]
330
+
331
+ # Performance Metrics Table (FIRST - as summary)
332
+ # Find best for each metric
333
+ best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0]
334
+ best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0]
335
+
336
+ # Add Best Models Summary HTML
337
+ html += f"""
338
+ <!-- Best Models Summary -->
339
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
340
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Best Models by Metric</h2>
341
+ <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
342
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
343
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Smallest Size</div>
344
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_size]["model_size_mb"]:.2f} MB</div>
345
+ <div class="text-sm text-foreground/80">{best_size}</div>
346
+ </div>
347
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
348
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Fastest Speed</div>
349
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_speed]["inference_time_ms"]:.4f} ms</div>
350
+ <div class="text-sm text-foreground/80">{best_speed}</div>
351
+ </div>
352
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
353
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Delta-E</div>
354
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_delta_E]["delta_E"]:.4f}</div>
355
+ <div class="text-sm text-foreground/80">{best_delta_E}</div>
356
+ </div>
357
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
358
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Average MAE</div>
359
+ <div class="text-3xl font-bold text-primary mb-3">{avg_maes[best_avg]:.4f}</div>
360
+ <div class="text-sm text-foreground/80">{best_avg}</div>
361
+ </div>
362
+ </div>
363
+ </div>
364
+ """
365
+
366
+ # Get baseline speed (Colour Library Iterative)
367
+ baseline_speed = baseline_inference_time_ms
368
+
369
+ # Sort by Delta-E for performance table (best first)
370
+ sorted_by_delta_E = sorted(
371
+ results.items(),
372
+ key=lambda x: x[1]["delta_E"]
373
+ if not np.isnan(x[1]["delta_E"])
374
+ else float("inf"),
375
+ )
376
+
377
+ # Calculate maximum speed multiplier (fastest model) for highlighting
378
+ max_speed_multiplier = 0.0
379
+ best_multiplier_model = None
380
+ for model_name, result in results.items():
381
+ speed_ms = result["inference_time_ms"]
382
+ if speed_ms > 0:
383
+ speed_multiplier = baseline_speed / speed_ms
384
+ if speed_multiplier > max_speed_multiplier:
385
+ max_speed_multiplier = speed_multiplier
386
+ best_multiplier_model = model_name
387
+
388
+ html += """
389
+ <!-- Performance Metrics Table -->
390
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
391
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Model Performance Metrics</h2>
392
+ <div class="overflow-x-auto">
393
+ <table class="w-full text-sm">
394
+ <thead>
395
+ <tr class="border-b border-border">
396
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
397
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
398
+ Size (MB)
399
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">ONNX files</div>
400
+ </th>
401
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
402
+ Speed (ms/sample)
403
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">10 iterations</div>
404
+ </th>
405
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
406
+ vs Baseline
407
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">Colour Iterative</div>
408
+ </th>
409
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
410
+ Delta-E
411
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">vs Colour Lib</div>
412
+ </th>
413
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Average MAE</th>
414
+ </tr>
415
+ </thead>
416
+ <tbody>
417
+ """
418
+
419
+ for model_name, result in sorted_by_delta_E:
420
+ size_mb = result["model_size_mb"]
421
+ speed_ms = result["inference_time_ms"]
422
+ avg_mae = avg_maes[model_name]
423
+ delta_E = result["delta_E"]
424
+
425
+ # Calculate relative speed (how many times faster than baseline)
426
+ speed_multiplier = baseline_speed / speed_ms if speed_ms > 0 else 0
427
+
428
+ size_class = "text-primary font-semibold" if model_name == best_size else ""
429
+ speed_class = "text-primary font-semibold" if model_name == best_speed else ""
430
+ avg_class = "text-primary font-semibold" if model_name == best_avg else ""
431
+ delta_E_class = (
432
+ "text-primary font-semibold" if model_name == best_delta_E else ""
433
+ )
434
+
435
+ # Format Delta-E value
436
+ delta_E_str = f"{delta_E:.4f}" if not np.isnan(delta_E) else "—"
437
+
438
+ # Highlight only the fastest model
439
+ if abs(speed_multiplier - 1.0) < 0.01:
440
+ # Baseline
441
+ multiplier_class = "text-muted-foreground"
442
+ multiplier_text = "1.0x"
443
+ elif model_name == best_multiplier_model:
444
+ # Fastest model (highest multiplier)
445
+ multiplier_class = "text-primary font-semibold"
446
+ if speed_multiplier > 1000:
447
+ multiplier_text = f"{speed_multiplier:.0f}x"
448
+ elif speed_multiplier > 100:
449
+ multiplier_text = f"{speed_multiplier:.1f}x"
450
+ else:
451
+ multiplier_text = f"{speed_multiplier:.2f}x"
452
+ elif speed_multiplier > 1.0:
453
+ # Faster than baseline but not the fastest
454
+ multiplier_class = ""
455
+ if speed_multiplier > 1000:
456
+ multiplier_text = f"{speed_multiplier:.0f}x"
457
+ elif speed_multiplier > 100:
458
+ multiplier_text = f"{speed_multiplier:.1f}x"
459
+ else:
460
+ multiplier_text = f"{speed_multiplier:.2f}x"
461
+ else:
462
+ # Slower than baseline
463
+ multiplier_class = "text-destructive"
464
+ multiplier_text = f"{speed_multiplier:.2f}x"
465
+
466
+ html += f"""
467
+ <tr class="border-b border-border/50 hover:bg-muted/30 transition-colors">
468
+ <td class="py-3 px-4 font-medium">{model_name}</td>
469
+ <td class="py-3 px-4 text-right {size_class}">{size_mb:.2f}</td>
470
+ <td class="py-3 px-4 text-right {speed_class}">{speed_ms:.4f}</td>
471
+ <td class="py-3 px-4 text-right {multiplier_class}">{multiplier_text}</td>
472
+ <td class="py-3 px-4 text-right {delta_E_class}">{delta_E_str}</td>
473
+ <td class="py-3 px-4 text-right {avg_class}">{avg_mae:.4f}</td>
474
+ </tr>
475
+ """
476
+
477
+ html += """
478
+ </tbody>
479
+ </table>
480
+ </div>
481
+ <div class="mt-6 p-4 bg-muted/30 rounded-md border border-primary/20">
482
+ <div class="text-sm space-y-2">
483
+ <div><span class="text-primary font-semibold">Note:</span> Speed measured with 10 iterations (3 warmup + 10 benchmark) on 2,734 samples.</div>
484
+ <div class="text-xs text-muted-foreground">Two-stage models include both base and error predictor. Highlighted values show best in each metric.</div>
485
+ <div class="text-xs text-muted-foreground">Baseline comparison: Speed multipliers show relative performance vs Colour Library's iterative xyY_to_munsell_specification(). Values &lt;1.0x are faster.</div>
486
+ </div>
487
+ </div>
488
+ </div>
489
+ """
490
+
491
+ # Overall ranking by Delta-E
492
+ html += """
493
+ <!-- Overall Ranking -->
494
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
495
+ <h2 class="text-2xl font-semibold mb-4 pb-2 border-b border-primary/30">Overall Ranking (by Delta-E)</h2>
496
+ <div class="space-y-1">
497
+ """
498
+
499
+ # Sort by Delta-E (best = lowest)
500
+ sorted_by_delta_E_ranking = sorted(
501
+ [
502
+ (name, res["delta_E"])
503
+ for name, res in results.items()
504
+ if not np.isnan(res["delta_E"])
505
+ ],
506
+ key=lambda x: x[1],
507
+ )
508
+
509
+ max_delta_E = (
510
+ max(delta_E for _, delta_E in sorted_by_delta_E_ranking)
511
+ if sorted_by_delta_E_ranking
512
+ else 1.0
513
+ )
514
+ for rank, (model_name, delta_E) in enumerate(sorted_by_delta_E_ranking, 1):
515
+ width_pct = (delta_E / max_delta_E) * 100
516
+ html += f"""
517
+ <div class="flex items-center gap-3 p-2 rounded-md hover:bg-muted/50 transition-colors">
518
+ <div class="flex-none w-80 text-sm font-medium">
519
+ <span class="text-muted-foreground">{rank}.</span> {model_name}
520
+ </div>
521
+ <div class="flex-1 h-6 bg-muted rounded-md overflow-hidden">
522
+ <div class="bar-fill h-full rounded-md" style="width: {width_pct}%"></div>
523
+ </div>
524
+ <div class="flex-none w-20 text-right font-bold text-primary">{delta_E:.4f}</div>
525
+ </div>
526
+ """
527
+
528
+ html += """
529
+ </div>
530
+ </div>
531
+ """
532
+
533
+ # Precision Threshold Table
534
+ html += """
535
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
536
+ <h2 class="text-2xl font-semibold mb-3 pb-3 border-b border-primary/30">Accuracy at Precision Thresholds</h2>
537
+ <p class="text-sm text-muted-foreground mb-6">Percentage of predictions where max error across all components is below threshold:</p>
538
+ <div class="overflow-x-auto">
539
+ <table class="w-full text-sm">
540
+ <thead>
541
+ <tr class="border-b border-border">
542
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
543
+ """
544
+
545
+ for threshold in thresholds:
546
+ html += f' <th class="text-right py-3 px-4 font-semibold text-muted-foreground">&lt; {threshold:.0e}</th>\n'
547
+
548
+ html += """
549
+ </tr>
550
+ </thead>
551
+ <tbody>
552
+ """
553
+
554
+ # Find best (highest) accuracy for each threshold column
555
+ best_accuracies = {}
556
+ min_accuracies = {}
557
+ for threshold in thresholds:
558
+ accuracies = [
559
+ np.mean(results[model_name]["max_errors"] < threshold) * 100
560
+ for model_name, _ in sorted_models
561
+ ]
562
+ best_accuracies[threshold] = max(accuracies)
563
+ min_accuracies[threshold] = min(accuracies)
564
+
565
+ for model_name, _ in sorted_models:
566
+ result = results[model_name]
567
+ row_class = (
568
+ "bg-primary/10 border-l-2 border-l-primary"
569
+ if model_name == best_avg
570
+ else ""
571
+ )
572
+ html += f"""
573
+ <tr class="border-b border-border hover:bg-muted/30 transition-colors {row_class}">
574
+ <td class="text-left py-3 px-4 font-medium">{model_name}</td>
575
+ """
576
+ for threshold in thresholds:
577
+ accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
578
+ # Only highlight if there's meaningful variation
579
+ # (>0.1% difference between best and worst)
580
+ has_variation = (
581
+ best_accuracies[threshold] - min_accuracies[threshold]
582
+ ) > 0.1
583
+ is_best = abs(accuracy_pct - best_accuracies[threshold]) < 0.01
584
+ cell_class = (
585
+ "text-right py-3 px-4 font-bold text-primary"
586
+ if (has_variation and is_best)
587
+ else "text-right py-3 px-4"
588
+ )
589
+ html += f' <td class="{cell_class}">{accuracy_pct:.2f}%</td>\n'
590
+
591
+ html += """
592
+ </tr>
593
+ """
594
+
595
+ html += """
596
+ </tbody>
597
+ </table>
598
+ </div>
599
+ </div>
600
+
601
+ </div>
602
+ </body>
603
+ </html>
604
+ """
605
+
606
+ # Write HTML file
607
+ with open(output_file, "w") as f:
608
+ f.write(html)
609
+
610
+ LOGGER.info("")
611
+ LOGGER.info("HTML report saved to: %s", output_file)
612
+
613
+
614
+ def main() -> None:
615
+ """Compare all models."""
616
+ LOGGER.info("=" * 80)
617
+ LOGGER.info("Comprehensive Model Comparison")
618
+ LOGGER.info("=" * 80)
619
+
620
+ # Paths
621
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
622
+
623
+ # Load real Munsell dataset
624
+ LOGGER.info("")
625
+ LOGGER.info("Loading real Munsell dataset...")
626
+ xyY_samples = []
627
+ ground_truth = []
628
+
629
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
630
+ try:
631
+ hue_code, value, chroma = munsell_spec_tuple
632
+ munsell_str = f"{hue_code} {value}/{chroma}"
633
+ spec = munsell_colour_to_munsell_specification(munsell_str)
634
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
635
+ xyY_samples.append(xyY_scaled)
636
+ ground_truth.append(spec)
637
+ except Exception: # noqa: BLE001, S112
638
+ continue
639
+
640
+ xyY_samples = np.array(xyY_samples)
641
+ ground_truth = np.array(ground_truth)
642
+ LOGGER.info("Loaded %d valid Munsell colors", len(xyY_samples))
643
+
644
+ # Define models to compare
645
+ models = [
646
+ {
647
+ "name": "MLP",
648
+ "files": [model_directory / "mlp.onnx"],
649
+ "params_file": model_directory / "mlp_normalization_parameters.npz",
650
+ "type": "single",
651
+ },
652
+ {
653
+ "name": "MLP + Error Predictor",
654
+ "files": [
655
+ model_directory / "mlp.onnx",
656
+ model_directory / "mlp_error_predictor.onnx",
657
+ ],
658
+ "params_file": model_directory / "mlp_normalization_parameters.npz",
659
+ "type": "two_stage",
660
+ },
661
+ {
662
+ "name": "Unified MLP",
663
+ "files": [model_directory / "unified_mlp.onnx"],
664
+ "params_file": model_directory / "unified_mlp_normalization_parameters.npz",
665
+ "type": "single",
666
+ },
667
+ {
668
+ "name": "MLP + Self-Attention",
669
+ "files": [model_directory / "mlp_attention.onnx"],
670
+ "params_file": model_directory
671
+ / "mlp_attention_normalization_parameters.npz",
672
+ "type": "single",
673
+ },
674
+ {
675
+ "name": "Deep + Wide",
676
+ "files": [model_directory / "deep_wide.onnx"],
677
+ "params_file": model_directory / "deep_wide_normalization_parameters.npz",
678
+ "type": "single",
679
+ },
680
+ {
681
+ "name": "Mixture of Experts",
682
+ "files": [model_directory / "mixture_of_experts.onnx"],
683
+ "params_file": model_directory
684
+ / "mixture_of_experts_normalization_parameters.npz",
685
+ "type": "single",
686
+ },
687
+ {
688
+ "name": "FT-Transformer",
689
+ "files": [model_directory / "ft_transformer.onnx"],
690
+ "params_file": model_directory
691
+ / "ft_transformer_normalization_parameters.npz",
692
+ "type": "single",
693
+ },
694
+ {
695
+ "name": "Multi-Head",
696
+ "files": [model_directory / "multi_head.onnx"],
697
+ "params_file": model_directory / "multi_head_normalization_parameters.npz",
698
+ "type": "single",
699
+ },
700
+ {
701
+ "name": "Multi-Head + Multi-Error Predictor",
702
+ "files": [
703
+ model_directory / "multi_head.onnx",
704
+ model_directory / "multi_head_multi_error_predictor.onnx",
705
+ ],
706
+ "params_file": model_directory / "multi_head_normalization_parameters.npz",
707
+ "type": "two_stage",
708
+ },
709
+ {
710
+ "name": "Multi-Head + Cross-Attention Error Predictor",
711
+ "files": [
712
+ model_directory / "multi_head.onnx",
713
+ model_directory / "multi_head_cross_attention_error_predictor.onnx",
714
+ ],
715
+ "params_file": model_directory / "multi_head_normalization_parameters.npz",
716
+ "type": "two_stage",
717
+ },
718
+ {
719
+ "name": "Multi-MLP",
720
+ "files": [model_directory / "multi_mlp.onnx"],
721
+ "params_file": model_directory / "multi_mlp_normalization_parameters.npz",
722
+ "type": "single",
723
+ },
724
+ {
725
+ "name": "Multi-MLP + Multi-Error Predictor",
726
+ "files": [
727
+ model_directory / "multi_mlp.onnx",
728
+ model_directory / "multi_mlp_multi_error_predictor.onnx",
729
+ ],
730
+ "params_file": model_directory / "multi_mlp_normalization_parameters.npz",
731
+ "type": "two_stage",
732
+ },
733
+ {
734
+ "name": "Multi-Head (Circular Loss)",
735
+ "files": [model_directory / "multi_head_circular.onnx"],
736
+ "params_file": model_directory
737
+ / "multi_head_circular_normalization_parameters.npz",
738
+ "type": "single",
739
+ },
740
+ {
741
+ "name": "Multi-Head (Weighted + Boundary Loss)",
742
+ "files": [model_directory / "multi_head_weighted_boundary.onnx"],
743
+ "params_file": model_directory
744
+ / "multi_head_weighted_boundary_normalization_parameters.npz",
745
+ "type": "single",
746
+ },
747
+ {
748
+ "name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor",
749
+ "files": [
750
+ model_directory / "multi_head_weighted_boundary.onnx",
751
+ model_directory
752
+ / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx",
753
+ ],
754
+ "params_file": model_directory
755
+ / "multi_head_weighted_boundary_normalization_parameters.npz",
756
+ "type": "two_stage",
757
+ },
758
+ {
759
+ "name": "Multi-MLP (Weighted + Boundary Loss)",
760
+ "files": [model_directory / "multi_mlp_weighted_boundary.onnx"],
761
+ "params_file": model_directory
762
+ / "multi_mlp_weighted_boundary_normalization_parameters.npz",
763
+ "type": "single",
764
+ },
765
+ {
766
+ "name": "Multi-MLP (Weighted + Boundary Loss) + Multi-Error Predictor",
767
+ "files": [
768
+ model_directory / "multi_mlp_weighted_boundary.onnx",
769
+ model_directory
770
+ / "multi_mlp_weighted_boundary_multi_error_predictor.onnx",
771
+ ],
772
+ "params_file": model_directory
773
+ / "multi_mlp_weighted_boundary_normalization_parameters.npz",
774
+ "type": "two_stage",
775
+ },
776
+ {
777
+ "name": "Multi-Head (Large Dataset)",
778
+ "files": [model_directory / "multi_head_large.onnx"],
779
+ "params_file": model_directory
780
+ / "multi_head_large_normalization_parameters.npz",
781
+ "type": "single",
782
+ },
783
+ {
784
+ "name": "Multi-Head + Multi-Error Predictor (Large Dataset)",
785
+ "files": [
786
+ model_directory / "multi_head_large.onnx",
787
+ model_directory / "multi_head_multi_error_predictor_large.onnx",
788
+ ],
789
+ "params_file": model_directory
790
+ / "multi_head_large_normalization_parameters.npz",
791
+ "type": "two_stage",
792
+ },
793
+ {
794
+ "name": "Multi-MLP (Large Dataset)",
795
+ "files": [model_directory / "multi_mlp_large.onnx"],
796
+ "params_file": model_directory
797
+ / "multi_mlp_large_normalization_parameters.npz",
798
+ "type": "single",
799
+ },
800
+ {
801
+ "name": "Multi-MLP + Multi-Error Predictor (Large Dataset)",
802
+ "files": [
803
+ model_directory / "multi_mlp_large.onnx",
804
+ model_directory / "multi_mlp_multi_error_predictor_large.onnx",
805
+ ],
806
+ "params_file": model_directory
807
+ / "multi_mlp_large_normalization_parameters.npz",
808
+ "type": "two_stage",
809
+ },
810
+ {
811
+ "name": "Multi-ResNet (Large Dataset)",
812
+ "files": [model_directory / "multi_resnet_large.onnx"],
813
+ "params_file": model_directory
814
+ / "multi_resnet_large_normalization_parameters.npz",
815
+ "type": "single",
816
+ },
817
+ {
818
+ "name": "Multi-ResNet + Multi-Error Predictor (Large Dataset)",
819
+ "files": [
820
+ model_directory / "multi_resnet_large.onnx",
821
+ model_directory / "multi_resnet_error_predictor_large.onnx",
822
+ ],
823
+ "params_file": model_directory
824
+ / "multi_resnet_large_normalization_parameters.npz",
825
+ "type": "two_stage",
826
+ },
827
+ {
828
+ "name": "Transformer (Large Dataset)",
829
+ "files": [model_directory / "transformer_large.onnx"],
830
+ "params_file": model_directory
831
+ / "transformer_large_normalization_parameters.npz",
832
+ "type": "single",
833
+ },
834
+ {
835
+ "name": "Transformer + Error Predictor (Large Dataset)",
836
+ "files": [
837
+ model_directory / "transformer_large.onnx",
838
+ model_directory / "transformer_multi_error_predictor_large.onnx",
839
+ ],
840
+ "params_file": model_directory
841
+ / "transformer_large_normalization_parameters.npz",
842
+ "type": "two_stage",
843
+ },
844
+ {
845
+ "name": "Multi-Head Refined (REAL Only)",
846
+ "files": [model_directory / "multi_head_refined_real.onnx"],
847
+ "params_file": model_directory
848
+ / "multi_head_refined_real_normalization_parameters.npz",
849
+ "type": "single",
850
+ },
851
+ {
852
+ "name": "Multi-Head + Multi-Error Predictor + 3-Stage Error Predictor",
853
+ "files": [
854
+ model_directory / "multi_head_large.onnx",
855
+ model_directory / "multi_head_multi_error_predictor_large.onnx",
856
+ model_directory / "multi_head_3stage_error_predictor.onnx",
857
+ ],
858
+ "params_file": model_directory
859
+ / "multi_head_large_normalization_parameters.npz",
860
+ "type": "three_stage",
861
+ },
862
+ {
863
+ "name": "Multi-MLP (Classification Code)",
864
+ "files": [model_directory / "multi_mlp_class_code.onnx"],
865
+ "params_file": model_directory
866
+ / "multi_mlp_class_code_normalization_parameters.npz",
867
+ "type": "class_code",
868
+ },
869
+ {
870
+ "name": "Multi-MLP (Classification Code) + Multi-Error Predictor",
871
+ "files": [
872
+ model_directory / "multi_mlp_class_code.onnx",
873
+ model_directory / "multi_mlp_class_code_multi_error_predictor.onnx",
874
+ ],
875
+ "params_file": model_directory
876
+ / "multi_mlp_class_code_normalization_parameters.npz",
877
+ "type": "class_code_multi_error_predictor",
878
+ },
879
+ {
880
+ "name": "Multi-MLP (Hue Angle sin/cos)",
881
+ "files": [model_directory / "multi_mlp_hue_angle.onnx"],
882
+ "params_file": model_directory
883
+ / "multi_mlp_hue_angle_normalization_parameters.npz",
884
+ "type": "hue_angle",
885
+ },
886
+ {
887
+ "name": "Multi-MLP (Class. Code) + Code-Aware Multi-Error Predictor",
888
+ "files": [
889
+ model_directory / "multi_mlp_class_code.onnx",
890
+ model_directory
891
+ / "multi_mlp_class_code_aware_multi_error_predictor.onnx",
892
+ ],
893
+ "params_file": model_directory
894
+ / "multi_mlp_class_code_normalization_parameters.npz",
895
+ "type": "class_code_aware_multi_error_predictor",
896
+ },
897
+ {
898
+ "name": "Multi-MLP (Classification Code) (Large Dataset)",
899
+ "files": [model_directory / "multi_mlp_class_code_large.onnx"],
900
+ "params_file": model_directory
901
+ / "multi_mlp_class_code_large_normalization_parameters.npz",
902
+ "type": "class_code",
903
+ },
904
+ {
905
+ "name": "Multi-MLP (Class. Code) + Code-Aware Multi-Error Predictor (Large Dataset)",
906
+ "files": [
907
+ model_directory / "multi_mlp_class_code_large.onnx",
908
+ model_directory
909
+ / "multi_mlp_class_code_aware_multi_error_predictor_large.onnx",
910
+ ],
911
+ "params_file": model_directory
912
+ / "multi_mlp_class_code_large_normalization_parameters.npz",
913
+ "type": "class_code_aware_multi_error_predictor",
914
+ },
915
+ {
916
+ "name": "Multi-MLP (Classification Code) (Uniform Dataset)",
917
+ "files": [model_directory / "multi_mlp_class_code_uniform.onnx"],
918
+ "params_file": model_directory
919
+ / "multi_mlp_class_code_uniform_normalization_parameters.npz",
920
+ "type": "class_code",
921
+ },
922
+ {
923
+ "name": "Multi-MLP (Class. Code) + Code-Aware Multi-Error Predictor (Uniform Dataset)",
924
+ "files": [
925
+ model_directory / "multi_mlp_class_code_uniform.onnx",
926
+ model_directory
927
+ / "multi_mlp_class_code_aware_multi_error_predictor_uniform.onnx",
928
+ ],
929
+ "params_file": model_directory
930
+ / "multi_mlp_class_code_uniform_normalization_parameters.npz",
931
+ "type": "class_code_aware_multi_error_predictor",
932
+ },
933
+ ]
934
+
935
+ # Benchmark colour library's iterative implementation first
936
+ LOGGER.info("")
937
+ LOGGER.info("=" * 80)
938
+ LOGGER.info("Colour Library (Iterative)")
939
+ LOGGER.info("=" * 80)
940
+
941
+ # Benchmark the iterative xyY_to_munsell_specification function
942
+ # Note: Using full dataset (100% of samples)
943
+
944
+ # Set random seed for reproducibility
945
+ np.random.seed(42)
946
+
947
+ # Use 100% of samples for comprehensive benchmarking
948
+ len(xyY_samples)
949
+ sampled_indices = np.arange(len(xyY_samples))
950
+ xyY_benchmark_samples = xyY_samples[sampled_indices]
951
+
952
+ # Measure inference time on sampled Munsell colors
953
+ start_time = time.perf_counter()
954
+ convergence_failures = 0
955
+ successful_inferences = 0
956
+
957
+ with warnings.catch_warnings():
958
+ warnings.simplefilter("ignore")
959
+ for xyy in xyY_benchmark_samples:
960
+ try:
961
+ xyY_to_munsell_specification(xyy)
962
+ successful_inferences += 1
963
+ except (RuntimeError, ValueError):
964
+ # Out-of-gamut color that doesn't converge or not in renotation system
965
+ convergence_failures += 1
966
+
967
+ end_time = time.perf_counter()
968
+
969
+ # Calculate average time per successful inference (in milliseconds)
970
+ total_time_s = end_time - start_time
971
+ colour_inference_time_ms = (
972
+ (total_time_s / successful_inferences) * 1000
973
+ if successful_inferences > 0
974
+ else 0
975
+ )
976
+
977
+ LOGGER.info("")
978
+ LOGGER.info("Performance Metrics:")
979
+ LOGGER.info(" Successful inferences: %d", successful_inferences)
980
+ LOGGER.info(" Convergence failures: %d", convergence_failures)
981
+ LOGGER.info(" Inference Speed: %.4f ms/sample", colour_inference_time_ms)
982
+ LOGGER.info(" Note: This is the baseline iterative implementation")
983
+
984
+ # Store the baseline speed
985
+ baseline_inference_time_ms = colour_inference_time_ms
986
+
987
+ # Convert ground truth Munsell specs to CIE Lab for Delta-E comparison
988
+ # Path: Munsell spec → xyY → XYZ → Lab
989
+ LOGGER.info("")
990
+ LOGGER.info("Converting ground truth to CIE Lab for Delta-E comparison...")
991
+ LOGGER.info(" Path: Munsell spec \u2192 xyY \u2192 XYZ \u2192 Lab")
992
+ reference_Lab = []
993
+ for spec in ground_truth:
994
+ try:
995
+ # Munsell specification → xyY
996
+ xyy = munsell_specification_to_xyY(spec)
997
+ # xyY → XYZ
998
+ XYZ = xyY_to_XYZ(xyy)
999
+ # XYZ → Lab (Illuminant C for Munsell)
1000
+ Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
1001
+ reference_Lab.append(Lab)
1002
+ except (RuntimeError, ValueError):
1003
+ # If conversion fails, use NaN
1004
+ reference_Lab.append(np.array([np.nan, np.nan, np.nan]))
1005
+
1006
+ reference_Lab = np.array(reference_Lab)
1007
+ LOGGER.info(
1008
+ " Converted %d ground truth specs to CIE Lab",
1009
+ len(reference_Lab),
1010
+ )
1011
+
1012
+ # Use the same sampled subset for ML model evaluations (for fair comparison)
1013
+ xyY_samples = xyY_benchmark_samples
1014
+ ground_truth = ground_truth[sampled_indices]
1015
+
1016
+ # Evaluate each model
1017
+ results = {}
1018
+
1019
+ for model_info in models:
1020
+ model_name = model_info["name"]
1021
+ LOGGER.info("")
1022
+ LOGGER.info("=" * 80)
1023
+ LOGGER.info(model_name)
1024
+ LOGGER.info("=" * 80)
1025
+
1026
+ # Load normalization params for this model
1027
+ params = np.load(model_info["params_file"], allow_pickle=True)
1028
+ # input_parameters may not exist if xyY is already in [0, 1] range
1029
+ input_parameters = (
1030
+ params["input_parameters"].item()
1031
+ if "input_parameters" in params.files
1032
+ else None
1033
+ )
1034
+ output_parameters = params["output_parameters"].item()
1035
+
1036
+ # Normalize input with this model's params (None means no normalization)
1037
+ X_norm = normalize_input(xyY_samples, input_parameters)
1038
+
1039
+ # Calculate model size
1040
+ model_size_mb = get_model_size_mb(model_info["files"])
1041
+
1042
+ if model_info["type"] == "two_stage":
1043
+ # Two-stage model
1044
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
1045
+ error_session = ort.InferenceSession(str(model_info["files"][1]))
1046
+
1047
+ # Define inference callable for benchmarking
1048
+ def two_stage_inference(
1049
+ _base_session: ort.InferenceSession = base_session,
1050
+ _error_session: ort.InferenceSession = error_session,
1051
+ _X_norm: NDArray = X_norm,
1052
+ ) -> NDArray:
1053
+ base_pred = _base_session.run(None, {"xyY": _X_norm})[0]
1054
+ combined = np.concatenate([_X_norm, base_pred], axis=1).astype(
1055
+ np.float32
1056
+ )
1057
+ error_corr = _error_session.run(None, {"combined_input": combined})[0]
1058
+ return base_pred + error_corr
1059
+
1060
+ # Benchmark speed
1061
+ inference_time_ms = benchmark_inference_speed(two_stage_inference, X_norm)
1062
+
1063
+ # Get predictions
1064
+ base_pred_norm = base_session.run(None, {"xyY": X_norm})[0]
1065
+ combined_input = np.concatenate([X_norm, base_pred_norm], axis=1).astype(
1066
+ np.float32
1067
+ )
1068
+ error_correction_norm = error_session.run(
1069
+ None, {"combined_input": combined_input}
1070
+ )[0]
1071
+ final_pred_norm = base_pred_norm + error_correction_norm
1072
+ pred = denormalize_output(final_pred_norm, output_parameters)
1073
+ errors = np.abs(pred - ground_truth)
1074
+
1075
+ result = {
1076
+ "hue_mae": np.mean(errors[:, 0]),
1077
+ "value_mae": np.mean(errors[:, 1]),
1078
+ "chroma_mae": np.mean(errors[:, 2]),
1079
+ "code_mae": np.mean(errors[:, 3]),
1080
+ "max_errors": np.max(errors, axis=1),
1081
+ "hue_errors": errors[:, 0],
1082
+ "value_errors": errors[:, 1],
1083
+ "chroma_errors": errors[:, 2],
1084
+ "code_errors": errors[:, 3],
1085
+ "model_size_mb": model_size_mb,
1086
+ "inference_time_ms": inference_time_ms,
1087
+ }
1088
+
1089
+ # Compute Delta-E against ground truth
1090
+ delta_E_values = []
1091
+ for idx in range(len(pred)):
1092
+ try:
1093
+ # Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab
1094
+ ml_spec = clamp_munsell_specification(pred[idx])
1095
+
1096
+ # Round Code to nearest integer before round-trip conversion
1097
+ ml_spec_for_conversion = ml_spec.copy()
1098
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1099
+
1100
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1101
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1102
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1103
+
1104
+ # Get ground truth Lab
1105
+ reference_Lab_sample = reference_Lab[idx]
1106
+
1107
+ # Compute Delta-E CIE2000
1108
+ delta_E = delta_E_CIE2000(reference_Lab_sample, ml_Lab)
1109
+ delta_E_values.append(delta_E)
1110
+ except (RuntimeError, ValueError):
1111
+ # Skip if conversion fails
1112
+ continue
1113
+
1114
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
1115
+ elif model_info["type"] == "three_stage":
1116
+ # Three-stage model: base + error predictor 1 + error predictor 2
1117
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
1118
+ error1_session = ort.InferenceSession(str(model_info["files"][1]))
1119
+ error2_session = ort.InferenceSession(str(model_info["files"][2]))
1120
+
1121
+ # Define inference callable for benchmarking
1122
+ def three_stage_inference(
1123
+ _base_session: ort.InferenceSession = base_session,
1124
+ _error1_session: ort.InferenceSession = error1_session,
1125
+ _error2_session: ort.InferenceSession = error2_session,
1126
+ _X_norm: NDArray = X_norm,
1127
+ ) -> NDArray:
1128
+ # Stage 1: Base model
1129
+ base_pred = _base_session.run(None, {"xyY": _X_norm})[0]
1130
+ # Stage 2: First error correction
1131
+ combined1 = np.concatenate([_X_norm, base_pred], axis=1).astype(
1132
+ np.float32
1133
+ )
1134
+ error1_corr = _error1_session.run(None, {"combined_input": combined1})[
1135
+ 0
1136
+ ]
1137
+ stage2_pred = base_pred + error1_corr
1138
+ # Stage 3: Second error correction
1139
+ combined2 = np.concatenate([_X_norm, stage2_pred], axis=1).astype(
1140
+ np.float32
1141
+ )
1142
+ error2_corr = _error2_session.run(None, {"combined_input": combined2})[
1143
+ 0
1144
+ ]
1145
+ return stage2_pred + error2_corr
1146
+
1147
+ # Benchmark speed
1148
+ inference_time_ms = benchmark_inference_speed(three_stage_inference, X_norm)
1149
+
1150
+ # Get predictions
1151
+ base_pred_norm = base_session.run(None, {"xyY": X_norm})[0]
1152
+ combined1 = np.concatenate([X_norm, base_pred_norm], axis=1).astype(
1153
+ np.float32
1154
+ )
1155
+ error1_corr_norm = error1_session.run(None, {"combined_input": combined1})[
1156
+ 0
1157
+ ]
1158
+ stage2_pred_norm = base_pred_norm + error1_corr_norm
1159
+ combined2 = np.concatenate([X_norm, stage2_pred_norm], axis=1).astype(
1160
+ np.float32
1161
+ )
1162
+ error2_corr_norm = error2_session.run(None, {"combined_input": combined2})[
1163
+ 0
1164
+ ]
1165
+ final_pred_norm = stage2_pred_norm + error2_corr_norm
1166
+ pred = denormalize_output(final_pred_norm, output_parameters)
1167
+ errors = np.abs(pred - ground_truth)
1168
+
1169
+ result = {
1170
+ "hue_mae": np.mean(errors[:, 0]),
1171
+ "value_mae": np.mean(errors[:, 1]),
1172
+ "chroma_mae": np.mean(errors[:, 2]),
1173
+ "code_mae": np.mean(errors[:, 3]),
1174
+ "max_errors": np.max(errors, axis=1),
1175
+ "hue_errors": errors[:, 0],
1176
+ "value_errors": errors[:, 1],
1177
+ "chroma_errors": errors[:, 2],
1178
+ "code_errors": errors[:, 3],
1179
+ "model_size_mb": model_size_mb,
1180
+ "inference_time_ms": inference_time_ms,
1181
+ }
1182
+
1183
+ # Compute Delta-E against ground truth for three-stage model
1184
+ delta_E_values = []
1185
+ for idx in range(len(pred)):
1186
+ try:
1187
+ ml_spec = clamp_munsell_specification(pred[idx])
1188
+ ml_spec_for_conversion = ml_spec.copy()
1189
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1190
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1191
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1192
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1193
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
1194
+ delta_E_values.append(delta_E)
1195
+ except (RuntimeError, ValueError):
1196
+ continue
1197
+
1198
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
1199
+ elif model_info["type"] == "class_code":
1200
+ # Classification code model — output shape (N, 13)
1201
+ session = ort.InferenceSession(str(model_info["files"][0]))
1202
+
1203
+ def class_code_inference(
1204
+ _session: ort.InferenceSession = session, _X_norm: NDArray = X_norm
1205
+ ) -> NDArray:
1206
+ return _session.run(None, {"xyY": _X_norm})[0]
1207
+
1208
+ inference_time_ms = benchmark_inference_speed(class_code_inference, X_norm)
1209
+
1210
+ pred_raw = session.run(None, {"xyY": X_norm})[0]
1211
+ pred = decode_class_code_output(pred_raw, output_parameters)
1212
+ pred = clamp_munsell_specification(pred)
1213
+ errors = np.abs(pred - ground_truth)
1214
+
1215
+ result = {
1216
+ "hue_mae": np.mean(errors[:, 0]),
1217
+ "value_mae": np.mean(errors[:, 1]),
1218
+ "chroma_mae": np.mean(errors[:, 2]),
1219
+ "code_mae": np.mean(errors[:, 3]),
1220
+ "max_errors": np.max(errors, axis=1),
1221
+ "hue_errors": errors[:, 0],
1222
+ "value_errors": errors[:, 1],
1223
+ "chroma_errors": errors[:, 2],
1224
+ "code_errors": errors[:, 3],
1225
+ "model_size_mb": model_size_mb,
1226
+ "inference_time_ms": inference_time_ms,
1227
+ }
1228
+
1229
+ # Compute Delta-E
1230
+ delta_E_values = []
1231
+ for idx in range(len(pred)):
1232
+ try:
1233
+ ml_spec = clamp_munsell_specification(pred[idx])
1234
+ ml_spec_for_conversion = ml_spec.copy()
1235
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1236
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1237
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1238
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1239
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
1240
+ delta_E_values.append(delta_E)
1241
+ except (RuntimeError, ValueError):
1242
+ continue
1243
+
1244
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
1245
+
1246
+ elif model_info["type"] == "class_code_multi_error_predictor":
1247
+ # Classification code base + multi-error predictor (regression only)
1248
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
1249
+ error_session = ort.InferenceSession(str(model_info["files"][1]))
1250
+
1251
+ def class_code_mep_inference(
1252
+ _base: ort.InferenceSession = base_session,
1253
+ _error: ort.InferenceSession = error_session,
1254
+ _X_norm: NDArray = X_norm,
1255
+ ) -> NDArray:
1256
+ base_raw = _base.run(None, {"xyY": _X_norm})[0]
1257
+ base_reg = base_raw[:, :3]
1258
+ combined = np.concatenate([_X_norm, base_reg], axis=1).astype(
1259
+ np.float32
1260
+ )
1261
+ error_correction = _error.run(None, {"combined_input": combined})[0]
1262
+ corrected_reg = base_reg + error_correction
1263
+ return np.concatenate([corrected_reg, base_raw[:, 3:]], axis=1)
1264
+
1265
+ inference_time_ms = benchmark_inference_speed(
1266
+ class_code_mep_inference, X_norm
1267
+ )
1268
+
1269
+ pred_raw = class_code_mep_inference()
1270
+ pred = decode_class_code_output(pred_raw, output_parameters)
1271
+ pred = clamp_munsell_specification(pred)
1272
+ errors = np.abs(pred - ground_truth)
1273
+
1274
+ result = {
1275
+ "hue_mae": np.mean(errors[:, 0]),
1276
+ "value_mae": np.mean(errors[:, 1]),
1277
+ "chroma_mae": np.mean(errors[:, 2]),
1278
+ "code_mae": np.mean(errors[:, 3]),
1279
+ "max_errors": np.max(errors, axis=1),
1280
+ "hue_errors": errors[:, 0],
1281
+ "value_errors": errors[:, 1],
1282
+ "chroma_errors": errors[:, 2],
1283
+ "code_errors": errors[:, 3],
1284
+ "model_size_mb": model_size_mb,
1285
+ "inference_time_ms": inference_time_ms,
1286
+ }
1287
+
1288
+ # Compute Delta-E
1289
+ delta_E_values = []
1290
+ for idx in range(len(pred)):
1291
+ try:
1292
+ ml_spec = clamp_munsell_specification(pred[idx])
1293
+ ml_spec_for_conversion = ml_spec.copy()
1294
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1295
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1296
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1297
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1298
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
1299
+ delta_E_values.append(delta_E)
1300
+ except (RuntimeError, ValueError):
1301
+ continue
1302
+
1303
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
1304
+
1305
+ elif model_info["type"] == "class_code_aware_multi_error_predictor":
1306
+ # Classification code base + code-aware multi-error predictor
1307
+ # Error predictor input: [xyY_norm(3) + regression_norm(3) + code_onehot(10)] = 16
1308
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
1309
+ error_session = ort.InferenceSession(str(model_info["files"][1]))
1310
+
1311
+ def class_code_aware_mep_inference(
1312
+ _base: ort.InferenceSession = base_session,
1313
+ _error: ort.InferenceSession = error_session,
1314
+ _X_norm: NDArray = X_norm,
1315
+ ) -> NDArray:
1316
+ base_raw = _base.run(None, {"xyY": _X_norm})[0]
1317
+ base_reg = base_raw[:, :3]
1318
+ code_logits = base_raw[:, 3:]
1319
+ code_idx = np.argmax(code_logits, axis=-1)
1320
+ code_onehot = np.zeros((len(code_idx), 10), dtype=np.float32)
1321
+ code_onehot[np.arange(len(code_idx)), code_idx] = 1.0
1322
+ combined = np.concatenate(
1323
+ [_X_norm, base_reg, code_onehot], axis=1
1324
+ ).astype(np.float32)
1325
+ error_correction = _error.run(None, {"combined_input": combined})[0]
1326
+ corrected_reg = base_reg + error_correction
1327
+ return np.concatenate([corrected_reg, code_logits], axis=1)
1328
+
1329
+ inference_time_ms = benchmark_inference_speed(
1330
+ class_code_aware_mep_inference, X_norm
1331
+ )
1332
+
1333
+ pred_raw = class_code_aware_mep_inference()
1334
+ pred = decode_class_code_output(pred_raw, output_parameters)
1335
+ pred = clamp_munsell_specification(pred)
1336
+ errors = np.abs(pred - ground_truth)
1337
+
1338
+ result = {
1339
+ "hue_mae": np.mean(errors[:, 0]),
1340
+ "value_mae": np.mean(errors[:, 1]),
1341
+ "chroma_mae": np.mean(errors[:, 2]),
1342
+ "code_mae": np.mean(errors[:, 3]),
1343
+ "max_errors": np.max(errors, axis=1),
1344
+ "hue_errors": errors[:, 0],
1345
+ "value_errors": errors[:, 1],
1346
+ "chroma_errors": errors[:, 2],
1347
+ "code_errors": errors[:, 3],
1348
+ "model_size_mb": model_size_mb,
1349
+ "inference_time_ms": inference_time_ms,
1350
+ }
1351
+
1352
+ # Compute Delta-E
1353
+ delta_E_values = []
1354
+ for idx in range(len(pred)):
1355
+ try:
1356
+ ml_spec = clamp_munsell_specification(pred[idx])
1357
+ ml_spec_for_conversion = ml_spec.copy()
1358
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1359
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1360
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1361
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1362
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
1363
+ delta_E_values.append(delta_E)
1364
+ except (RuntimeError, ValueError):
1365
+ continue
1366
+
1367
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
1368
+
1369
+ elif model_info["type"] == "hue_angle":
1370
+ # Hue angle sin/cos model — output shape (N, 4)
1371
+ session = ort.InferenceSession(str(model_info["files"][0]))
1372
+
1373
+ def hue_angle_inference(
1374
+ _session: ort.InferenceSession = session, _X_norm: NDArray = X_norm
1375
+ ) -> NDArray:
1376
+ return _session.run(None, {"xyY": _X_norm})[0]
1377
+
1378
+ inference_time_ms = benchmark_inference_speed(hue_angle_inference, X_norm)
1379
+
1380
+ pred_raw = session.run(None, {"xyY": X_norm})[0]
1381
+ pred = decode_hue_angle_output(pred_raw, output_parameters)
1382
+ pred = clamp_munsell_specification(pred)
1383
+ errors = np.abs(pred - ground_truth)
1384
+
1385
+ result = {
1386
+ "hue_mae": np.mean(errors[:, 0]),
1387
+ "value_mae": np.mean(errors[:, 1]),
1388
+ "chroma_mae": np.mean(errors[:, 2]),
1389
+ "code_mae": np.mean(errors[:, 3]),
1390
+ "max_errors": np.max(errors, axis=1),
1391
+ "hue_errors": errors[:, 0],
1392
+ "value_errors": errors[:, 1],
1393
+ "chroma_errors": errors[:, 2],
1394
+ "code_errors": errors[:, 3],
1395
+ "model_size_mb": model_size_mb,
1396
+ "inference_time_ms": inference_time_ms,
1397
+ }
1398
+
1399
+ # Compute Delta-E
1400
+ delta_E_values = []
1401
+ for idx in range(len(pred)):
1402
+ try:
1403
+ ml_spec = clamp_munsell_specification(pred[idx])
1404
+ ml_spec_for_conversion = ml_spec.copy()
1405
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1406
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1407
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1408
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1409
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
1410
+ delta_E_values.append(delta_E)
1411
+ except (RuntimeError, ValueError):
1412
+ continue
1413
+
1414
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
1415
+
1416
+ else:
1417
+ # Single model
1418
+ session = ort.InferenceSession(str(model_info["files"][0]))
1419
+
1420
+ # Define inference callable for benchmarking
1421
+ def single_inference(
1422
+ _session: ort.InferenceSession = session, _X_norm: NDArray = X_norm
1423
+ ) -> NDArray:
1424
+ return _session.run(None, {"xyY": _X_norm})[0]
1425
+
1426
+ # Benchmark speed
1427
+ inference_time_ms = benchmark_inference_speed(single_inference, X_norm)
1428
+
1429
+ result = evaluate_model(
1430
+ session,
1431
+ X_norm,
1432
+ ground_truth,
1433
+ output_parameters,
1434
+ reference_Lab=reference_Lab,
1435
+ )
1436
+ result["model_size_mb"] = model_size_mb
1437
+ result["inference_time_ms"] = inference_time_ms
1438
+
1439
+ results[model_name] = result
1440
+
1441
+ # Compute code accuracy (exact match after rounding)
1442
+ pred_for_code = result.get("_pred") # set below for custom types
1443
+ if pred_for_code is None:
1444
+ # For standard models, recompute pred from raw errors + ground truth
1445
+ # Code accuracy: count where rounded code matches ground truth code
1446
+ code_errors = result["code_errors"]
1447
+ code_correct = np.sum(code_errors < 0.5)
1448
+ code_accuracy = code_correct / len(code_errors) * 100.0
1449
+ else:
1450
+ code_correct = np.sum(
1451
+ np.abs(pred_for_code[:, 3] - ground_truth[:, 3]) < 0.5
1452
+ )
1453
+ code_accuracy = code_correct / len(ground_truth) * 100.0
1454
+
1455
+ result["code_accuracy"] = code_accuracy
1456
+
1457
+ # Print results
1458
+ LOGGER.info("")
1459
+ LOGGER.info("Mean Absolute Errors:")
1460
+ LOGGER.info(" Hue: %.4f", result["hue_mae"])
1461
+ LOGGER.info(" Value: %.4f", result["value_mae"])
1462
+ LOGGER.info(" Chroma: %.4f", result["chroma_mae"])
1463
+ LOGGER.info(" Code: %.4f", result["code_mae"])
1464
+ LOGGER.info(" Code Accuracy: %.1f%%", result["code_accuracy"])
1465
+ if not np.isnan(result["delta_E"]):
1466
+ LOGGER.info(" Delta-E (vs Ground Truth): %.4f", result["delta_E"])
1467
+ LOGGER.info("")
1468
+ LOGGER.info("Performance Metrics:")
1469
+ LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"])
1470
+ LOGGER.info(" Inference Speed: %.4f ms/sample", result["inference_time_ms"])
1471
+
1472
+ # Summary comparison
1473
+ LOGGER.info("")
1474
+ LOGGER.info("=" * 80)
1475
+ LOGGER.info("SUMMARY COMPARISON")
1476
+ LOGGER.info("=" * 80)
1477
+ LOGGER.info("")
1478
+
1479
+ if not results:
1480
+ LOGGER.info("⚠️ No models were successfully evaluated")
1481
+ return
1482
+
1483
+ # MAE comparison table
1484
+ LOGGER.info("Mean Absolute Error Comparison:")
1485
+ LOGGER.info("")
1486
+ header = "{:<35} {:>8} {:>8} {:>8} {:>8} {:>10} {:>10}".format(
1487
+ "Model",
1488
+ "Hue",
1489
+ "Value",
1490
+ "Chroma",
1491
+ "Code",
1492
+ "Code Acc%",
1493
+ "Delta-E",
1494
+ )
1495
+ LOGGER.info(header)
1496
+ LOGGER.info("-" * 100)
1497
+
1498
+ for model_name, result in results.items():
1499
+ delta_E_str = (
1500
+ f"{result['delta_E']:.4f}" if not np.isnan(result["delta_E"]) else "N/A"
1501
+ )
1502
+ code_acc = result.get("code_accuracy", 0.0)
1503
+ LOGGER.info(
1504
+ "%-35s %8.4f %8.4f %8.4f %8.4f %9.1f%% %10s",
1505
+ model_name[:35],
1506
+ result["hue_mae"],
1507
+ result["value_mae"],
1508
+ result["chroma_mae"],
1509
+ result["code_mae"],
1510
+ code_acc,
1511
+ delta_E_str,
1512
+ )
1513
+
1514
+ # Precision threshold comparison
1515
+ LOGGER.info("")
1516
+ LOGGER.info("Accuracy at Precision Thresholds:")
1517
+ LOGGER.info("")
1518
+
1519
+ thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0]
1520
+ header_parts = [f"{'Model/Threshold':<35}"]
1521
+ header_parts.extend(f"{f'< {threshold:.0e}':>10}" for threshold in thresholds)
1522
+ LOGGER.info(" ".join(header_parts))
1523
+ LOGGER.info("-" * 80)
1524
+
1525
+ for model_name, result in results.items():
1526
+ row_parts = [f"{model_name[:35]:<35}"]
1527
+ for threshold in thresholds:
1528
+ accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
1529
+ row_parts.append(f"{accuracy_pct:9.2f}%")
1530
+ LOGGER.info(" ".join(row_parts))
1531
+
1532
+ # Performance metrics comparison
1533
+ LOGGER.info("")
1534
+ LOGGER.info("Model Size and Inference Speed Comparison:")
1535
+ LOGGER.info("")
1536
+ header = f"{'Model':<35} {'Size (MB)':>12} {'Speed (ms/sample)':>18}"
1537
+ LOGGER.info(header)
1538
+ LOGGER.info("-" * 80)
1539
+
1540
+ for model_name, result in results.items():
1541
+ LOGGER.info(
1542
+ "%-35s %11.2f %17.4f",
1543
+ model_name[:35],
1544
+ result["model_size_mb"],
1545
+ result["inference_time_ms"],
1546
+ )
1547
+
1548
+ # Find best model
1549
+ LOGGER.info("")
1550
+ LOGGER.info("=" * 80)
1551
+ LOGGER.info("BEST MODELS BY METRIC")
1552
+ LOGGER.info("=" * 80)
1553
+ LOGGER.info("")
1554
+
1555
+ metrics = ["hue_mae", "value_mae", "chroma_mae", "code_mae"]
1556
+ metric_names = ["Hue MAE", "Value MAE", "Chroma MAE", "Code MAE"]
1557
+
1558
+ for metric, metric_name in zip(metrics, metric_names, strict=False):
1559
+ best_model = min(results.items(), key=lambda x: x[1][metric])
1560
+ LOGGER.info(
1561
+ "%-15s: %s (%.4f)",
1562
+ metric_name,
1563
+ best_model[0],
1564
+ best_model[1][metric],
1565
+ )
1566
+
1567
+ # Overall best (average rank)
1568
+ LOGGER.info("")
1569
+ LOGGER.info("Overall Best (by average component MAE):")
1570
+ for model_name, result in results.items():
1571
+ avg_mae = np.mean(
1572
+ [
1573
+ result["hue_mae"],
1574
+ result["value_mae"],
1575
+ result["chroma_mae"],
1576
+ result["code_mae"],
1577
+ ]
1578
+ )
1579
+ LOGGER.info(" %s: %.4f", model_name, avg_mae)
1580
+
1581
+ LOGGER.info("")
1582
+ LOGGER.info("=" * 80)
1583
+
1584
+ # Generate HTML report
1585
+ report_dir = PROJECT_ROOT / "reports" / "from_xyY"
1586
+ report_dir.mkdir(exist_ok=True)
1587
+ report_file = report_dir / "model_comparison.html"
1588
+ generate_html_report(
1589
+ results, len(xyY_samples), report_file, baseline_inference_time_ms
1590
+ )
1591
+
1592
+
1593
+ if __name__ == "__main__":
1594
+ main()
learning_munsell/comparison/from_xyY/compare_gamma_model.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick comparison of the gamma-corrected models against baselines.
3
+
4
+ This script compares:
5
+ 1. MLP (Base) vs MLP (Gamma 2.33)
6
+ 2. Multi-Head (Base) vs Multi-Head (Gamma 2.33) vs Multi-Head (ST.2084)
7
+ """
8
+
9
+ import logging
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import onnxruntime as ort
14
+ from colour import XYZ_to_Lab, xyY_to_XYZ
15
+ from colour.difference import delta_E_CIE2000
16
+ from colour.models import eotf_inverse_ST2084
17
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
18
+ from colour.notation.munsell import (
19
+ CCS_ILLUMINANT_MUNSELL,
20
+ munsell_colour_to_munsell_specification,
21
+ munsell_specification_to_xyY,
22
+ )
23
+ from numpy.typing import NDArray
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+
27
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
28
+ LOGGER = logging.getLogger(__name__)
29
+
30
+
31
+ def normalize_input_standard(X: NDArray, params: dict[str, Any]) -> NDArray:
32
+ """Normalize xyY input using standard method."""
33
+ X_norm = np.copy(X)
34
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
35
+ params["x_range"][1] - params["x_range"][0]
36
+ )
37
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
38
+ params["y_range"][1] - params["y_range"][0]
39
+ )
40
+ X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / (
41
+ params["Y_range"][1] - params["Y_range"][0]
42
+ )
43
+ return X_norm.astype(np.float32)
44
+
45
+
46
+ def normalize_input_gamma(X: NDArray, params: dict[str, Any]) -> NDArray:
47
+ """Gamma-corrected xyY normalization."""
48
+ gamma = params.get("gamma", 2.33)
49
+ X_norm = np.copy(X)
50
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
51
+ params["x_range"][1] - params["x_range"][0]
52
+ )
53
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
54
+ params["y_range"][1] - params["y_range"][0]
55
+ )
56
+ # Normalize Y then apply gamma
57
+ Y_normalized = (X[..., 2] - params["Y_range"][0]) / (
58
+ params["Y_range"][1] - params["Y_range"][0]
59
+ )
60
+ Y_normalized = np.clip(Y_normalized, 0, 1)
61
+ X_norm[..., 2] = np.power(Y_normalized, 1.0 / gamma)
62
+ return X_norm.astype(np.float32)
63
+
64
+
65
+ def normalize_input_st2084(X: NDArray, params: dict[str, Any]) -> NDArray:
66
+ """ST.2084 (PQ) encoded xyY normalization."""
67
+ L_p = params.get("L_p", 100.0)
68
+ X_norm = np.copy(X)
69
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
70
+ params["x_range"][1] - params["x_range"][0]
71
+ )
72
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
73
+ params["y_range"][1] - params["y_range"][0]
74
+ )
75
+ # Normalize Y then apply ST.2084
76
+ Y_normalized = (X[..., 2] - params["Y_range"][0]) / (
77
+ params["Y_range"][1] - params["Y_range"][0]
78
+ )
79
+ Y_normalized = np.clip(Y_normalized, 0, 1)
80
+ # Scale to cd/m² and apply ST.2084 inverse EOTF
81
+ Y_cdm2 = Y_normalized * L_p
82
+ X_norm[..., 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p)
83
+ return X_norm.astype(np.float32)
84
+
85
+
86
+ def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
87
+ """Denormalize Munsell output."""
88
+ y = np.copy(y_norm)
89
+ y[..., 0] = (
90
+ y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
91
+ + params["hue_range"][0]
92
+ )
93
+ y[..., 1] = (
94
+ y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
95
+ + params["value_range"][0]
96
+ )
97
+ y[..., 2] = (
98
+ y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
99
+ + params["chroma_range"][0]
100
+ )
101
+ y[..., 3] = (
102
+ y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
103
+ + params["code_range"][0]
104
+ )
105
+ return y
106
+
107
+
108
+ def clamp_munsell_specification(spec: NDArray) -> NDArray:
109
+ """Clamp Munsell specification to valid ranges."""
110
+ clamped = np.copy(spec)
111
+ clamped[..., 0] = np.clip(spec[..., 0], 0.0, 10.0) # Hue: [0, 10]
112
+ clamped[..., 1] = np.clip(
113
+ spec[..., 1], 1.0, 9.0
114
+ ) # Value: [1, 9] (colour library constraint)
115
+ clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0) # Chroma: [0, 50]
116
+ clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0) # Code: [1, 10]
117
+ return clamped
118
+
119
+
120
+ def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]:
121
+ """Compute Delta-E for predictions."""
122
+ delta_E_values = []
123
+ for idx in range(len(pred)):
124
+ try:
125
+ ml_spec = clamp_munsell_specification(pred[idx])
126
+ ml_spec_for_conversion = ml_spec.copy()
127
+ ml_spec_for_conversion[3] = round(ml_spec[3])
128
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
129
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
130
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
131
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
132
+ delta_E_values.append(delta_E)
133
+ except (RuntimeError, ValueError):
134
+ continue
135
+ return delta_E_values
136
+
137
+
138
+ def main() -> None:
139
+ """Compare gamma model against baseline."""
140
+ LOGGER.info("=" * 80)
141
+ LOGGER.info("Gamma Model Comparison: MLP vs MLP (Gamma 2.33)")
142
+ LOGGER.info("=" * 80)
143
+
144
+ models_dir = PROJECT_ROOT / "models" / "from_xyY"
145
+
146
+ # Load real Munsell data
147
+ LOGGER.info("\nLoading real Munsell colours...")
148
+ xyY_values = []
149
+ munsell_specs = []
150
+ reference_Lab = []
151
+
152
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
153
+ try:
154
+ hue_code, value, chroma = munsell_spec_tuple
155
+ munsell_str = f"{hue_code} {value}/{chroma}"
156
+ spec = munsell_colour_to_munsell_specification(munsell_str)
157
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
158
+
159
+ XYZ = xyY_to_XYZ(xyY_scaled)
160
+ Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
161
+
162
+ xyY_values.append(xyY_scaled)
163
+ munsell_specs.append(spec)
164
+ reference_Lab.append(Lab)
165
+ except (RuntimeError, ValueError):
166
+ continue
167
+
168
+ xyY_array = np.array(xyY_values)
169
+ ground_truth = np.array(munsell_specs)
170
+ reference_Lab = np.array(reference_Lab)
171
+
172
+ LOGGER.info("Loaded %d real Munsell colours", len(xyY_array))
173
+
174
+ # Test baseline MLP
175
+ LOGGER.info("\n%s", "-" * 40)
176
+ LOGGER.info("1. MLP (Base) - Standard Normalization")
177
+ LOGGER.info("-" * 40)
178
+
179
+ base_onnx = models_dir / "mlp.onnx"
180
+ base_params_file = models_dir / "mlp_normalization_parameters.npz"
181
+
182
+ if base_onnx.exists() and base_params_file.exists():
183
+ base_session = ort.InferenceSession(str(base_onnx))
184
+ base_params_data = np.load(base_params_file, allow_pickle=True)
185
+ base_input_parameters = base_params_data["input_parameters"].item()
186
+ base_output_parameters = base_params_data["output_parameters"].item()
187
+
188
+ X_norm_base = normalize_input_standard(xyY_array, base_input_parameters)
189
+ pred_norm = base_session.run(None, {"xyY": X_norm_base})[0]
190
+ pred_base = denormalize_output(pred_norm, base_output_parameters)
191
+
192
+ errors_base = np.abs(pred_base - ground_truth)
193
+ delta_E_base = compute_delta_e(pred_base, reference_Lab)
194
+
195
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_base[:, 0]))
196
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_base[:, 1]))
197
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_base[:, 2]))
198
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_base[:, 3]))
199
+ LOGGER.info(
200
+ " Delta-E: %.4f (mean), %.4f (median)",
201
+ np.mean(delta_E_base),
202
+ np.median(delta_E_base),
203
+ )
204
+ else:
205
+ LOGGER.info(" Model not found, skipping...")
206
+ delta_E_base = []
207
+
208
+ # Test gamma MLP
209
+ LOGGER.info("\n%s", "-" * 40)
210
+ LOGGER.info("2. MLP (Gamma 2.33) - Gamma-Corrected Y")
211
+ LOGGER.info("-" * 40)
212
+
213
+ gamma_onnx = models_dir / "mlp_gamma.onnx"
214
+ gamma_params_file = models_dir / "mlp_gamma_normalization_parameters.npz"
215
+
216
+ if gamma_onnx.exists() and gamma_params_file.exists():
217
+ gamma_session = ort.InferenceSession(str(gamma_onnx))
218
+ gamma_params_data = np.load(gamma_params_file, allow_pickle=True)
219
+ gamma_input_parameters = gamma_params_data["input_parameters"].item()
220
+ gamma_output_parameters = gamma_params_data["output_parameters"].item()
221
+
222
+ X_norm_gamma = normalize_input_gamma(xyY_array, gamma_input_parameters)
223
+ pred_norm = gamma_session.run(None, {"xyY_gamma": X_norm_gamma})[0]
224
+ pred_gamma = denormalize_output(pred_norm, gamma_output_parameters)
225
+
226
+ errors_gamma = np.abs(pred_gamma - ground_truth)
227
+ delta_E_gamma = compute_delta_e(pred_gamma, reference_Lab)
228
+
229
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_gamma[:, 0]))
230
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_gamma[:, 1]))
231
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_gamma[:, 2]))
232
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_gamma[:, 3]))
233
+ LOGGER.info(
234
+ " Delta-E: %.4f (mean), %.4f (median)",
235
+ np.mean(delta_E_gamma),
236
+ np.median(delta_E_gamma),
237
+ )
238
+ else:
239
+ LOGGER.info(" Model not found, skipping...")
240
+ delta_E_gamma = []
241
+
242
+ # Summary comparison for MLP
243
+ if delta_E_base and delta_E_gamma:
244
+ LOGGER.info("\n%s", "=" * 80)
245
+ LOGGER.info("MLP COMPARISON SUMMARY")
246
+ LOGGER.info("=" * 80)
247
+ LOGGER.info("")
248
+ LOGGER.info("Delta-E (lower is better):")
249
+ LOGGER.info(
250
+ " MLP (Base): %.4f mean, %.4f median",
251
+ np.mean(delta_E_base),
252
+ np.median(delta_E_base),
253
+ )
254
+ LOGGER.info(
255
+ " MLP (Gamma): %.4f mean, %.4f median",
256
+ np.mean(delta_E_gamma),
257
+ np.median(delta_E_gamma),
258
+ )
259
+ LOGGER.info("")
260
+
261
+ improvement = (
262
+ (np.mean(delta_E_base) - np.mean(delta_E_gamma))
263
+ / np.mean(delta_E_base)
264
+ * 100
265
+ )
266
+ if improvement > 0:
267
+ LOGGER.info(" Gamma model is %.1f%% BETTER", improvement)
268
+ else:
269
+ LOGGER.info(" Gamma model is %.1f%% WORSE", -improvement)
270
+
271
+ # Test Multi-Head baseline
272
+ LOGGER.info("\n%s", "=" * 80)
273
+ LOGGER.info("MULTI-HEAD GAMMA EXPERIMENT")
274
+ LOGGER.info("=" * 80)
275
+
276
+ LOGGER.info("\n%s", "-" * 40)
277
+ LOGGER.info("3. Multi-Head (Base) - Standard Normalization")
278
+ LOGGER.info("-" * 40)
279
+
280
+ mh_base_onnx = models_dir / "multi_head.onnx"
281
+ mh_base_params_file = models_dir / "multi_head_normalization_parameters.npz"
282
+
283
+ if mh_base_onnx.exists() and mh_base_params_file.exists():
284
+ mh_base_session = ort.InferenceSession(str(mh_base_onnx))
285
+ mh_base_params_data = np.load(mh_base_params_file, allow_pickle=True)
286
+ mh_base_input_parameters = mh_base_params_data["input_parameters"].item()
287
+ mh_base_output_parameters = mh_base_params_data["output_parameters"].item()
288
+
289
+ X_norm_mh_base = normalize_input_standard(xyY_array, mh_base_input_parameters)
290
+ pred_norm = mh_base_session.run(None, {"xyY": X_norm_mh_base})[0]
291
+ pred_mh_base = denormalize_output(pred_norm, mh_base_output_parameters)
292
+
293
+ errors_mh_base = np.abs(pred_mh_base - ground_truth)
294
+ delta_E_mh_base = compute_delta_e(pred_mh_base, reference_Lab)
295
+
296
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_base[:, 0]))
297
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_base[:, 1]))
298
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_base[:, 2]))
299
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_base[:, 3]))
300
+ LOGGER.info(
301
+ " Delta-E: %.4f (mean), %.4f (median)",
302
+ np.mean(delta_E_mh_base),
303
+ np.median(delta_E_mh_base),
304
+ )
305
+ else:
306
+ LOGGER.info(" Model not found, skipping...")
307
+ delta_E_mh_base = []
308
+
309
+ # Test Multi-Head gamma
310
+ LOGGER.info("\n%s", "-" * 40)
311
+ LOGGER.info("4. Multi-Head (Gamma 2.33) - Gamma-Corrected Y")
312
+ LOGGER.info("-" * 40)
313
+
314
+ mh_gamma_onnx = models_dir / "multi_head_gamma.onnx"
315
+ mh_gamma_params_file = models_dir / "multi_head_gamma_normalization_parameters.npz"
316
+
317
+ if mh_gamma_onnx.exists() and mh_gamma_params_file.exists():
318
+ mh_gamma_session = ort.InferenceSession(str(mh_gamma_onnx))
319
+ mh_gamma_params_data = np.load(mh_gamma_params_file, allow_pickle=True)
320
+ mh_gamma_input_parameters = mh_gamma_params_data["input_parameters"].item()
321
+ mh_gamma_output_parameters = mh_gamma_params_data["output_parameters"].item()
322
+
323
+ X_norm_mh_gamma = normalize_input_gamma(xyY_array, mh_gamma_input_parameters)
324
+ pred_norm = mh_gamma_session.run(None, {"xyY_gamma": X_norm_mh_gamma})[0]
325
+ pred_mh_gamma = denormalize_output(pred_norm, mh_gamma_output_parameters)
326
+
327
+ errors_mh_gamma = np.abs(pred_mh_gamma - ground_truth)
328
+ delta_E_mh_gamma = compute_delta_e(pred_mh_gamma, reference_Lab)
329
+
330
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_gamma[:, 0]))
331
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_gamma[:, 1]))
332
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_gamma[:, 2]))
333
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_gamma[:, 3]))
334
+ LOGGER.info(
335
+ " Delta-E: %.4f (mean), %.4f (median)",
336
+ np.mean(delta_E_mh_gamma),
337
+ np.median(delta_E_mh_gamma),
338
+ )
339
+ else:
340
+ LOGGER.info(" Model not found, skipping...")
341
+ delta_E_mh_gamma = []
342
+
343
+ # Test Multi-Head ST.2084
344
+ LOGGER.info("\n%s", "-" * 40)
345
+ LOGGER.info("5. Multi-Head (ST.2084) - PQ-Encoded Y")
346
+ LOGGER.info("-" * 40)
347
+
348
+ mh_st2084_onnx = models_dir / "multi_head_st2084.onnx"
349
+ mh_st2084_params_file = (
350
+ models_dir / "multi_head_st2084_normalization_parameters.npz"
351
+ )
352
+
353
+ if mh_st2084_onnx.exists() and mh_st2084_params_file.exists():
354
+ mh_st2084_session = ort.InferenceSession(str(mh_st2084_onnx))
355
+ mh_st2084_params_data = np.load(mh_st2084_params_file, allow_pickle=True)
356
+ mh_st2084_input_parameters = mh_st2084_params_data["input_parameters"].item()
357
+ mh_st2084_output_parameters = mh_st2084_params_data["output_parameters"].item()
358
+
359
+ X_norm_mh_st2084 = normalize_input_st2084(xyY_array, mh_st2084_input_parameters)
360
+ pred_norm = mh_st2084_session.run(None, {"xyY_st2084": X_norm_mh_st2084})[0]
361
+ pred_mh_st2084 = denormalize_output(pred_norm, mh_st2084_output_parameters)
362
+
363
+ errors_mh_st2084 = np.abs(pred_mh_st2084 - ground_truth)
364
+ delta_E_mh_st2084 = compute_delta_e(pred_mh_st2084, reference_Lab)
365
+
366
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_st2084[:, 0]))
367
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_st2084[:, 1]))
368
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_st2084[:, 2]))
369
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_st2084[:, 3]))
370
+ LOGGER.info(
371
+ " Delta-E: %.4f (mean), %.4f (median)",
372
+ np.mean(delta_E_mh_st2084),
373
+ np.median(delta_E_mh_st2084),
374
+ )
375
+ else:
376
+ LOGGER.info(" Model not found, skipping...")
377
+ delta_E_mh_st2084 = []
378
+
379
+ # Summary comparison for Multi-Head
380
+ if delta_E_mh_base and delta_E_mh_gamma:
381
+ LOGGER.info("\n%s", "=" * 80)
382
+ LOGGER.info("MULTI-HEAD COMPARISON SUMMARY")
383
+ LOGGER.info("=" * 80)
384
+ LOGGER.info("")
385
+ LOGGER.info("Delta-E (lower is better):")
386
+ LOGGER.info(
387
+ " Multi-Head (Base): %.4f mean, %.4f median",
388
+ np.mean(delta_E_mh_base),
389
+ np.median(delta_E_mh_base),
390
+ )
391
+ LOGGER.info(
392
+ " Multi-Head (Gamma): %.4f mean, %.4f median",
393
+ np.mean(delta_E_mh_gamma),
394
+ np.median(delta_E_mh_gamma),
395
+ )
396
+ if delta_E_mh_st2084:
397
+ LOGGER.info(
398
+ " Multi-Head (ST.2084): %.4f mean, %.4f median",
399
+ np.mean(delta_E_mh_st2084),
400
+ np.median(delta_E_mh_st2084),
401
+ )
402
+ LOGGER.info("")
403
+
404
+ mh_gamma_improvement = (
405
+ (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_gamma))
406
+ / np.mean(delta_E_mh_base)
407
+ * 100
408
+ )
409
+ if mh_gamma_improvement > 0:
410
+ LOGGER.info(
411
+ " Multi-Head Gamma vs Base: %.1f%% BETTER", mh_gamma_improvement
412
+ )
413
+ else:
414
+ LOGGER.info(
415
+ " Multi-Head Gamma vs Base: %.1f%% WORSE", -mh_gamma_improvement
416
+ )
417
+
418
+ if delta_E_mh_st2084:
419
+ mh_st2084_improvement = (
420
+ (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_st2084))
421
+ / np.mean(delta_E_mh_base)
422
+ * 100
423
+ )
424
+ if mh_st2084_improvement > 0:
425
+ LOGGER.info(
426
+ " Multi-Head ST.2084 vs Base: %.1f%% BETTER", mh_st2084_improvement
427
+ )
428
+ else:
429
+ LOGGER.info(
430
+ " Multi-Head ST.2084 vs Base: %.1f%% WORSE", -mh_st2084_improvement
431
+ )
432
+
433
+ # Compare ST.2084 vs Gamma
434
+ st2084_vs_gamma = (
435
+ (np.mean(delta_E_mh_gamma) - np.mean(delta_E_mh_st2084))
436
+ / np.mean(delta_E_mh_gamma)
437
+ * 100
438
+ )
439
+ if st2084_vs_gamma > 0:
440
+ LOGGER.info(
441
+ " Multi-Head ST.2084 vs Gamma: %.1f%% BETTER", st2084_vs_gamma
442
+ )
443
+ else:
444
+ LOGGER.info(
445
+ " Multi-Head ST.2084 vs Gamma: %.1f%% WORSE", -st2084_vs_gamma
446
+ )
447
+
448
+ LOGGER.info("\n%s", "=" * 80)
449
+
450
+
451
+ if __name__ == "__main__":
452
+ main()
learning_munsell/comparison/to_xyY/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Comparison scripts for Munsell to xyY conversion models."""
learning_munsell/comparison/to_xyY/compare_all_models.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compare all ML models for Munsell to xyY conversion on real Munsell data.
3
+
4
+ Models to compare:
5
+ 1. Simple MLP Approximator
6
+ 2. Multi-Head MLP
7
+ 3. Multi-Head MLP (Optimized) - with hyperparameter optimization
8
+ 4. Multi-Head + Multi-Error Predictor
9
+ 5. Multi-MLP - 3 independent branches
10
+ 6. Multi-MLP (Optimized) - 3 independent branches with optimized hyperparameters
11
+ 7. Multi-MLP + Error Predictor
12
+ 8. Multi-MLP + Multi-Error Predictor
13
+ 9. Multi-MLP (Optimized) + Multi-Error Predictor (Optimized)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ import time
20
+ import warnings
21
+ from typing import TYPE_CHECKING
22
+
23
+ import numpy as np
24
+ import onnxruntime as ort
25
+ from colour import XYZ_to_Lab, xyY_to_XYZ
26
+ from colour.difference import delta_E_CIE2000
27
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
28
+ from colour.notation.munsell import (
29
+ CCS_ILLUMINANT_MUNSELL,
30
+ munsell_colour_to_munsell_specification,
31
+ munsell_specification_to_xyY,
32
+ )
33
+ from numpy.typing import NDArray # noqa: TC002
34
+
35
+ from learning_munsell import PROJECT_ROOT
36
+ from learning_munsell.utilities.common import (
37
+ benchmark_inference_speed,
38
+ generate_html_report_footer,
39
+ generate_html_report_header,
40
+ generate_ranking_section,
41
+ get_model_size_mb,
42
+ )
43
+
44
+ if TYPE_CHECKING:
45
+ from pathlib import Path
46
+
47
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
48
+ LOGGER = logging.getLogger(__name__)
49
+
50
+
51
+ def normalize_munsell(munsell: np.ndarray) -> np.ndarray:
52
+ """Normalize Munsell specs to [0, 1] range."""
53
+ normalized = munsell.copy()
54
+ normalized[..., 0] = munsell[..., 0] / 10.0 # Hue (in decade)
55
+ normalized[..., 1] = munsell[..., 1] / 10.0 # Value
56
+ normalized[..., 2] = munsell[..., 2] / 50.0 # Chroma
57
+ normalized[..., 3] = munsell[..., 3] / 10.0 # Code
58
+ return normalized.astype(np.float32)
59
+
60
+
61
+ def evaluate_model(
62
+ session: ort.InferenceSession,
63
+ X_norm: np.ndarray,
64
+ ground_truth: np.ndarray,
65
+ input_name: str = "munsell_normalized",
66
+ ) -> dict:
67
+ """Evaluate a single model."""
68
+ pred = session.run(None, {input_name: X_norm})[0]
69
+ errors = np.abs(pred - ground_truth)
70
+
71
+ return {
72
+ "x_mae": np.mean(errors[:, 0]),
73
+ "y_mae": np.mean(errors[:, 1]),
74
+ "Y_mae": np.mean(errors[:, 2]),
75
+ "predictions": pred,
76
+ "errors": errors,
77
+ "max_errors": np.max(errors, axis=1),
78
+ }
79
+
80
+
81
+ def compute_delta_E(
82
+ ml_predictions: np.ndarray,
83
+ reference_xyY: np.ndarray,
84
+ ) -> float:
85
+ """Compute Delta-E CIE2000 between ML predictions and reference xyY (ground truth)."""
86
+ delta_E_values = []
87
+
88
+ for ml_xyY, ref_xyY in zip(ml_predictions, reference_xyY, strict=False):
89
+ try:
90
+ ml_XYZ = xyY_to_XYZ(ml_xyY)
91
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
92
+
93
+ ref_XYZ = xyY_to_XYZ(ref_xyY)
94
+ ref_Lab = XYZ_to_Lab(ref_XYZ, CCS_ILLUMINANT_MUNSELL)
95
+
96
+ delta_E = delta_E_CIE2000(ref_Lab, ml_Lab)
97
+ if not np.isnan(delta_E):
98
+ delta_E_values.append(delta_E)
99
+ except (RuntimeError, ValueError):
100
+ continue
101
+
102
+ return np.mean(delta_E_values) if delta_E_values else np.nan
103
+
104
+
105
+ def generate_html_report(
106
+ results: dict,
107
+ num_samples: int,
108
+ output_file: Path,
109
+ baseline_inference_time_ms: float,
110
+ ) -> None:
111
+ """Generate HTML report with visualizations."""
112
+ # Calculate average MAE
113
+ avg_maes = {}
114
+ for model_name, result in results.items():
115
+ avg_maes[model_name] = np.mean(
116
+ [
117
+ result["x_mae"],
118
+ result["y_mae"],
119
+ result["Y_mae"],
120
+ ]
121
+ )
122
+
123
+ # Sort by average MAE
124
+ sorted_models = sorted(avg_maes.items(), key=lambda x: x[1])
125
+
126
+ # Start HTML
127
+ html = generate_html_report_header(
128
+ title="ML Model Comparison Report",
129
+ subtitle="Munsell to xyY Conversion",
130
+ num_samples=num_samples,
131
+ )
132
+
133
+ # Best Models Summary
134
+ best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0]
135
+ best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0]
136
+ best_avg = sorted_models[0][0]
137
+
138
+ # Find best Delta-E
139
+ delta_E_results = [
140
+ (n, r["delta_E"]) for n, r in results.items() if not np.isnan(r["delta_E"])
141
+ ]
142
+ best_delta_E = (
143
+ min(delta_E_results, key=lambda x: x[1])[0] if delta_E_results else None
144
+ )
145
+
146
+ html += f"""
147
+ <!-- Best Models Summary -->
148
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
149
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Best Models by Metric</h2>
150
+ <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
151
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
152
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Smallest Size</div>
153
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_size]["model_size_mb"]:.2f} MB</div>
154
+ <div class="text-sm text-foreground/80">{best_size}</div>
155
+ </div>
156
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
157
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Fastest Speed</div>
158
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_speed]["inference_time_ms"]:.4f} ms</div>
159
+ <div class="text-sm text-foreground/80">{best_speed}</div>
160
+ </div>
161
+ """
162
+
163
+ if best_delta_E:
164
+ html += f"""
165
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
166
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Delta-E</div>
167
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_delta_E]["delta_E"]:.6f}</div>
168
+ <div class="text-sm text-foreground/80">{best_delta_E}</div>
169
+ </div>
170
+ """
171
+
172
+ html += f"""
173
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
174
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Average MAE</div>
175
+ <div class="text-3xl font-bold text-primary mb-3">{avg_maes[best_avg]:.6f}</div>
176
+ <div class="text-sm text-foreground/80">{best_avg}</div>
177
+ </div>
178
+ </div>
179
+ </div>
180
+ """
181
+
182
+ # Performance Metrics Table
183
+ sorted_by_avg_mae = sorted(results.items(), key=lambda x: avg_maes[x[0]])
184
+
185
+ html += """
186
+ <!-- Performance Metrics Table -->
187
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
188
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Model Performance Metrics</h2>
189
+ <div class="overflow-x-auto">
190
+ <table class="w-full text-sm">
191
+ <thead>
192
+ <tr class="border-b border-border">
193
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
194
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Size (MB)</th>
195
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Speed (ms/sample)</th>
196
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">vs Baseline</th>
197
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE x</th>
198
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE y</th>
199
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE Y</th>
200
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Delta-E</th>
201
+ </tr>
202
+ </thead>
203
+ <tbody>
204
+ """
205
+
206
+ for model_name, result in sorted_by_avg_mae:
207
+ size_mb = result["model_size_mb"]
208
+ speed_ms = result["inference_time_ms"]
209
+ delta_E = result["delta_E"]
210
+
211
+ # Calculate speedup vs baseline
212
+ speedup = baseline_inference_time_ms / speed_ms if speed_ms > 0 else 0
213
+
214
+ size_class = "text-primary font-semibold" if model_name == best_size else ""
215
+ speed_class = "text-primary font-semibold" if model_name == best_speed else ""
216
+ delta_E_class = (
217
+ "text-primary font-semibold" if model_name == best_delta_E else ""
218
+ )
219
+
220
+ delta_E_str = f"{delta_E:.6f}" if not np.isnan(delta_E) else "—"
221
+
222
+ speedup_text = f"{speedup:.0f}x" if speedup > 100 else f"{speedup:.1f}x"
223
+
224
+ html += f"""
225
+ <tr class="border-b border-border/50 hover:bg-muted/30 transition-colors">
226
+ <td class="py-3 px-4 font-medium">{model_name}</td>
227
+ <td class="py-3 px-4 text-right {size_class}">{size_mb:.2f}</td>
228
+ <td class="py-3 px-4 text-right {speed_class}">{speed_ms:.4f}</td>
229
+ <td class="py-3 px-4 text-right text-primary font-semibold">{speedup_text}</td>
230
+ <td class="py-3 px-4 text-right">{result["x_mae"]:.6f}</td>
231
+ <td class="py-3 px-4 text-right">{result["y_mae"]:.6f}</td>
232
+ <td class="py-3 px-4 text-right">{result["Y_mae"]:.6f}</td>
233
+ <td class="py-3 px-4 text-right {delta_E_class}">{delta_E_str}</td>
234
+ </tr>
235
+ """
236
+
237
+ html += """
238
+ </tbody>
239
+ </table>
240
+ </div>
241
+ </div>
242
+ """
243
+
244
+ # Add ranking section
245
+ html += generate_ranking_section(
246
+ results,
247
+ metric_key="avg_mae",
248
+ title="Overall Ranking (by Average MAE)",
249
+ )
250
+
251
+ # Precision thresholds
252
+ thresholds = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
253
+
254
+ html += """
255
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
256
+ <h2 class="text-2xl font-semibold mb-3 pb-3 border-b border-primary/30">Accuracy at Precision Thresholds</h2>
257
+ <p class="text-sm text-muted-foreground mb-6">Percentage of predictions where max error across all components is below threshold:</p>
258
+ <div class="overflow-x-auto">
259
+ <table class="w-full text-sm">
260
+ <thead>
261
+ <tr class="border-b border-border">
262
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
263
+ """
264
+
265
+ for threshold in thresholds:
266
+ html += f' <th class="text-right py-3 px-4 font-semibold text-muted-foreground">&lt; {threshold:.0e}</th>\n'
267
+
268
+ html += """
269
+ </tr>
270
+ </thead>
271
+ <tbody>
272
+ """
273
+
274
+ for model_name, _ in sorted_models:
275
+ result = results[model_name]
276
+ html += f"""
277
+ <tr class="border-b border-border hover:bg-muted/30 transition-colors">
278
+ <td class="text-left py-3 px-4 font-medium">{model_name}</td>
279
+ """
280
+ for threshold in thresholds:
281
+ accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
282
+ html += f' <td class="text-right py-3 px-4">{accuracy_pct:.2f}%</td>\n'
283
+
284
+ html += """
285
+ </tr>
286
+ """
287
+
288
+ html += """
289
+ </tbody>
290
+ </table>
291
+ </div>
292
+ </div>
293
+ """
294
+
295
+ html += generate_html_report_footer()
296
+
297
+ # Write HTML file
298
+ with open(output_file, "w") as f:
299
+ f.write(html)
300
+
301
+ LOGGER.info("")
302
+ LOGGER.info("HTML report saved to: %s", output_file)
303
+
304
+
305
+ def main() -> None:
306
+ """Compare all models."""
307
+ LOGGER.info("=" * 80)
308
+ LOGGER.info("Munsell to xyY Model Comparison")
309
+ LOGGER.info("=" * 80)
310
+
311
+ # Paths
312
+ model_directory = PROJECT_ROOT / "models" / "to_xyY"
313
+
314
+ # Load real Munsell dataset
315
+ LOGGER.info("")
316
+ LOGGER.info("Loading real Munsell dataset...")
317
+ munsell_specs = []
318
+ xyY_ground_truth = []
319
+
320
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
321
+ try:
322
+ hue_code, value, chroma = munsell_spec_tuple
323
+ munsell_str = f"{hue_code} {value}/{chroma}"
324
+ spec = munsell_colour_to_munsell_specification(munsell_str)
325
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
326
+ munsell_specs.append(spec)
327
+ xyY_ground_truth.append(xyY_scaled)
328
+ except Exception: # noqa: BLE001, S112
329
+ continue
330
+
331
+ munsell_specs = np.array(munsell_specs, dtype=np.float32)
332
+ xyY_ground_truth = np.array(xyY_ground_truth, dtype=np.float32)
333
+ LOGGER.info("Loaded %d valid Munsell colors", len(munsell_specs))
334
+
335
+ # Normalize inputs
336
+ munsell_normalized = normalize_munsell(munsell_specs)
337
+
338
+ # Benchmark colour library first
339
+ LOGGER.info("")
340
+ LOGGER.info("=" * 80)
341
+ LOGGER.info("Colour Library (munsell_specification_to_xyY)")
342
+ LOGGER.info("=" * 80)
343
+
344
+ # Benchmark the munsell_specification_to_xyY function
345
+ # Note: Using full dataset (100% of samples)
346
+
347
+ # Set random seed for reproducibility
348
+ np.random.seed(42)
349
+
350
+ # Use 100% of samples for comprehensive benchmarking
351
+ sampled_indices = np.arange(len(munsell_specs))
352
+ munsell_benchmark = munsell_specs[sampled_indices]
353
+
354
+ start_time = time.perf_counter()
355
+ colour_predictions = []
356
+ successful_inferences = 0
357
+
358
+ with warnings.catch_warnings():
359
+ warnings.simplefilter("ignore")
360
+ for spec in munsell_benchmark:
361
+ try:
362
+ xyY = munsell_specification_to_xyY(spec)
363
+ colour_predictions.append(xyY)
364
+ successful_inferences += 1
365
+ except (RuntimeError, ValueError):
366
+ colour_predictions.append(np.array([np.nan, np.nan, np.nan]))
367
+
368
+ end_time = time.perf_counter()
369
+
370
+ total_time_s = end_time - start_time
371
+ baseline_inference_time_ms = (
372
+ (total_time_s / successful_inferences) * 1000
373
+ if successful_inferences > 0
374
+ else 0
375
+ )
376
+ colour_predictions = np.array(colour_predictions)
377
+
378
+ LOGGER.info(" Successful inferences: %d", successful_inferences)
379
+ LOGGER.info(" Inference Speed: %.4f ms/sample", baseline_inference_time_ms)
380
+
381
+ # Define models to compare
382
+ models = [
383
+ {
384
+ "name": "Simple MLP",
385
+ "files": [model_directory / "munsell_to_xyY_approximator.onnx"],
386
+ "params_file": model_directory
387
+ / "munsell_to_xyY_approximator_normalization_parameters.npz",
388
+ "type": "single",
389
+ },
390
+ {
391
+ "name": "Multi-Head",
392
+ "files": [model_directory / "multi_head.onnx"],
393
+ "params_file": model_directory / "multi_head_normalization_parameters.npz",
394
+ "type": "single",
395
+ },
396
+ {
397
+ "name": "Multi-Head + Multi-Error Predictor",
398
+ "files": [
399
+ model_directory / "multi_head.onnx",
400
+ model_directory / "multi_head_multi_error_predictor.onnx",
401
+ ],
402
+ "params_file": model_directory
403
+ / "multi_head_multi_error_predictor_normalization_parameters.npz",
404
+ "type": "two_stage",
405
+ },
406
+ {
407
+ "name": "Multi-MLP",
408
+ "files": [model_directory / "multi_mlp.onnx"],
409
+ "params_file": model_directory / "multi_mlp_normalization_parameters.npz",
410
+ "type": "single",
411
+ },
412
+ {
413
+ "name": "Multi-MLP + Error Predictor",
414
+ "files": [
415
+ model_directory / "multi_mlp.onnx",
416
+ model_directory / "multi_mlp_error_predictor.onnx",
417
+ ],
418
+ "params_file": model_directory
419
+ / "multi_mlp_error_predictor_normalization_parameters.npz",
420
+ "type": "two_stage",
421
+ },
422
+ {
423
+ "name": "Multi-MLP + Multi-Error Predictor",
424
+ "files": [
425
+ model_directory / "multi_mlp.onnx",
426
+ model_directory / "multi_mlp_multi_error_predictor.onnx",
427
+ ],
428
+ "params_file": model_directory
429
+ / "multi_mlp_multi_error_predictor_normalization_parameters.npz",
430
+ "type": "two_stage",
431
+ },
432
+ ]
433
+
434
+ # Evaluate each model
435
+ results = {}
436
+
437
+ for model_info in models:
438
+ model_name = model_info["name"]
439
+ LOGGER.info("")
440
+ LOGGER.info("=" * 80)
441
+ LOGGER.info(model_name)
442
+ LOGGER.info("=" * 80)
443
+
444
+ # Calculate model size
445
+ model_size_mb = get_model_size_mb(model_info["files"])
446
+
447
+ if model_info["type"] == "two_stage":
448
+ # Two-stage model
449
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
450
+ error_session = ort.InferenceSession(str(model_info["files"][1]))
451
+ error_input_name = error_session.get_inputs()[0].name
452
+
453
+ # Define inference callable
454
+ def two_stage_inference(
455
+ _base_session: ort.InferenceSession = base_session,
456
+ _error_session: ort.InferenceSession = error_session,
457
+ _munsell_normalized: NDArray = munsell_normalized,
458
+ _error_input_name: str = error_input_name,
459
+ ) -> NDArray:
460
+ base_pred = _base_session.run(
461
+ None, {"munsell_normalized": _munsell_normalized}
462
+ )[0]
463
+ combined = np.concatenate(
464
+ [_munsell_normalized, base_pred], axis=1
465
+ ).astype(np.float32)
466
+ error_corr = _error_session.run(None, {_error_input_name: combined})[0]
467
+ return base_pred + error_corr
468
+
469
+ # Benchmark speed
470
+ inference_time_ms = benchmark_inference_speed(
471
+ two_stage_inference, munsell_normalized
472
+ )
473
+
474
+ # Get predictions
475
+ base_pred = base_session.run(
476
+ None, {"munsell_normalized": munsell_normalized}
477
+ )[0]
478
+ combined = np.concatenate([munsell_normalized, base_pred], axis=1).astype(
479
+ np.float32
480
+ )
481
+ error_corr = error_session.run(None, {error_input_name: combined})[0]
482
+ pred = base_pred + error_corr
483
+
484
+ errors = np.abs(pred - xyY_ground_truth)
485
+ result = {
486
+ "x_mae": np.mean(errors[:, 0]),
487
+ "y_mae": np.mean(errors[:, 1]),
488
+ "Y_mae": np.mean(errors[:, 2]),
489
+ "predictions": pred,
490
+ "errors": errors,
491
+ "max_errors": np.max(errors, axis=1),
492
+ }
493
+ else:
494
+ # Single model
495
+ session = ort.InferenceSession(str(model_info["files"][0]))
496
+
497
+ # Define inference callable
498
+ def single_inference(
499
+ _session: ort.InferenceSession = session,
500
+ _munsell_normalized: NDArray = munsell_normalized,
501
+ ) -> NDArray:
502
+ return _session.run(None, {"munsell_normalized": _munsell_normalized})[
503
+ 0
504
+ ]
505
+
506
+ # Benchmark speed
507
+ inference_time_ms = benchmark_inference_speed(
508
+ single_inference, munsell_normalized
509
+ )
510
+
511
+ result = evaluate_model(session, munsell_normalized, xyY_ground_truth)
512
+
513
+ result["model_size_mb"] = model_size_mb
514
+ result["inference_time_ms"] = inference_time_ms
515
+ result["avg_mae"] = np.mean([result["x_mae"], result["y_mae"], result["Y_mae"]])
516
+
517
+ # Compute Delta-E against ground truth (measured xyY)
518
+ sampled_predictions = result["predictions"][sampled_indices]
519
+ result["delta_E"] = compute_delta_E(
520
+ sampled_predictions,
521
+ xyY_ground_truth,
522
+ )
523
+
524
+ results[model_name] = result
525
+
526
+ # Print results
527
+ LOGGER.info("")
528
+ LOGGER.info("Mean Absolute Errors:")
529
+ LOGGER.info(" x: %.6f", result["x_mae"])
530
+ LOGGER.info(" y: %.6f", result["y_mae"])
531
+ LOGGER.info(" Y: %.6f", result["Y_mae"])
532
+ if not np.isnan(result["delta_E"]):
533
+ LOGGER.info(" Delta-E (vs Ground Truth): %.6f", result["delta_E"])
534
+ LOGGER.info("")
535
+ LOGGER.info("Performance Metrics:")
536
+ LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"])
537
+ LOGGER.info(" Inference Speed: %.4f ms/sample", result["inference_time_ms"])
538
+ LOGGER.info(
539
+ " Speedup vs Colour: %.1fx",
540
+ baseline_inference_time_ms / inference_time_ms,
541
+ )
542
+
543
+ # Summary
544
+ LOGGER.info("")
545
+ LOGGER.info("=" * 80)
546
+ LOGGER.info("SUMMARY COMPARISON")
547
+ LOGGER.info("=" * 80)
548
+ LOGGER.info("")
549
+
550
+ if not results:
551
+ LOGGER.info("No models were successfully evaluated")
552
+ return
553
+
554
+ # MAE comparison table
555
+ LOGGER.info("Mean Absolute Error Comparison:")
556
+ LOGGER.info("")
557
+ header = f"{'Model':<40} {'x':>10} {'y':>10} {'Y':>10} {'Delta-E':>12}"
558
+ LOGGER.info(header)
559
+ LOGGER.info("-" * 85)
560
+
561
+ for model_name, result in results.items():
562
+ delta_E_str = (
563
+ f"{result['delta_E']:.6f}" if not np.isnan(result["delta_E"]) else "N/A"
564
+ )
565
+ LOGGER.info(
566
+ "%-40s %10.6f %10.6f %10.6f %12s",
567
+ model_name,
568
+ result["x_mae"],
569
+ result["y_mae"],
570
+ result["Y_mae"],
571
+ delta_E_str,
572
+ )
573
+
574
+ # Generate HTML report
575
+ report_dir = PROJECT_ROOT / "reports" / "to_xyY"
576
+ report_dir.mkdir(parents=True, exist_ok=True)
577
+ report_file = report_dir / "model_comparison.html"
578
+ generate_html_report(
579
+ results, len(munsell_specs), report_file, baseline_inference_time_ms
580
+ )
581
+
582
+
583
+ if __name__ == "__main__":
584
+ main()
learning_munsell/data_generation/generate_training_data.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate training data for ML-based xyY to Munsell conversion.
3
+
4
+ Generates samples by sampling in Munsell space and converting to xyY via
5
+ forward conversion, guaranteeing 100% valid samples.
6
+
7
+ Usage:
8
+ uv run python -m learning_munsell.data_generation.generate_training_data
9
+ uv run python -m learning_munsell.data_generation.generate_training_data \\
10
+ --n-samples 2000000 --perturbation 0.10 --output training_data_large
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import logging
16
+ import multiprocessing as mp
17
+ import warnings
18
+ from datetime import UTC, datetime
19
+
20
+ import numpy as np
21
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
22
+ from colour.notation.munsell import (
23
+ munsell_colour_to_munsell_specification,
24
+ munsell_specification_to_xyY,
25
+ )
26
+ from colour.utilities import ColourUsageWarning
27
+ from numpy.typing import NDArray
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from learning_munsell import PROJECT_ROOT
31
+
32
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ def _worker_generate_samples(
37
+ args: tuple[int, NDArray, int, float],
38
+ ) -> tuple[list[NDArray], list[NDArray]]:
39
+ """
40
+ Worker function to generate samples in parallel.
41
+
42
+ Parameters
43
+ ----------
44
+ args : tuple
45
+ - worker_id: Worker identifier
46
+ - base_specs: Array of base Munsell specifications
47
+ - samples_per_base: Number of samples to generate per base color
48
+ - perturbation_pct: Perturbation percentage
49
+
50
+ Returns
51
+ -------
52
+ tuple
53
+ - xyY_samples: List of xyY arrays
54
+ - munsell_samples: List of Munsell specification arrays
55
+ """
56
+ worker_id, base_specs, samples_per_base, perturbation_pct = args
57
+
58
+ np.random.seed(42 + worker_id)
59
+
60
+ warnings.filterwarnings("ignore", category=ColourUsageWarning)
61
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
62
+
63
+ xyY_samples = []
64
+ munsell_samples = []
65
+
66
+ hue_range = 9.5
67
+ value_range = 9.0
68
+ chroma_range = 50.0
69
+
70
+ for base_spec in base_specs:
71
+ for _ in range(samples_per_base):
72
+ hue_delta = np.random.uniform(
73
+ -perturbation_pct * hue_range, perturbation_pct * hue_range
74
+ )
75
+ value_delta = np.random.uniform(
76
+ -perturbation_pct * value_range, perturbation_pct * value_range
77
+ )
78
+ chroma_delta = np.random.uniform(
79
+ -perturbation_pct * chroma_range, perturbation_pct * chroma_range
80
+ )
81
+
82
+ perturbed_spec = base_spec.copy()
83
+ perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0)
84
+ perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0)
85
+ perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0)
86
+
87
+ try:
88
+ xyY = munsell_specification_to_xyY(perturbed_spec)
89
+ xyY_samples.append(xyY)
90
+ munsell_samples.append(perturbed_spec)
91
+ except Exception: # noqa: BLE001, S112
92
+ continue
93
+
94
+ return xyY_samples, munsell_samples
95
+
96
+
97
+ def generate_forward_munsell_samples(
98
+ n_samples: int = 500000,
99
+ perturbation_pct: float = 0.05,
100
+ n_workers: int | None = None,
101
+ ) -> tuple[NDArray, NDArray]:
102
+ """
103
+ Generate samples by sampling directly in Munsell space and converting to xyY.
104
+
105
+ Parameters
106
+ ----------
107
+ n_samples : int
108
+ Target number of samples to generate.
109
+ perturbation_pct : float
110
+ Perturbation as percentage of valid range.
111
+ n_workers : int, optional
112
+ Number of parallel workers. Defaults to CPU count.
113
+
114
+ Returns
115
+ -------
116
+ tuple
117
+ - xyY_samples: Array of shape (n, 3) with xyY values
118
+ - munsell_samples: Array of shape (n, 4) with Munsell specifications
119
+ """
120
+ if n_workers is None:
121
+ n_workers = mp.cpu_count()
122
+
123
+ LOGGER.info(
124
+ "Generating %d samples with %.0f%% perturbations using %d workers...",
125
+ n_samples,
126
+ perturbation_pct * 100,
127
+ n_workers,
128
+ )
129
+
130
+ # Extract base Munsell specifications
131
+ base_specs = []
132
+ for munsell_spec_tuple, _ in MUNSELL_COLOURS_ALL:
133
+ hue_code_str, value, chroma = munsell_spec_tuple
134
+ munsell_str = f"{hue_code_str} {value}/{chroma}"
135
+ spec = munsell_colour_to_munsell_specification(munsell_str)
136
+ base_specs.append(spec)
137
+
138
+ base_specs = np.array(base_specs)
139
+ samples_per_base = n_samples // len(base_specs) + 1
140
+
141
+ LOGGER.info("Using %d base Munsell colors", len(base_specs))
142
+ LOGGER.info("Generating ~%d samples per base color", samples_per_base)
143
+
144
+ # Split base specs across workers
145
+ specs_per_worker = len(base_specs) // n_workers
146
+ worker_args = []
147
+
148
+ for i in range(n_workers):
149
+ start_idx = i * specs_per_worker
150
+ end_idx = start_idx + specs_per_worker if i < n_workers - 1 else len(base_specs)
151
+ worker_specs = base_specs[start_idx:end_idx]
152
+ worker_args.append((i, worker_specs, samples_per_base, perturbation_pct))
153
+
154
+ # Run in parallel
155
+ LOGGER.info("Starting %d parallel workers...", n_workers)
156
+ with mp.Pool(n_workers) as pool:
157
+ results = pool.map(_worker_generate_samples, worker_args)
158
+
159
+ # Combine results
160
+ all_xyY = []
161
+ all_munsell = []
162
+ for xyY_samples, munsell_samples in results:
163
+ all_xyY.extend(xyY_samples)
164
+ all_munsell.extend(munsell_samples)
165
+
166
+ # Trim to exact sample count
167
+ all_xyY = all_xyY[:n_samples]
168
+ all_munsell = all_munsell[:n_samples]
169
+
170
+ LOGGER.info("Generated %d valid samples", len(all_xyY))
171
+ return np.array(all_xyY), np.array(all_munsell)
172
+
173
+
174
+ def main(
175
+ n_samples: int = 500000,
176
+ perturbation_pct: float = 0.05,
177
+ output: str = "training_data",
178
+ ) -> None:
179
+ """Generate and save training data."""
180
+ LOGGER.info("=" * 80)
181
+ LOGGER.info("Training Data Generation")
182
+ LOGGER.info("=" * 80)
183
+
184
+ output_dir = PROJECT_ROOT / "data"
185
+ output_dir.mkdir(exist_ok=True)
186
+
187
+ LOGGER.info("")
188
+ LOGGER.info("SAMPLING STRATEGY")
189
+ LOGGER.info("=" * 80)
190
+ LOGGER.info("Forward Munsell->xyY sampling:")
191
+ LOGGER.info(
192
+ " - Base: %d colors from MUNSELL_COLOURS_ALL", len(MUNSELL_COLOURS_ALL)
193
+ )
194
+ LOGGER.info(
195
+ " - Perturbations: +/-%.0f%% of valid range per component",
196
+ perturbation_pct * 100,
197
+ )
198
+ LOGGER.info(
199
+ " - Hue: +/-%.2f (+/-%.0f%% of 9.5 range)",
200
+ perturbation_pct * 9.5,
201
+ perturbation_pct * 100,
202
+ )
203
+ LOGGER.info(
204
+ " - Value: +/-%.2f (+/-%.0f%% of 9.0 range)",
205
+ perturbation_pct * 9.0,
206
+ perturbation_pct * 100,
207
+ )
208
+ LOGGER.info(
209
+ " - Chroma: +/-%.1f (+/-%.0f%% of 50.0 range)",
210
+ perturbation_pct * 50.0,
211
+ perturbation_pct * 100,
212
+ )
213
+ LOGGER.info(" - Target samples: %d", n_samples)
214
+ LOGGER.info("=" * 80)
215
+ LOGGER.info("")
216
+
217
+ # Generate samples
218
+ xyY_all, munsell_all = generate_forward_munsell_samples(
219
+ n_samples=n_samples,
220
+ perturbation_pct=perturbation_pct,
221
+ )
222
+
223
+ valid_mask = np.ones(len(xyY_all), dtype=bool)
224
+
225
+ LOGGER.info("")
226
+ LOGGER.info("Sample statistics:")
227
+ LOGGER.info(" Total samples generated: %d", len(xyY_all))
228
+ LOGGER.info(" All samples are valid (100%% by forward conversion)")
229
+
230
+ LOGGER.info("")
231
+ LOGGER.info("Using %d valid samples for training", len(xyY_all))
232
+
233
+ # Split into train/validation/test (70/15/15)
234
+ X_temp, X_test, y_temp, y_test = train_test_split(
235
+ xyY_all, munsell_all, test_size=0.15, random_state=42
236
+ )
237
+ X_train, X_val, y_train, y_val = train_test_split(
238
+ X_temp, y_temp, test_size=0.15 / 0.85, random_state=42
239
+ )
240
+
241
+ LOGGER.info("")
242
+ LOGGER.info("Data split:")
243
+ LOGGER.info(" Train: %d samples", len(X_train))
244
+ LOGGER.info(" Validation: %d samples", len(X_val))
245
+ LOGGER.info(" Test: %d samples", len(X_test))
246
+
247
+ # Save training data
248
+ cache_file = output_dir / f"{output}.npz"
249
+ np.savez_compressed(
250
+ cache_file,
251
+ X_train=X_train,
252
+ y_train=y_train,
253
+ X_val=X_val,
254
+ y_val=y_val,
255
+ X_test=X_test,
256
+ y_test=y_test,
257
+ xyY_all=xyY_all,
258
+ munsell_all=munsell_all,
259
+ valid_mask=valid_mask,
260
+ )
261
+
262
+ # Save parameters to sidecar file
263
+ params_file = output_dir / f"{output}_params.json"
264
+ params = {
265
+ "n_samples": n_samples,
266
+ "perturbation_pct": perturbation_pct,
267
+ "n_base_colors": len(MUNSELL_COLOURS_ALL),
268
+ "train_samples": len(X_train),
269
+ "val_samples": len(X_val),
270
+ "test_samples": len(X_test),
271
+ "generated_at": datetime.now(UTC).isoformat(),
272
+ }
273
+ with open(params_file, "w") as f:
274
+ json.dump(params, f, indent=2)
275
+
276
+ LOGGER.info("")
277
+ LOGGER.info("Training data saved to: %s", cache_file)
278
+ LOGGER.info("Parameters saved to: %s", params_file)
279
+ LOGGER.info("=" * 80)
280
+
281
+
282
+ if __name__ == "__main__":
283
+ parser = argparse.ArgumentParser(
284
+ description="Generate training data for xyY to Munsell conversion"
285
+ )
286
+ parser.add_argument(
287
+ "--n-samples",
288
+ type=int,
289
+ default=500000,
290
+ help="Number of samples to generate (default: 500000)",
291
+ )
292
+ parser.add_argument(
293
+ "--perturbation",
294
+ type=float,
295
+ default=0.05,
296
+ help="Perturbation as fraction of valid range (default: 0.05)",
297
+ )
298
+ parser.add_argument(
299
+ "--output",
300
+ type=str,
301
+ default="training_data",
302
+ help="Output filename without extension (default: training_data)",
303
+ )
304
+ args = parser.parse_args()
305
+
306
+ main(
307
+ n_samples=args.n_samples,
308
+ perturbation_pct=args.perturbation,
309
+ output=args.output,
310
+ )
learning_munsell/data_generation/generate_training_data_uniform.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate training data by uniform random sampling in Munsell space.
3
+
4
+ Samples hue, value, chroma, and code uniformly and converts to xyY via
5
+ ``munsell_specification_to_xyY``. Invalid specifications (out of gamut)
6
+ are discarded. This gives uniform coverage across the full Munsell space
7
+ without gaps between standard hue prefixes.
8
+
9
+ Usage:
10
+ uv run python -m learning_munsell.data_generation.generate_training_data_uniform
11
+ uv run python -m learning_munsell.data_generation.generate_training_data_uniform \\
12
+ --n-samples 2000000 --output training_data_large
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import logging
18
+ import multiprocessing as mp
19
+ import warnings
20
+ from datetime import UTC, datetime
21
+
22
+ import numpy as np
23
+ from colour.notation.munsell import munsell_specification_to_xyY
24
+ from colour.utilities import ColourUsageWarning
25
+ from numpy.typing import NDArray
26
+ from sklearn.model_selection import train_test_split
27
+
28
+ from learning_munsell import PROJECT_ROOT
29
+
30
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ def _worker_generate_samples(
35
+ args: tuple[int, int],
36
+ ) -> tuple[list[NDArray], list[NDArray]]:
37
+ """
38
+ Worker function to generate samples by uniform random sampling.
39
+
40
+ Parameters
41
+ ----------
42
+ args : tuple
43
+ - worker_id: Worker identifier
44
+ - n_samples: Number of valid samples to generate
45
+
46
+ Returns
47
+ -------
48
+ tuple
49
+ - xyY_samples: List of xyY arrays
50
+ - munsell_samples: List of Munsell specification arrays
51
+ """
52
+ worker_id, n_samples = args
53
+
54
+ np.random.seed(42 + worker_id)
55
+
56
+ warnings.filterwarnings("ignore", category=ColourUsageWarning)
57
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
58
+
59
+ codes = np.arange(1, 11, dtype=np.float64)
60
+
61
+ xyY_samples = []
62
+ munsell_samples = []
63
+
64
+ while len(xyY_samples) < n_samples:
65
+ spec = np.array(
66
+ [
67
+ np.random.uniform(0.0, 10.0),
68
+ np.random.uniform(1.0, 9.0),
69
+ np.random.uniform(0.0, 50.0),
70
+ codes[np.random.randint(len(codes))],
71
+ ]
72
+ )
73
+
74
+ try:
75
+ xyY = munsell_specification_to_xyY(spec)
76
+ xyY_samples.append(xyY)
77
+ munsell_samples.append(spec)
78
+ except Exception: # noqa: BLE001, S112
79
+ continue
80
+
81
+ return xyY_samples, munsell_samples
82
+
83
+
84
+ def generate_uniform_munsell_samples(
85
+ n_samples: int = 2000000,
86
+ n_workers: int | None = None,
87
+ ) -> tuple[NDArray, NDArray]:
88
+ """
89
+ Generate samples by uniform random sampling in Munsell space.
90
+
91
+ Parameters
92
+ ----------
93
+ n_samples : int
94
+ Target number of valid samples to generate.
95
+ n_workers : int, optional
96
+ Number of parallel workers. Defaults to CPU count.
97
+
98
+ Returns
99
+ -------
100
+ tuple
101
+ - xyY_samples: Array of shape (n, 3) with xyY values
102
+ - munsell_samples: Array of shape (n, 4) with Munsell specifications
103
+ """
104
+ if n_workers is None:
105
+ n_workers = mp.cpu_count()
106
+
107
+ samples_per_worker = n_samples // n_workers + 1
108
+
109
+ LOGGER.info(
110
+ "Generating %d samples using %d workers (%d per worker)...",
111
+ n_samples,
112
+ n_workers,
113
+ samples_per_worker,
114
+ )
115
+ LOGGER.info(
116
+ "Sampling uniformly: hue [0, 10], value [1, 9], chroma [0, 50], code [1, 10]"
117
+ )
118
+
119
+ worker_args = [(i, samples_per_worker) for i in range(n_workers)]
120
+
121
+ LOGGER.info("Starting %d parallel workers...", n_workers)
122
+ with mp.Pool(n_workers) as pool:
123
+ results = pool.map(_worker_generate_samples, worker_args)
124
+
125
+ all_xyY = []
126
+ all_munsell = []
127
+ for xyY_samples, munsell_samples in results:
128
+ all_xyY.extend(xyY_samples)
129
+ all_munsell.extend(munsell_samples)
130
+
131
+ all_xyY = all_xyY[:n_samples]
132
+ all_munsell = all_munsell[:n_samples]
133
+
134
+ LOGGER.info("Generated %d valid samples", len(all_xyY))
135
+ return np.array(all_xyY), np.array(all_munsell)
136
+
137
+
138
+ def main(
139
+ n_samples: int = 2000000,
140
+ output: str = "training_data_large",
141
+ ) -> None:
142
+ """Generate and save training data."""
143
+ LOGGER.info("=" * 80)
144
+ LOGGER.info("Training Data Generation (Uniform Sampling)")
145
+ LOGGER.info("=" * 80)
146
+
147
+ output_dir = PROJECT_ROOT / "data"
148
+ output_dir.mkdir(exist_ok=True)
149
+
150
+ LOGGER.info("")
151
+ LOGGER.info("SAMPLING STRATEGY")
152
+ LOGGER.info("=" * 80)
153
+ LOGGER.info("Uniform random sampling in Munsell space:")
154
+ LOGGER.info(" - Hue: uniform [0, 10]")
155
+ LOGGER.info(" - Value: uniform [1, 9]")
156
+ LOGGER.info(" - Chroma: uniform [0, 50]")
157
+ LOGGER.info(" - Code: uniform {1, 2, ..., 10}")
158
+ LOGGER.info(" - Target samples: %d", n_samples)
159
+ LOGGER.info("=" * 80)
160
+ LOGGER.info("")
161
+
162
+ xyY_all, munsell_all = generate_uniform_munsell_samples(
163
+ n_samples=n_samples,
164
+ )
165
+
166
+ LOGGER.info("")
167
+ LOGGER.info("Sample statistics:")
168
+ LOGGER.info(" Total samples: %d", len(xyY_all))
169
+
170
+ hue = munsell_all[:, 0]
171
+ LOGGER.info(" hue < 1: %d (%.1f%%)", (hue < 1).sum(), (hue < 1).mean() * 100)
172
+ LOGGER.info(" hue < 2: %d (%.1f%%)", (hue < 2).sum(), (hue < 2).mean() * 100)
173
+
174
+ # Split into train/validation/test (70/15/15)
175
+ X_temp, X_test, y_temp, y_test = train_test_split(
176
+ xyY_all, munsell_all, test_size=0.15, random_state=42
177
+ )
178
+ X_train, X_val, y_train, y_val = train_test_split(
179
+ X_temp, y_temp, test_size=0.15 / 0.85, random_state=42
180
+ )
181
+
182
+ LOGGER.info("")
183
+ LOGGER.info("Data split:")
184
+ LOGGER.info(" Train: %d samples", len(X_train))
185
+ LOGGER.info(" Validation: %d samples", len(X_val))
186
+ LOGGER.info(" Test: %d samples", len(X_test))
187
+
188
+ cache_file = output_dir / f"{output}.npz"
189
+ np.savez_compressed(
190
+ cache_file,
191
+ X_train=X_train,
192
+ y_train=y_train,
193
+ X_val=X_val,
194
+ y_val=y_val,
195
+ X_test=X_test,
196
+ y_test=y_test,
197
+ xyY_all=xyY_all,
198
+ munsell_all=munsell_all,
199
+ valid_mask=np.ones(len(xyY_all), dtype=bool),
200
+ )
201
+
202
+ params_file = output_dir / f"{output}_params.json"
203
+ params = {
204
+ "n_samples": n_samples,
205
+ "sampling": "uniform",
206
+ "hue_range": [0.0, 10.0],
207
+ "value_range": [1.0, 9.0],
208
+ "chroma_range": [0.0, 50.0],
209
+ "code_range": [1, 10],
210
+ "train_samples": len(X_train),
211
+ "val_samples": len(X_val),
212
+ "test_samples": len(X_test),
213
+ "generated_at": datetime.now(UTC).isoformat(),
214
+ }
215
+ with open(params_file, "w") as f:
216
+ json.dump(params, f, indent=2)
217
+
218
+ LOGGER.info("")
219
+ LOGGER.info("Training data saved to: %s", cache_file)
220
+ LOGGER.info("Parameters saved to: %s", params_file)
221
+ LOGGER.info("=" * 80)
222
+
223
+
224
+ if __name__ == "__main__":
225
+ parser = argparse.ArgumentParser(
226
+ description="Generate training data by uniform sampling in Munsell space"
227
+ )
228
+ parser.add_argument(
229
+ "--n-samples",
230
+ type=int,
231
+ default=2000000,
232
+ help="Number of valid samples to generate (default: 2000000)",
233
+ )
234
+ parser.add_argument(
235
+ "--output",
236
+ type=str,
237
+ default="training_data_large",
238
+ help="Output filename without extension (default: training_data_large)",
239
+ )
240
+ args = parser.parse_args()
241
+
242
+ main(
243
+ n_samples=args.n_samples,
244
+ output=args.output,
245
+ )
learning_munsell/interpolation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Interpolation-based methods for Munsell conversions."""
learning_munsell/interpolation/from_xyY/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interpolation-based methods for xyY to Munsell conversions."""
2
+
3
+ import numpy as np
4
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
5
+ from colour.notation.munsell import munsell_colour_to_munsell_specification
6
+ from numpy.typing import NDArray
7
+
8
+
9
+ def load_munsell_reference_data() -> tuple[NDArray, NDArray]:
10
+ """
11
+ Load reference Munsell data from colour library.
12
+
13
+ Returns xyY coordinates and corresponding Munsell specifications
14
+ [hue, value, chroma, code] for all 4,995 reference colors.
15
+
16
+ The Y values are normalized to [0, 1] range (originally 0-102.57).
17
+
18
+ Returns
19
+ -------
20
+ Tuple[NDArray, NDArray]
21
+ X : xyY values of shape (4995, 3) with Y normalized to [0, 1]
22
+ y : Munsell specifications of shape (4995, 4)
23
+ """
24
+ xyY_list = []
25
+ munsell_list = []
26
+
27
+ for munsell_tuple, xyY in MUNSELL_COLOURS_ALL:
28
+ hue_name, value, chroma = munsell_tuple
29
+ munsell_string = f"{hue_name} {value}/{chroma}"
30
+
31
+ # Convert to numeric specification [hue, value, chroma, code]
32
+ spec = munsell_colour_to_munsell_specification(munsell_string)
33
+
34
+ # Normalize Y to [0, 1] range (max ~102.57)
35
+ xyY_normalized = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
36
+
37
+ xyY_list.append(xyY_normalized)
38
+ munsell_list.append(spec)
39
+
40
+ return np.array(xyY_list), np.array(munsell_list)
41
+
42
+
43
+ __all__ = ["load_munsell_reference_data"]
learning_munsell/interpolation/from_xyY/compare_methods.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compare classical interpolation methods against the best ML model.
3
+
4
+ Evaluates RBF, KD-Tree, and Delaunay interpolation on REAL Munsell colors
5
+ and compares with the Multi-Head (W+B) + Multi-Error Predictor (W+B) model.
6
+ """
7
+
8
+ import logging
9
+
10
+ import numpy as np
11
+ import onnxruntime as ort
12
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
13
+ from colour.notation.munsell import munsell_colour_to_munsell_specification
14
+ from scipy.interpolate import LinearNDInterpolator, RBFInterpolator
15
+ from scipy.spatial import KDTree
16
+ from sklearn.model_selection import train_test_split
17
+
18
+ from learning_munsell import PROJECT_ROOT
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
21
+ LOGGER = logging.getLogger(__name__)
22
+
23
+
24
+ def load_reference_data() -> tuple[np.ndarray, np.ndarray]:
25
+ """Load ALL Munsell colors as training data for interpolators."""
26
+ X, y = [], []
27
+ for munsell_tuple, xyY in MUNSELL_COLOURS_ALL:
28
+ hue_name, value, chroma = munsell_tuple
29
+ munsell_str = f"{hue_name} {value}/{chroma}"
30
+ spec = munsell_colour_to_munsell_specification(munsell_str)
31
+ # Normalize Y to [0, 1]
32
+ X.append([xyY[0], xyY[1], xyY[2] / 100.0])
33
+ y.append(spec)
34
+ return np.array(X), np.array(y)
35
+
36
+
37
+ def evaluate(
38
+ predictions: np.ndarray, y_true: np.ndarray, method_name: str
39
+ ) -> dict[str, float]:
40
+ """Calculate MAE for each component."""
41
+ errors = np.abs(predictions - y_true)
42
+ results = {
43
+ "hue": errors[:, 0].mean(),
44
+ "value": errors[:, 1].mean(),
45
+ "chroma": errors[:, 2].mean(),
46
+ "code": errors[:, 3].mean(),
47
+ }
48
+ LOGGER.info(" %s:", method_name)
49
+ for comp in ["hue", "value", "chroma", "code"]:
50
+ LOGGER.info(" %s MAE: %.4f", comp.capitalize(), results[comp])
51
+ return results
52
+
53
+
54
+ def rbf_predict(
55
+ X_train: np.ndarray, y_train: np.ndarray, X_test: np.ndarray
56
+ ) -> np.ndarray:
57
+ """RBF interpolation prediction."""
58
+ predictions = np.zeros((len(X_test), 4))
59
+ for i in range(4):
60
+ rbf = RBFInterpolator(X_train, y_train[:, i], kernel="thin_plate_spline")
61
+ predictions[:, i] = rbf(X_test)
62
+ return predictions
63
+
64
+
65
+ def kdtree_predict(
66
+ X_train: np.ndarray,
67
+ y_train: np.ndarray,
68
+ X_test: np.ndarray,
69
+ k: int = 5,
70
+ ) -> np.ndarray:
71
+ """KD-Tree with inverse distance weighting prediction."""
72
+ tree = KDTree(X_train)
73
+ distances, indices = tree.query(X_test, k=k)
74
+ distances = np.maximum(distances, 1e-10)
75
+ weights = 1.0 / (distances**2)
76
+ weights /= weights.sum(axis=1, keepdims=True)
77
+
78
+ predictions = np.zeros((len(X_test), 4))
79
+ for i in range(len(X_test)):
80
+ predictions[i] = np.sum(weights[i, :, np.newaxis] * y_train[indices[i]], axis=0)
81
+ return predictions
82
+
83
+
84
+ def delaunay_predict(
85
+ X_train: np.ndarray, y_train: np.ndarray, X_test: np.ndarray
86
+ ) -> np.ndarray:
87
+ """Delaunay interpolation with NN fallback."""
88
+ predictions = np.zeros((len(X_test), 4))
89
+ tree = KDTree(X_train)
90
+
91
+ for i in range(4):
92
+ interp = LinearNDInterpolator(X_train, y_train[:, i])
93
+ predictions[:, i] = interp(X_test)
94
+
95
+ # Fallback to nearest neighbor for NaN
96
+ nan_mask = np.any(np.isnan(predictions), axis=1)
97
+ if nan_mask.sum() > 0:
98
+ _, indices = tree.query(X_test[nan_mask])
99
+ predictions[nan_mask] = y_train[indices]
100
+
101
+ return predictions
102
+
103
+
104
+ def ml_predict(X_test: np.ndarray) -> np.ndarray | None:
105
+ """ML model prediction using base + error predictor."""
106
+ base_path = (
107
+ PROJECT_ROOT / "models" / "from_xyY" / "multi_head_weighted_boundary.onnx"
108
+ )
109
+ error_path = (
110
+ PROJECT_ROOT
111
+ / "models"
112
+ / "from_xyY"
113
+ / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx"
114
+ )
115
+
116
+ if not base_path.exists() or not error_path.exists():
117
+ return None
118
+
119
+ # Input is already normalized to [0, 1] for x, y, Y
120
+ X_norm = X_test.astype(np.float32)
121
+
122
+ # Base model prediction
123
+ base_session = ort.InferenceSession(str(base_path))
124
+ base_out = base_session.run(None, {"xyY": X_norm})[0]
125
+
126
+ # Error predictor (takes xyY + base predictions)
127
+ error_session = ort.InferenceSession(str(error_path))
128
+ combined_input = np.concatenate([X_norm, base_out], axis=1).astype(np.float32)
129
+ error_out = error_session.run(None, {"combined_input": combined_input})[0]
130
+
131
+ # Combined prediction (normalized)
132
+ pred_norm = base_out + error_out
133
+
134
+ # Denormalize using actual ranges from params file
135
+ predictions = np.zeros_like(pred_norm)
136
+ predictions[:, 0] = pred_norm[:, 0] * (10.0 - 0.5) + 0.5 # Hue: [0.5, 10]
137
+ predictions[:, 1] = pred_norm[:, 1] * (10.0 - 0.0) + 0.0 # Value: [0, 10]
138
+ predictions[:, 2] = pred_norm[:, 2] * (50.0 - 0.0) + 0.0 # Chroma: [0, 50]
139
+ predictions[:, 3] = pred_norm[:, 3] * (10.0 - 1.0) + 1.0 # Code: [1, 10]
140
+
141
+ return predictions
142
+
143
+
144
+ def main() -> None:
145
+ """Compare all methods using held-out test set."""
146
+ LOGGER.info("=" * 80)
147
+ LOGGER.info("Classical Interpolation vs ML Model Comparison")
148
+ LOGGER.info("=" * 80)
149
+
150
+ LOGGER.info("")
151
+ LOGGER.info("Loading data...")
152
+ X_all, y_all = load_reference_data()
153
+
154
+ # 80/20 train/test split for fair comparison
155
+ X_train, X_test, y_train, y_test = train_test_split(
156
+ X_all, y_all, test_size=0.2, random_state=42
157
+ )
158
+ LOGGER.info(" Total: %d colors", len(X_all))
159
+ LOGGER.info(" Training: %d colors (80%%)", len(X_train))
160
+ LOGGER.info(" Test: %d colors (20%%)", len(X_test))
161
+
162
+ results = {}
163
+
164
+ # RBF
165
+ LOGGER.info("")
166
+ LOGGER.info("-" * 60)
167
+ LOGGER.info("RBF Interpolation (thin_plate_spline)")
168
+ rbf_pred = rbf_predict(X_train, y_train, X_test)
169
+ results["RBF"] = evaluate(rbf_pred, y_test, "RBF")
170
+
171
+ # KD-Tree
172
+ LOGGER.info("")
173
+ LOGGER.info("-" * 60)
174
+ LOGGER.info("KD-Tree Interpolation (k=5, IDW)")
175
+ kdt_pred = kdtree_predict(X_train, y_train, X_test, k=5)
176
+ results["KD-Tree"] = evaluate(kdt_pred, y_test, "KD-Tree")
177
+
178
+ # Delaunay
179
+ LOGGER.info("")
180
+ LOGGER.info("-" * 60)
181
+ LOGGER.info("Delaunay Interpolation (with NN fallback)")
182
+ del_pred = delaunay_predict(X_train, y_train, X_test)
183
+ results["Delaunay"] = evaluate(del_pred, y_test, "Delaunay")
184
+
185
+ # ML
186
+ LOGGER.info("")
187
+ LOGGER.info("-" * 60)
188
+ LOGGER.info("ML Model (Multi-Head W+B + Multi-Error Predictor W+B)")
189
+ ml_pred = ml_predict(X_test)
190
+ if ml_pred is not None:
191
+ results["ML"] = evaluate(ml_pred, y_test, "ML")
192
+ else:
193
+ LOGGER.info(" Skipped (model not found)")
194
+
195
+ # Summary
196
+ LOGGER.info("")
197
+ LOGGER.info("=" * 80)
198
+ LOGGER.info("SUMMARY (MAE on %d held-out test colors)", len(X_test))
199
+ LOGGER.info("=" * 80)
200
+ LOGGER.info("")
201
+ LOGGER.info("%-12s %8s %8s %8s %8s", "Method", "Hue", "Value", "Chroma", "Code")
202
+ LOGGER.info("-" * 52)
203
+
204
+ for method, mae in results.items():
205
+ LOGGER.info(
206
+ "%-12s %8.4f %8.4f %8.4f %8.4f",
207
+ method,
208
+ mae["hue"],
209
+ mae["value"],
210
+ mae["chroma"],
211
+ mae["code"],
212
+ )
213
+
214
+ LOGGER.info("")
215
+ LOGGER.info("=" * 80)
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
learning_munsell/interpolation/from_xyY/delaunay_interpolator.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Delaunay triangulation based interpolation for xyY to Munsell conversion.
3
+
4
+ This approach uses scipy's LinearNDInterpolator which performs piecewise
5
+ linear interpolation based on Delaunay triangulation.
6
+
7
+ Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
8
+
9
+ Advantages:
10
+ - Piecewise linear: exact at data points, linear between
11
+ - Handles irregular point distributions
12
+ - No hyperparameters to tune
13
+
14
+ Disadvantages:
15
+ - Returns NaN outside convex hull of data points
16
+ - Non-convex Munsell boundary may cause issues
17
+ - C0 continuous only (discontinuous gradients at cell boundaries)
18
+ """
19
+
20
+ import logging
21
+ import pickle
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ from numpy.typing import NDArray
26
+ from scipy.interpolate import LinearNDInterpolator
27
+ from scipy.spatial import KDTree
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from learning_munsell import PROJECT_ROOT, setup_logging
31
+ from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
32
+
33
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ class MunsellDelaunayInterpolator:
38
+ """
39
+ Delaunay triangulation based interpolator for xyY to Munsell conversion.
40
+
41
+ Uses LinearNDInterpolator for piecewise linear interpolation within
42
+ the Delaunay triangulation. Falls back to nearest neighbor for points
43
+ outside the convex hull.
44
+ """
45
+
46
+ def __init__(self, fallback_to_nearest: bool = True) -> None:
47
+ """
48
+ Initialize the Delaunay interpolator.
49
+
50
+ Parameters
51
+ ----------
52
+ fallback_to_nearest
53
+ If True, use nearest neighbor for points outside convex hull.
54
+ If False, return NaN for such points.
55
+ """
56
+ self.fallback_to_nearest = fallback_to_nearest
57
+ self.interpolators: dict = {}
58
+ self.kdtree: KDTree | None = None
59
+ self.y_data: NDArray | None = None
60
+ self.fitted = False
61
+
62
+ def fit(self, X: NDArray, y: NDArray) -> "MunsellDelaunayInterpolator":
63
+ """
64
+ Build the Delaunay interpolator from training data.
65
+
66
+ Parameters
67
+ ----------
68
+ X
69
+ xyY input values of shape (n, 3)
70
+ y
71
+ Munsell output values [hue, value, chroma, code] of shape (n, 4)
72
+
73
+ Returns
74
+ -------
75
+ self
76
+ """
77
+ LOGGER.info("Building Delaunay interpolator...")
78
+ LOGGER.info(" Fallback to nearest: %s", self.fallback_to_nearest)
79
+ LOGGER.info(" Data points: %d", len(X))
80
+
81
+ component_names = ["hue", "value", "chroma", "code"]
82
+
83
+ for i, name in enumerate(component_names):
84
+ LOGGER.info(" Building %s interpolator...", name)
85
+ self.interpolators[name] = LinearNDInterpolator(X, y[:, i])
86
+
87
+ # Build KDTree for nearest neighbor fallback
88
+ if self.fallback_to_nearest:
89
+ LOGGER.info(" Building KD-Tree for fallback...")
90
+ self.kdtree = KDTree(X)
91
+ self.y_data = y.copy()
92
+
93
+ self.fitted = True
94
+ LOGGER.info("Delaunay interpolator built successfully")
95
+ return self
96
+
97
+ def predict(self, X: NDArray) -> NDArray:
98
+ """
99
+ Predict Munsell values using Delaunay interpolation.
100
+
101
+ Parameters
102
+ ----------
103
+ X
104
+ xyY input values of shape (n, 3)
105
+
106
+ Returns
107
+ -------
108
+ NDArray
109
+ Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
110
+ """
111
+ if not self.fitted:
112
+ msg = "Interpolator not fitted. Call fit() first."
113
+ raise RuntimeError(msg)
114
+
115
+ results = np.zeros((len(X), 4))
116
+
117
+ for i, name in enumerate(["hue", "value", "chroma", "code"]):
118
+ results[:, i] = self.interpolators[name](X)
119
+
120
+ # Handle NaN values (points outside convex hull)
121
+ if self.fallback_to_nearest:
122
+ nan_mask = np.any(np.isnan(results), axis=1)
123
+ n_nan = nan_mask.sum()
124
+
125
+ if n_nan > 0:
126
+ LOGGER.debug(" %d points outside hull, using nearest neighbor", n_nan)
127
+ # Find nearest neighbors for NaN points
128
+ _, indices = self.kdtree.query(X[nan_mask])
129
+ results[nan_mask] = self.y_data[indices]
130
+
131
+ return results
132
+
133
+ def save(self, path: Path) -> None:
134
+ """Save the interpolator to disk."""
135
+ with open(path, "wb") as f:
136
+ pickle.dump(
137
+ {
138
+ "fallback_to_nearest": self.fallback_to_nearest,
139
+ "interpolators": self.interpolators,
140
+ "kdtree": self.kdtree,
141
+ "y_data": self.y_data,
142
+ },
143
+ f,
144
+ )
145
+ LOGGER.info("Saved Delaunay interpolator to %s", path)
146
+
147
+ @classmethod
148
+ def load(cls, path: Path) -> "MunsellDelaunayInterpolator":
149
+ """Load the interpolator from disk."""
150
+ with open(path, "rb") as f:
151
+ data = pickle.load(f) # noqa: S301
152
+
153
+ instance = cls(fallback_to_nearest=data["fallback_to_nearest"])
154
+ instance.interpolators = data["interpolators"]
155
+ instance.kdtree = data["kdtree"]
156
+ instance.y_data = data["y_data"]
157
+ instance.fitted = True
158
+
159
+ LOGGER.info("Loaded Delaunay interpolator from %s", path)
160
+ return instance
161
+
162
+
163
+ def evaluate_delaunay(
164
+ interpolator: MunsellDelaunayInterpolator,
165
+ X: NDArray,
166
+ y: NDArray,
167
+ name: str = "Test",
168
+ ) -> dict:
169
+ """Evaluate Delaunay interpolator performance."""
170
+ predictions = interpolator.predict(X)
171
+
172
+ # Check for NaN values
173
+ nan_count = np.isnan(predictions).any(axis=1).sum()
174
+ if nan_count > 0:
175
+ LOGGER.warning(" %d/%d predictions contain NaN", nan_count, len(X))
176
+
177
+ # Filter out NaN for error calculation
178
+ valid_mask = ~np.isnan(predictions).any(axis=1)
179
+ if valid_mask.sum() == 0:
180
+ LOGGER.error(" All predictions are NaN!")
181
+ return {
182
+ "hue": float("nan"),
183
+ "value": float("nan"),
184
+ "chroma": float("nan"),
185
+ "code": float("nan"),
186
+ }
187
+
188
+ errors = np.abs(predictions[valid_mask] - y[valid_mask])
189
+
190
+ component_names = ["Hue", "Value", "Chroma", "Code"]
191
+ results = {}
192
+
193
+ LOGGER.info("%s set MAE (%d/%d valid):", name, valid_mask.sum(), len(X))
194
+ for i, comp_name in enumerate(component_names):
195
+ mae = errors[:, i].mean()
196
+ results[comp_name.lower()] = mae
197
+ LOGGER.info(" %s: %.4f", comp_name, mae)
198
+
199
+ return results
200
+
201
+
202
+ def main() -> None:
203
+ """Build and evaluate Delaunay interpolator using reference Munsell data."""
204
+
205
+ log_file = setup_logging("delaunay_interpolator", "from_xyY")
206
+
207
+ LOGGER.info("=" * 80)
208
+ LOGGER.info("Delaunay Interpolation for xyY to Munsell Conversion")
209
+ LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
210
+ LOGGER.info("=" * 80)
211
+
212
+ # Load reference data from colour library
213
+ LOGGER.info("")
214
+ LOGGER.info("Loading reference Munsell data...")
215
+ X_all, y_all = load_munsell_reference_data()
216
+ LOGGER.info("Total reference colors: %d", len(X_all))
217
+
218
+ # Split into train/validation (80/20)
219
+ X_train, X_val, y_train, y_val = train_test_split(
220
+ X_all, y_all, test_size=0.2, random_state=42
221
+ )
222
+
223
+ LOGGER.info("Train samples: %d", len(X_train))
224
+ LOGGER.info("Validation samples: %d", len(X_val))
225
+
226
+ # Test with and without fallback
227
+ LOGGER.info("")
228
+ LOGGER.info("Testing Delaunay interpolation...")
229
+ LOGGER.info("-" * 60)
230
+
231
+ best_config = None
232
+ best_mae = float("inf")
233
+
234
+ for fallback in [True, False]:
235
+ LOGGER.info("")
236
+ LOGGER.info("Fallback to nearest: %s", fallback)
237
+
238
+ interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=fallback)
239
+ interpolator.fit(X_train, y_train)
240
+
241
+ results = evaluate_delaunay(interpolator, X_val, y_val, "Validation")
242
+
243
+ # Skip if results contain NaN
244
+ if any(np.isnan(v) for v in results.values()):
245
+ LOGGER.info(" Skipping due to NaN results")
246
+ continue
247
+
248
+ total_mae = sum(results.values())
249
+
250
+ if total_mae < best_mae:
251
+ best_mae = total_mae
252
+ best_config = fallback
253
+
254
+ LOGGER.info("")
255
+ LOGGER.info("=" * 60)
256
+ LOGGER.info("Best configuration: fallback_to_nearest=%s", best_config)
257
+ LOGGER.info("=" * 60)
258
+
259
+ # Train final model on ALL data
260
+ LOGGER.info("")
261
+ LOGGER.info("Training final model on all %d reference colors...", len(X_all))
262
+
263
+ final_interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=best_config)
264
+ final_interpolator.fit(X_all, y_all)
265
+
266
+ LOGGER.info("")
267
+ LOGGER.info("Final evaluation (training set = all data):")
268
+ evaluate_delaunay(final_interpolator, X_all, y_all, "All data")
269
+
270
+ # Save the model
271
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
272
+ model_dir.mkdir(parents=True, exist_ok=True)
273
+ model_path = model_dir / "delaunay_interpolator.pkl"
274
+ final_interpolator.save(model_path)
275
+
276
+ LOGGER.info("")
277
+ LOGGER.info("=" * 80)
278
+
279
+ log_file.close()
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()
learning_munsell/interpolation/from_xyY/kdtree_interpolator.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ KD-Tree based interpolation for xyY to Munsell conversion.
3
+
4
+ This approach uses scipy's KDTree for fast nearest neighbor lookups,
5
+ with optional weighted interpolation using k nearest neighbors.
6
+
7
+ Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
8
+
9
+ Advantages over RBF:
10
+ - O(n) memory, O(log n) query time
11
+ - Scales to millions of data points
12
+ - No matrix inversion required
13
+
14
+ Advantages over ML:
15
+ - Deterministic
16
+ - No training required
17
+ - Easy to understand
18
+ """
19
+
20
+ import logging
21
+ import pickle
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ from numpy.typing import NDArray
26
+ from scipy.spatial import KDTree
27
+ from sklearn.model_selection import train_test_split
28
+
29
+ from learning_munsell import PROJECT_ROOT, setup_logging
30
+ from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
31
+
32
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ class MunsellKDTreeInterpolator:
37
+ """
38
+ KD-Tree based interpolator for xyY to Munsell conversion.
39
+
40
+ Uses k-nearest neighbors with inverse distance weighting
41
+ for smooth interpolation.
42
+ """
43
+
44
+ def __init__(self, k: int = 5, power: float = 2.0) -> None:
45
+ """
46
+ Initialize the KD-Tree interpolator.
47
+
48
+ Parameters
49
+ ----------
50
+ k
51
+ Number of nearest neighbors to use for interpolation.
52
+ power
53
+ Power for inverse distance weighting. Higher = sharper.
54
+ """
55
+ self.k = k
56
+ self.power = power
57
+ self.tree: KDTree | None = None
58
+ self.y_data: NDArray | None = None
59
+ self.fitted = False
60
+
61
+ def fit(self, X: NDArray, y: NDArray) -> "MunsellKDTreeInterpolator":
62
+ """
63
+ Build the KD-Tree from training data.
64
+
65
+ Parameters
66
+ ----------
67
+ X
68
+ xyY input values of shape (n, 3)
69
+ y
70
+ Munsell output values [hue, value, chroma, code] of shape (n, 4)
71
+
72
+ Returns
73
+ -------
74
+ self
75
+ """
76
+ LOGGER.info("Building KD-Tree interpolator...")
77
+ LOGGER.info(" k neighbors: %d", self.k)
78
+ LOGGER.info(" IDW power: %.1f", self.power)
79
+ LOGGER.info(" Data points: %d", len(X))
80
+
81
+ self.tree = KDTree(X)
82
+ self.y_data = y.copy()
83
+ self.fitted = True
84
+
85
+ LOGGER.info("KD-Tree built successfully")
86
+ return self
87
+
88
+ def predict(self, X: NDArray) -> NDArray:
89
+ """
90
+ Predict Munsell values using k-NN with IDW.
91
+
92
+ Parameters
93
+ ----------
94
+ X
95
+ xyY input values of shape (n, 3)
96
+
97
+ Returns
98
+ -------
99
+ NDArray
100
+ Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
101
+ """
102
+ if not self.fitted:
103
+ msg = "Interpolator not fitted. Call fit() first."
104
+ raise RuntimeError(msg)
105
+
106
+ # Query k nearest neighbors
107
+ distances, indices = self.tree.query(X, k=self.k)
108
+
109
+ # Ensure 2D arrays for consistent handling
110
+ if self.k == 1:
111
+ distances = distances.reshape(-1, 1)
112
+ indices = indices.reshape(-1, 1)
113
+
114
+ # Inverse distance weighting
115
+ # Avoid division by zero
116
+ distances = np.maximum(distances, 1e-10)
117
+ weights = 1.0 / (distances**self.power)
118
+ weights /= weights.sum(axis=1, keepdims=True)
119
+
120
+ # Weighted average of neighbor values
121
+ results = np.zeros((len(X), 4))
122
+ for i in range(len(X)):
123
+ neighbor_values = self.y_data[indices[i]]
124
+ if self.k == 1:
125
+ results[i] = neighbor_values.flatten()
126
+ else:
127
+ results[i] = np.sum(weights[i, :, np.newaxis] * neighbor_values, axis=0)
128
+
129
+ return results
130
+
131
+ def save(self, path: Path) -> None:
132
+ """Save the interpolator to disk."""
133
+ with open(path, "wb") as f:
134
+ pickle.dump(
135
+ {
136
+ "k": self.k,
137
+ "power": self.power,
138
+ "tree": self.tree,
139
+ "y_data": self.y_data,
140
+ },
141
+ f,
142
+ )
143
+ LOGGER.info("Saved KD-Tree interpolator to %s", path)
144
+
145
+ @classmethod
146
+ def load(cls, path: Path) -> "MunsellKDTreeInterpolator":
147
+ """Load the interpolator from disk."""
148
+ with open(path, "rb") as f:
149
+ data = pickle.load(f) # noqa: S301
150
+
151
+ instance = cls(k=data["k"], power=data["power"])
152
+ instance.tree = data["tree"]
153
+ instance.y_data = data["y_data"]
154
+ instance.fitted = True
155
+
156
+ LOGGER.info("Loaded KD-Tree interpolator from %s", path)
157
+ return instance
158
+
159
+
160
+ def evaluate_kdtree(
161
+ interpolator: MunsellKDTreeInterpolator,
162
+ X: NDArray,
163
+ y: NDArray,
164
+ name: str = "Test",
165
+ ) -> dict:
166
+ """Evaluate KD-Tree interpolator performance."""
167
+ predictions = interpolator.predict(X)
168
+ errors = np.abs(predictions - y)
169
+
170
+ component_names = ["Hue", "Value", "Chroma", "Code"]
171
+ results = {}
172
+
173
+ LOGGER.info("%s set MAE:", name)
174
+ for i, comp_name in enumerate(component_names):
175
+ mae = errors[:, i].mean()
176
+ results[comp_name.lower()] = mae
177
+ LOGGER.info(" %s: %.4f", comp_name, mae)
178
+
179
+ return results
180
+
181
+
182
+ def main() -> None:
183
+ """Build and evaluate KD-Tree interpolator using reference Munsell data."""
184
+
185
+ log_file = setup_logging("kdtree_interpolator", "from_xyY")
186
+
187
+ LOGGER.info("=" * 80)
188
+ LOGGER.info("KD-Tree Interpolation for xyY to Munsell Conversion")
189
+ LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
190
+ LOGGER.info("=" * 80)
191
+
192
+ # Load reference data from colour library
193
+ LOGGER.info("")
194
+ LOGGER.info("Loading reference Munsell data...")
195
+ X_all, y_all = load_munsell_reference_data()
196
+ LOGGER.info("Total reference colors: %d", len(X_all))
197
+
198
+ # Split into train/validation (80/20)
199
+ X_train, X_val, y_train, y_val = train_test_split(
200
+ X_all, y_all, test_size=0.2, random_state=42
201
+ )
202
+
203
+ LOGGER.info("Train samples: %d", len(X_train))
204
+ LOGGER.info("Validation samples: %d", len(X_val))
205
+
206
+ # Test different k values
207
+ k_values = [1, 3, 5, 10, 20, 50]
208
+
209
+ best_k = None
210
+ best_mae = float("inf")
211
+
212
+ LOGGER.info("")
213
+ LOGGER.info("Testing different k values...")
214
+ LOGGER.info("-" * 60)
215
+
216
+ for k in k_values:
217
+ LOGGER.info("")
218
+ LOGGER.info("k = %d:", k)
219
+
220
+ interpolator = MunsellKDTreeInterpolator(k=k, power=2.0)
221
+ interpolator.fit(X_train, y_train)
222
+
223
+ results = evaluate_kdtree(interpolator, X_val, y_val, "Validation")
224
+ total_mae = sum(results.values())
225
+
226
+ if total_mae < best_mae:
227
+ best_mae = total_mae
228
+ best_k = k
229
+
230
+ LOGGER.info("")
231
+ LOGGER.info("=" * 60)
232
+ LOGGER.info("Best k: %d", best_k)
233
+ LOGGER.info("=" * 60)
234
+
235
+ # Train final model with best k on ALL data
236
+ LOGGER.info("")
237
+ LOGGER.info(
238
+ "Training final model on all %d reference colors with k=%d...",
239
+ len(X_all),
240
+ best_k,
241
+ )
242
+
243
+ final_interpolator = MunsellKDTreeInterpolator(k=best_k, power=2.0)
244
+ final_interpolator.fit(X_all, y_all)
245
+
246
+ LOGGER.info("")
247
+ LOGGER.info("Final evaluation (training set = all data):")
248
+ evaluate_kdtree(final_interpolator, X_all, y_all, "All data")
249
+
250
+ # Save the model
251
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
252
+ model_dir.mkdir(parents=True, exist_ok=True)
253
+ model_path = model_dir / "kdtree_interpolator.pkl"
254
+ final_interpolator.save(model_path)
255
+
256
+ LOGGER.info("")
257
+ LOGGER.info("=" * 80)
258
+
259
+ log_file.close()
260
+
261
+
262
+ if __name__ == "__main__":
263
+ main()
learning_munsell/interpolation/from_xyY/rbf_interpolator.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RBF (Radial Basis Function) interpolation for xyY to Munsell conversion.
3
+
4
+ This approach uses scipy's RBFInterpolator to build a lookup table
5
+ with smooth interpolation between known color samples.
6
+
7
+ Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
8
+
9
+ Advantages over ML:
10
+ - Deterministic, no training required
11
+ - Exact interpolation at known points
12
+ - Smooth interpolation between points
13
+ - Easy to understand and debug
14
+
15
+ Disadvantages:
16
+ - Memory scales with number of data points
17
+ - Query time scales with data points (O(n) naive, can optimize)
18
+ - May struggle with extrapolation
19
+ """
20
+
21
+ import logging
22
+ import pickle
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ from numpy.typing import NDArray
27
+ from scipy.interpolate import RBFInterpolator
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from learning_munsell import PROJECT_ROOT, setup_logging
31
+ from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
32
+
33
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ class MunsellRBFInterpolator:
38
+ """
39
+ RBF-based interpolator for xyY to Munsell conversion.
40
+
41
+ Uses separate RBF interpolators for each Munsell component
42
+ (hue, value, chroma, code) to allow independent kernel tuning.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ kernel: str = "thin_plate_spline",
48
+ smoothing: float = 0.0,
49
+ epsilon: float | None = None,
50
+ ) -> None:
51
+ """
52
+ Initialize the RBF interpolator.
53
+
54
+ Parameters
55
+ ----------
56
+ kernel
57
+ RBF kernel type. Options: 'linear', 'thin_plate_spline',
58
+ 'cubic', 'quintic', 'multiquadric', 'inverse_multiquadric',
59
+ 'inverse_quadratic', 'gaussian'
60
+ smoothing
61
+ Smoothing parameter. 0 = exact interpolation.
62
+ epsilon
63
+ Shape parameter for kernels that use it.
64
+ """
65
+ self.kernel = kernel
66
+ self.smoothing = smoothing
67
+ self.epsilon = epsilon
68
+
69
+ self.interpolators: dict[str, RBFInterpolator] = {}
70
+ self.fitted = False
71
+
72
+ def fit(self, X: NDArray, y: NDArray) -> "MunsellRBFInterpolator":
73
+ """
74
+ Fit RBF interpolators to the training data.
75
+
76
+ Parameters
77
+ ----------
78
+ X
79
+ xyY input values of shape (n, 3)
80
+ y
81
+ Munsell output values [hue, value, chroma, code] of shape (n, 4)
82
+
83
+ Returns
84
+ -------
85
+ self
86
+ """
87
+ LOGGER.info("Fitting RBF interpolators...")
88
+ LOGGER.info(" Kernel: %s", self.kernel)
89
+ LOGGER.info(" Smoothing: %s", self.smoothing)
90
+ LOGGER.info(" Data points: %d", len(X))
91
+
92
+ component_names = ["hue", "value", "chroma", "code"]
93
+
94
+ for i, name in enumerate(component_names):
95
+ LOGGER.info(" Building %s interpolator...", name)
96
+
97
+ kwargs = {
98
+ "kernel": self.kernel,
99
+ "smoothing": self.smoothing,
100
+ }
101
+ if self.epsilon is not None:
102
+ kwargs["epsilon"] = self.epsilon
103
+
104
+ self.interpolators[name] = RBFInterpolator(X, y[:, i], **kwargs)
105
+
106
+ self.fitted = True
107
+ LOGGER.info("RBF interpolators fitted successfully")
108
+
109
+ return self
110
+
111
+ def predict(self, X: NDArray) -> NDArray:
112
+ """
113
+ Predict Munsell values for given xyY inputs.
114
+
115
+ Parameters
116
+ ----------
117
+ X
118
+ xyY input values of shape (n, 3)
119
+
120
+ Returns
121
+ -------
122
+ NDArray
123
+ Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
124
+ """
125
+ if not self.fitted:
126
+ msg = "Interpolator not fitted. Call fit() first."
127
+ raise RuntimeError(msg)
128
+
129
+ results = np.zeros((len(X), 4))
130
+
131
+ for i, name in enumerate(["hue", "value", "chroma", "code"]):
132
+ results[:, i] = self.interpolators[name](X)
133
+
134
+ return results
135
+
136
+ def save(self, path: Path) -> None:
137
+ """Save the interpolator to disk."""
138
+ with open(path, "wb") as f:
139
+ pickle.dump(
140
+ {
141
+ "kernel": self.kernel,
142
+ "smoothing": self.smoothing,
143
+ "epsilon": self.epsilon,
144
+ "interpolators": self.interpolators,
145
+ },
146
+ f,
147
+ )
148
+ LOGGER.info("Saved RBF interpolator to %s", path)
149
+
150
+ @classmethod
151
+ def load(cls, path: Path) -> "MunsellRBFInterpolator":
152
+ """Load the interpolator from disk."""
153
+ with open(path, "rb") as f:
154
+ data = pickle.load(f) # noqa: S301
155
+
156
+ instance = cls(
157
+ kernel=data["kernel"],
158
+ smoothing=data["smoothing"],
159
+ epsilon=data["epsilon"],
160
+ )
161
+ instance.interpolators = data["interpolators"]
162
+ instance.fitted = True
163
+
164
+ LOGGER.info("Loaded RBF interpolator from %s", path)
165
+ return instance
166
+
167
+
168
+ def evaluate_rbf(
169
+ interpolator: MunsellRBFInterpolator,
170
+ X: NDArray,
171
+ y: NDArray,
172
+ name: str = "Test",
173
+ ) -> dict[str, float]:
174
+ """
175
+ Evaluate RBF interpolator performance.
176
+
177
+ Parameters
178
+ ----------
179
+ interpolator
180
+ Fitted RBF interpolator
181
+ X
182
+ Input xyY values
183
+ y
184
+ Ground truth Munsell values
185
+ name
186
+ Name for logging
187
+
188
+ Returns
189
+ -------
190
+ dict
191
+ Dictionary of MAE values for each component
192
+ """
193
+ predictions = interpolator.predict(X)
194
+ errors = np.abs(predictions - y)
195
+
196
+ component_names = ["Hue", "Value", "Chroma", "Code"]
197
+ results = {}
198
+
199
+ LOGGER.info("%s set MAE:", name)
200
+ for i, comp_name in enumerate(component_names):
201
+ mae = errors[:, i].mean()
202
+ results[comp_name.lower()] = mae
203
+ LOGGER.info(" %s: %.4f", comp_name, mae)
204
+
205
+ return results
206
+
207
+
208
+ def main() -> None:
209
+ """Build and evaluate RBF interpolator using reference Munsell data."""
210
+
211
+ log_file = setup_logging("rbf_interpolator", "from_xyY")
212
+
213
+ LOGGER.info("=" * 80)
214
+ LOGGER.info("RBF Interpolation for xyY to Munsell Conversion")
215
+ LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
216
+ LOGGER.info("=" * 80)
217
+
218
+ # Load reference data from colour library
219
+ LOGGER.info("")
220
+ LOGGER.info("Loading reference Munsell data...")
221
+ X_all, y_all = load_munsell_reference_data()
222
+ LOGGER.info("Total reference colors: %d", len(X_all))
223
+
224
+ # Split into train/validation (80/20)
225
+ X_train, X_val, y_train, y_val = train_test_split(
226
+ X_all, y_all, test_size=0.2, random_state=42
227
+ )
228
+
229
+ LOGGER.info("Train samples: %d", len(X_train))
230
+ LOGGER.info("Validation samples: %d", len(X_val))
231
+
232
+ # Test different kernels
233
+ kernels_to_test = [
234
+ ("thin_plate_spline", 0.0),
235
+ ("thin_plate_spline", 0.001),
236
+ ("thin_plate_spline", 0.01),
237
+ ("cubic", 0.0),
238
+ ("linear", 0.0),
239
+ ("multiquadric", 0.0),
240
+ ]
241
+
242
+ best_kernel = None
243
+ best_smoothing = None
244
+ best_mae = float("inf")
245
+
246
+ LOGGER.info("")
247
+ LOGGER.info("Testing different RBF kernels...")
248
+ LOGGER.info("-" * 60)
249
+
250
+ for kernel, smoothing in kernels_to_test:
251
+ LOGGER.info("")
252
+ LOGGER.info("Kernel: %s, Smoothing: %s", kernel, smoothing)
253
+
254
+ try:
255
+ interpolator = MunsellRBFInterpolator(kernel=kernel, smoothing=smoothing)
256
+ interpolator.fit(X_train, y_train)
257
+
258
+ results = evaluate_rbf(interpolator, X_val, y_val, "Validation")
259
+ total_mae = sum(results.values())
260
+
261
+ if total_mae < best_mae:
262
+ best_mae = total_mae
263
+ best_kernel = kernel
264
+ best_smoothing = smoothing
265
+
266
+ except Exception:
267
+ LOGGER.exception(" Failed")
268
+
269
+ LOGGER.info("")
270
+ LOGGER.info("=" * 60)
271
+ LOGGER.info("Best configuration: %s with smoothing=%s", best_kernel, best_smoothing)
272
+ LOGGER.info("=" * 60)
273
+
274
+ # Train final model with best kernel on ALL data
275
+ LOGGER.info("")
276
+ LOGGER.info("Training final model on all %d reference colors...", len(X_all))
277
+
278
+ final_interpolator = MunsellRBFInterpolator(
279
+ kernel=best_kernel, smoothing=best_smoothing
280
+ )
281
+ final_interpolator.fit(X_all, y_all)
282
+
283
+ LOGGER.info("")
284
+ LOGGER.info("Final evaluation (training set = all data):")
285
+ evaluate_rbf(final_interpolator, X_all, y_all, "All data")
286
+
287
+ # Save the model
288
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
289
+ model_dir.mkdir(parents=True, exist_ok=True)
290
+ model_path = model_dir / "rbf_interpolator.pkl"
291
+ final_interpolator.save(model_path)
292
+
293
+ LOGGER.info("")
294
+ LOGGER.info("=" * 80)
295
+
296
+ log_file.close()
297
+
298
+
299
+ if __name__ == "__main__":
300
+ main()
learning_munsell/losses/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loss functions for Munsell ML training."""
2
+
3
+ from learning_munsell.losses.jax_delta_e import (
4
+ XYZ_to_Lab,
5
+ delta_E_CIE2000,
6
+ delta_E_loss,
7
+ xyY_to_Lab,
8
+ xyY_to_XYZ,
9
+ )
10
+
11
+ __all__ = [
12
+ "delta_E_CIE2000",
13
+ "delta_E_loss",
14
+ "xyY_to_Lab",
15
+ "xyY_to_XYZ",
16
+ "XYZ_to_Lab",
17
+ ]
learning_munsell/losses/jax_delta_e.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Differentiable Delta-E Loss Functions using JAX
3
+ ================================================
4
+
5
+ This module provides JAX implementations of color space conversions
6
+ and Delta-E (CIE2000) loss function for use in training.
7
+
8
+ The key insight is that we can compute Delta-E between:
9
+ - The input xyY (which we convert to Lab as the "target")
10
+ - The predicted Munsell converted back to Lab
11
+
12
+ For the Munsell -> xyY conversion, we either:
13
+ 1. Use a pre-trained neural network approximator
14
+ 2. Use differentiable interpolation on the Munsell Renotation data
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import colour
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from jax import Array
24
+
25
+ # D65 illuminant XYZ reference values (standard for sRGB)
26
+ D65_XYZ = jnp.array([95.047, 100.0, 108.883])
27
+
28
+ # Illuminant C XYZ reference values (used by Munsell system)
29
+ ILLUMINANT_C_XYZ = jnp.array([98.074, 100.0, 118.232])
30
+
31
+
32
+ def xyY_to_XYZ(xyY: Array, scale_Y: bool = True) -> Array:
33
+ """
34
+ Convert CIE xyY to CIE XYZ.
35
+
36
+ Parameters
37
+ ----------
38
+ xyY : Array
39
+ CIE xyY values with shape (..., 3)
40
+ scale_Y : bool
41
+ If True, scale Y from 0-1 to 0-100 range (required for Lab conversion)
42
+
43
+ Returns
44
+ -------
45
+ Array
46
+ CIE XYZ values with shape (..., 3)
47
+ """
48
+ x = xyY[..., 0]
49
+ y = xyY[..., 1]
50
+ Y = xyY[..., 2]
51
+
52
+ # Scale Y to 0-100 range if needed (colour library uses 0-100)
53
+ if scale_Y:
54
+ Y = Y * 100.0
55
+
56
+ # Avoid division by zero
57
+ y_safe = jnp.where(y == 0, 1e-10, y)
58
+
59
+ X = (x * Y) / y_safe
60
+ Z = ((1 - x - y) * Y) / y_safe
61
+
62
+ # Handle y=0 case (set X=Z=0)
63
+ X = jnp.where(y == 0, 0.0, X)
64
+ Z = jnp.where(y == 0, 0.0, Z)
65
+
66
+ return jnp.stack([X, Y, Z], axis=-1)
67
+
68
+
69
+ def XYZ_to_Lab(XYZ: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array:
70
+ """
71
+ Convert CIE XYZ to CIE Lab.
72
+
73
+ Parameters
74
+ ----------
75
+ XYZ : Array
76
+ CIE XYZ values with shape (..., 3)
77
+ illuminant : Array
78
+ Reference white XYZ values
79
+
80
+ Returns
81
+ -------
82
+ Array
83
+ CIE Lab values with shape (..., 3)
84
+ """
85
+ # Normalize by illuminant
86
+ XYZ_n = XYZ / illuminant
87
+
88
+ # CIE Lab transfer function
89
+ delta = 6.0 / 29.0
90
+ delta_cube = delta**3
91
+
92
+ # f(t) = t^(1/3) if t > delta^3, else t/(3*delta^2) + 4/29
93
+ def f(t: Array) -> Array:
94
+ return jnp.where(t > delta_cube, jnp.cbrt(t), t / (3 * delta**2) + 4.0 / 29.0)
95
+
96
+ f_X = f(XYZ_n[..., 0])
97
+ f_Y = f(XYZ_n[..., 1])
98
+ f_Z = f(XYZ_n[..., 2])
99
+
100
+ L = 116.0 * f_Y - 16.0
101
+ a = 500.0 * (f_X - f_Y)
102
+ b = 200.0 * (f_Y - f_Z)
103
+
104
+ return jnp.stack([L, a, b], axis=-1)
105
+
106
+
107
+ def xyY_to_Lab(xyY: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array:
108
+ """Convert CIE xyY directly to CIE Lab."""
109
+ return XYZ_to_Lab(xyY_to_XYZ(xyY), illuminant)
110
+
111
+
112
+ def delta_E_CIE2000(Lab_1: Array, Lab_2: Array) -> Array:
113
+ """
114
+ Compute CIE 2000 Delta-E color difference.
115
+
116
+ This is a differentiable JAX implementation of the CIE 2000 Delta-E formula.
117
+
118
+ Parameters
119
+ ----------
120
+ Lab_1 : Array
121
+ First CIE Lab color(s) with shape (..., 3)
122
+ Lab_2 : Array
123
+ Second CIE Lab color(s) with shape (..., 3)
124
+
125
+ Returns
126
+ -------
127
+ Array
128
+ Delta-E values with shape (...)
129
+ """
130
+ L_1, a_1, b_1 = Lab_1[..., 0], Lab_1[..., 1], Lab_1[..., 2]
131
+ L_2, a_2, b_2 = Lab_2[..., 0], Lab_2[..., 1], Lab_2[..., 2]
132
+
133
+ # Chroma
134
+ C_1_ab = jnp.sqrt(a_1**2 + b_1**2)
135
+ C_2_ab = jnp.sqrt(a_2**2 + b_2**2)
136
+
137
+ C_bar_ab = (C_1_ab + C_2_ab) / 2
138
+ C_bar_ab_7 = C_bar_ab**7
139
+
140
+ # G factor for a' adjustment (25^7 = 6103515625.0)
141
+ G = 0.5 * (1 - jnp.sqrt(C_bar_ab_7 / (C_bar_ab_7 + 6103515625.0)))
142
+
143
+ # Adjusted a'
144
+ a_p_1 = (1 + G) * a_1
145
+ a_p_2 = (1 + G) * a_2
146
+
147
+ # Adjusted chroma C'
148
+ C_p_1 = jnp.sqrt(a_p_1**2 + b_1**2)
149
+ C_p_2 = jnp.sqrt(a_p_2**2 + b_2**2)
150
+
151
+ # Hue angle h' (in degrees)
152
+ h_p_1 = jnp.degrees(jnp.arctan2(b_1, a_p_1)) % 360
153
+ h_p_2 = jnp.degrees(jnp.arctan2(b_2, a_p_2)) % 360
154
+
155
+ # Handle achromatic case
156
+ h_p_1 = jnp.where((b_1 == 0) & (a_p_1 == 0), 0.0, h_p_1)
157
+ h_p_2 = jnp.where((b_2 == 0) & (a_p_2 == 0), 0.0, h_p_2)
158
+
159
+ # Delta L', C'
160
+ delta_L_p = L_2 - L_1
161
+ delta_C_p = C_p_2 - C_p_1
162
+
163
+ # Delta h'
164
+ h_p_diff = h_p_2 - h_p_1
165
+ C_p_product = C_p_1 * C_p_2
166
+
167
+ delta_h_p = jnp.where(
168
+ C_p_product == 0,
169
+ 0.0,
170
+ jnp.where(
171
+ jnp.abs(h_p_diff) <= 180,
172
+ h_p_diff,
173
+ jnp.where(h_p_diff > 180, h_p_diff - 360, h_p_diff + 360),
174
+ ),
175
+ )
176
+
177
+ # Delta H'
178
+ delta_H_p = 2 * jnp.sqrt(C_p_product) * jnp.sin(jnp.radians(delta_h_p / 2))
179
+
180
+ # Mean L', C'
181
+ L_bar_p = (L_1 + L_2) / 2
182
+ C_bar_p = (C_p_1 + C_p_2) / 2
183
+
184
+ # Mean h'
185
+ h_p_sum = h_p_1 + h_p_2
186
+ h_p_abs_diff = jnp.abs(h_p_1 - h_p_2)
187
+
188
+ h_bar_p = jnp.where(
189
+ C_p_product == 0,
190
+ h_p_sum,
191
+ jnp.where(
192
+ h_p_abs_diff <= 180,
193
+ h_p_sum / 2,
194
+ jnp.where(h_p_sum < 360, (h_p_sum + 360) / 2, (h_p_sum - 360) / 2),
195
+ ),
196
+ )
197
+
198
+ # T factor
199
+ T = (
200
+ 1
201
+ - 0.17 * jnp.cos(jnp.radians(h_bar_p - 30))
202
+ + 0.24 * jnp.cos(jnp.radians(2 * h_bar_p))
203
+ + 0.32 * jnp.cos(jnp.radians(3 * h_bar_p + 6))
204
+ - 0.20 * jnp.cos(jnp.radians(4 * h_bar_p - 63))
205
+ )
206
+
207
+ # Delta theta
208
+ delta_theta = 30 * jnp.exp(-(((h_bar_p - 275) / 25) ** 2))
209
+
210
+ # R_C (25^7 = 6103515625.0)
211
+ C_bar_p_7 = C_bar_p**7
212
+ R_C = 2 * jnp.sqrt(C_bar_p_7 / (C_bar_p_7 + 6103515625.0))
213
+
214
+ # S_L, S_C, S_H
215
+ L_bar_p_minus_50_sq = (L_bar_p - 50) ** 2
216
+ S_L = 1 + (0.015 * L_bar_p_minus_50_sq) / jnp.sqrt(20 + L_bar_p_minus_50_sq)
217
+ S_C = 1 + 0.045 * C_bar_p
218
+ S_H = 1 + 0.015 * C_bar_p * T
219
+
220
+ # R_T
221
+ R_T = -jnp.sin(jnp.radians(2 * delta_theta)) * R_C
222
+
223
+ # Final Delta E
224
+ k_L, k_C, k_H = 1.0, 1.0, 1.0
225
+
226
+ term_L = delta_L_p / (k_L * S_L)
227
+ term_C = delta_C_p / (k_C * S_C)
228
+ term_H = delta_H_p / (k_H * S_H)
229
+
230
+ return jnp.sqrt(term_L**2 + term_C**2 + term_H**2 + R_T * term_C * term_H)
231
+
232
+
233
+ def delta_E_loss(pred_xyY: Array, target_xyY: Array) -> Array:
234
+ """
235
+ Compute mean Delta-E loss between predicted and target xyY values.
236
+
237
+ This is the primary loss function for training with perceptual accuracy.
238
+
239
+ Parameters
240
+ ----------
241
+ pred_xyY : Array
242
+ Predicted xyY values with shape (batch, 3)
243
+ target_xyY : Array
244
+ Target xyY values with shape (batch, 3)
245
+
246
+ Returns
247
+ -------
248
+ Array
249
+ Scalar mean Delta-E loss
250
+ """
251
+ pred_Lab = xyY_to_Lab(pred_xyY)
252
+ target_Lab = xyY_to_Lab(target_xyY)
253
+ return jnp.mean(delta_E_CIE2000(pred_Lab, target_Lab))
254
+
255
+
256
+ # JIT-compiled versions for performance
257
+ xyY_to_XYZ_jit = jax.jit(xyY_to_XYZ)
258
+ XYZ_to_Lab_jit = jax.jit(XYZ_to_Lab)
259
+ xyY_to_Lab_jit = jax.jit(xyY_to_Lab)
260
+ delta_E_CIE2000_jit = jax.jit(delta_E_CIE2000)
261
+ delta_E_loss_jit = jax.jit(delta_E_loss)
262
+
263
+ # Gradient functions
264
+ grad_delta_E_loss = jax.grad(delta_E_loss)
265
+
266
+
267
+ def test_jax_delta_e() -> None:
268
+ """Test the JAX Delta-E implementation against colour library."""
269
+ # Test xyY values
270
+ xyY_1 = np.array([0.3127, 0.3290, 0.5]) # D65 white point, Y=0.5
271
+ xyY_2 = np.array([0.35, 0.35, 0.5]) # Slightly shifted
272
+
273
+ # Convert using JAX
274
+ Lab_1_jax = xyY_to_Lab(jnp.array(xyY_1))
275
+ Lab_2_jax = xyY_to_Lab(jnp.array(xyY_2))
276
+ delta_E_CIE2000(Lab_1_jax, Lab_2_jax)
277
+
278
+ # Convert using colour library
279
+ XYZ_1 = colour.xyY_to_XYZ(xyY_1)
280
+ XYZ_2 = colour.xyY_to_XYZ(xyY_2)
281
+ Lab_1_colour = colour.XYZ_to_Lab(
282
+ XYZ_1, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"]
283
+ )
284
+ Lab_2_colour = colour.XYZ_to_Lab(
285
+ XYZ_2, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"]
286
+ )
287
+ colour.delta_E(Lab_1_colour, Lab_2_colour, method="CIE 2000")
288
+
289
+ # Test gradient computation
290
+ pred_xyY = jnp.array([[0.35, 0.35, 0.5]])
291
+ target_xyY = jnp.array([[0.3127, 0.3290, 0.5]])
292
+
293
+ # Compute gradient
294
+ grad_fn = jax.grad(lambda x: delta_E_loss(x, target_xyY))
295
+ grad_fn(pred_xyY)
296
+
297
+
298
+ if __name__ == "__main__":
299
+ test_jax_delta_e()
learning_munsell/models/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Neural network models for Munsell color conversions."""
2
+
3
+ from learning_munsell.models.networks import (
4
+ ComponentErrorPredictor,
5
+ ComponentMLP,
6
+ FeatureTokenizer,
7
+ MLPToMunsell,
8
+ MultiHeadErrorPredictorToMunsell,
9
+ MultiHeadMLPToMunsell,
10
+ MultiMLPClassCodeErrorPredictorToMunsell,
11
+ MultiMLPClassCodeToMunsell,
12
+ MultiMLPErrorPredictorToMunsell,
13
+ MultiMLPErrorPredictorToxyY,
14
+ MultiMLPHueAngleToMunsell,
15
+ MultiMLPToMunsell,
16
+ MultiMLPToxyY,
17
+ ResidualBlock,
18
+ TransformerBlock,
19
+ TransformerToMunsell,
20
+ )
21
+
22
+ __all__ = [
23
+ # Building blocks
24
+ "ResidualBlock",
25
+ # Component networks (single output)
26
+ "ComponentMLP",
27
+ "ComponentErrorPredictor",
28
+ # Transformer building blocks
29
+ "FeatureTokenizer",
30
+ "TransformerBlock",
31
+ # Composite models: xyY → Munsell
32
+ "MLPToMunsell",
33
+ "MultiHeadMLPToMunsell",
34
+ "MultiMLPToMunsell",
35
+ "MultiMLPClassCodeToMunsell",
36
+ "MultiMLPHueAngleToMunsell",
37
+ "TransformerToMunsell",
38
+ # Error predictors: xyY → Munsell
39
+ "MultiHeadErrorPredictorToMunsell",
40
+ "MultiMLPErrorPredictorToMunsell",
41
+ "MultiMLPClassCodeErrorPredictorToMunsell",
42
+ # Composite models: Munsell → xyY
43
+ "MultiMLPToxyY",
44
+ # Error predictors: Munsell → xyY
45
+ "MultiMLPErrorPredictorToxyY",
46
+ ]
learning_munsell/models/networks.py ADDED
@@ -0,0 +1,1507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reusable neural network building blocks.
3
+
4
+ Provides shared network architectures for training scripts,
5
+ including MLP components and error predictors.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+ from torch import Tensor, nn
12
+
13
+ __all__ = [
14
+ # Building blocks
15
+ "ResidualBlock",
16
+ # Component networks (single output)
17
+ "ComponentMLP",
18
+ "ComponentResNet",
19
+ "ComponentErrorPredictor",
20
+ # Transformer building blocks
21
+ "FeatureTokenizer",
22
+ "TransformerBlock",
23
+ # Composite models: xyY → Munsell
24
+ "MLPToMunsell",
25
+ "MultiHeadMLPToMunsell",
26
+ "MultiMLPToMunsell",
27
+ "MultiMLPClassCodeToMunsell",
28
+ "MultiMLPHueAngleToMunsell",
29
+ "MultiResNetToMunsell",
30
+ "TransformerToMunsell",
31
+ # Error predictors: xyY → Munsell
32
+ "MultiHeadErrorPredictorToMunsell",
33
+ "MultiMLPErrorPredictorToMunsell",
34
+ "MultiMLPClassCodeErrorPredictorToMunsell",
35
+ "MultiResNetErrorPredictorToMunsell",
36
+ # Composite models: Munsell → xyY
37
+ "MultiMLPToxyY",
38
+ # Error predictors: Munsell → xyY
39
+ "MultiMLPErrorPredictorToxyY",
40
+ ]
41
+
42
+
43
+ # =============================================================================
44
+ # Building Blocks
45
+ # =============================================================================
46
+
47
+
48
+ class ResidualBlock(nn.Module):
49
+ """
50
+ Residual block with GELU activation and batch normalization.
51
+
52
+ Architecture::
53
+
54
+ input -> Linear -> GELU -> BatchNorm ->
55
+ Linear -> BatchNorm -> add input -> GELU
56
+
57
+ Parameters
58
+ ----------
59
+ dim : int
60
+ Dimension of input and output features.
61
+
62
+ Attributes
63
+ ----------
64
+ block : nn.Sequential
65
+ Sequential block with linear layers, GELU, and BatchNorm.
66
+ activation : nn.GELU
67
+ Final activation after residual addition.
68
+ """
69
+
70
+ def __init__(self, dim: int) -> None:
71
+ """Initialize residual block."""
72
+ super().__init__()
73
+ self.block = nn.Sequential(
74
+ nn.Linear(dim, dim),
75
+ nn.GELU(),
76
+ nn.BatchNorm1d(dim),
77
+ nn.Linear(dim, dim),
78
+ nn.BatchNorm1d(dim),
79
+ )
80
+ self.activation = nn.GELU()
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ """
84
+ Forward pass with residual connection.
85
+
86
+ Parameters
87
+ ----------
88
+ x : Tensor
89
+ Input tensor of shape (batch_size, dim).
90
+
91
+ Returns
92
+ -------
93
+ Tensor
94
+ Output tensor of shape (batch_size, dim).
95
+ """
96
+ return self.activation(x + self.block(x))
97
+
98
+
99
+ # =============================================================================
100
+ # Component Networks (Single Output)
101
+ # =============================================================================
102
+
103
+
104
+ class ComponentMLP(nn.Module):
105
+ """
106
+ Independent MLP for a single Munsell component.
107
+
108
+ Architecture: input_dim → 128 → 256 → 512 → 256 → 128 → output_dim
109
+
110
+ Parameters
111
+ ----------
112
+ input_dim : int, optional
113
+ Input feature dimension. Default is 3 (for xyY).
114
+ output_dim : int, optional
115
+ Output feature dimension. Default is 1.
116
+ width_multiplier : float, optional
117
+ Multiplier for hidden layer dimensions. Default is 1.0.
118
+ dropout : float, optional
119
+ Dropout probability between layers. Default is 0.0.
120
+
121
+ Attributes
122
+ ----------
123
+ network : nn.Sequential
124
+ Feed-forward network with encoder-decoder structure.
125
+
126
+ Notes
127
+ -----
128
+ Uses ReLU activations and batch normalization. The encoder-decoder
129
+ architecture expands to 512-dim (or scaled by width_multiplier) and
130
+ then contracts back to ``output_dim`` outputs. Optional dropout can be
131
+ applied between layers for regularization.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ input_dim: int = 3,
137
+ output_dim: int = 1,
138
+ width_multiplier: float = 1.0,
139
+ dropout: float = 0.0,
140
+ ) -> None:
141
+ """Initialize the component-specific MLP."""
142
+ super().__init__()
143
+
144
+ # Scale hidden dimensions
145
+ h1 = int(128 * width_multiplier)
146
+ h2 = int(256 * width_multiplier)
147
+ h3 = int(512 * width_multiplier)
148
+
149
+ layers: list[nn.Module] = [
150
+ # Encoder
151
+ nn.Linear(input_dim, h1),
152
+ nn.ReLU(),
153
+ nn.BatchNorm1d(h1),
154
+ ]
155
+
156
+ if dropout > 0:
157
+ layers.append(nn.Dropout(dropout))
158
+
159
+ layers.extend(
160
+ [
161
+ nn.Linear(h1, h2),
162
+ nn.ReLU(),
163
+ nn.BatchNorm1d(h2),
164
+ ]
165
+ )
166
+
167
+ if dropout > 0:
168
+ layers.append(nn.Dropout(dropout))
169
+
170
+ layers.extend(
171
+ [
172
+ nn.Linear(h2, h3),
173
+ nn.ReLU(),
174
+ nn.BatchNorm1d(h3),
175
+ ]
176
+ )
177
+
178
+ if dropout > 0:
179
+ layers.append(nn.Dropout(dropout))
180
+
181
+ layers.extend(
182
+ [
183
+ # Decoder
184
+ nn.Linear(h3, h2),
185
+ nn.ReLU(),
186
+ nn.BatchNorm1d(h2),
187
+ ]
188
+ )
189
+
190
+ if dropout > 0:
191
+ layers.append(nn.Dropout(dropout))
192
+
193
+ layers.extend(
194
+ [
195
+ nn.Linear(h2, h1),
196
+ nn.ReLU(),
197
+ nn.BatchNorm1d(h1),
198
+ # Output
199
+ nn.Linear(h1, output_dim),
200
+ ]
201
+ )
202
+
203
+ self.network = nn.Sequential(*layers)
204
+
205
+ def forward(self, x: Tensor) -> Tensor:
206
+ """
207
+ Forward pass through the component-specific network.
208
+
209
+ Parameters
210
+ ----------
211
+ x : Tensor
212
+ Input tensor of shape (batch_size, input_dim).
213
+
214
+ Returns
215
+ -------
216
+ Tensor
217
+ Output tensor of shape (batch_size, output_dim) containing the
218
+ predicted component value(s).
219
+ """
220
+ return self.network(x)
221
+
222
+
223
+ class ComponentResNet(nn.Module):
224
+ """
225
+ Independent ResNet for a single Munsell component with true skip connections.
226
+
227
+ Architecture::
228
+
229
+ input -> projection -> ResidualBlock x num_blocks -> output
230
+
231
+ Unlike ComponentMLP, this uses actual residual blocks where:
232
+ output = activation(x + f(x))
233
+
234
+ Parameters
235
+ ----------
236
+ input_dim : int, optional
237
+ Input feature dimension. Default is 3 (for xyY).
238
+ hidden_dim : int, optional
239
+ Hidden dimension for residual blocks. Default is 256.
240
+ num_blocks : int, optional
241
+ Number of residual blocks. Default is 4.
242
+
243
+ Attributes
244
+ ----------
245
+ input_proj : nn.Sequential
246
+ Projects input to hidden dimension with GELU activation.
247
+ res_blocks : nn.ModuleList
248
+ List of ResidualBlock modules with skip connections.
249
+ output_proj : nn.Linear
250
+ Projects hidden dimension to single output.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ input_dim: int = 3,
256
+ hidden_dim: int = 256,
257
+ num_blocks: int = 4,
258
+ ) -> None:
259
+ """Initialize the component-specific ResNet."""
260
+ super().__init__()
261
+
262
+ # Project input to hidden dimension
263
+ self.input_proj = nn.Sequential(
264
+ nn.Linear(input_dim, hidden_dim),
265
+ nn.GELU(),
266
+ )
267
+
268
+ # Stack of residual blocks with skip connections
269
+ self.res_blocks = nn.ModuleList(
270
+ [ResidualBlock(hidden_dim) for _ in range(num_blocks)]
271
+ )
272
+
273
+ # Project to output
274
+ self.output_proj = nn.Linear(hidden_dim, 1)
275
+
276
+ def forward(self, x: Tensor) -> Tensor:
277
+ """
278
+ Forward pass through the ResNet with skip connections.
279
+
280
+ Parameters
281
+ ----------
282
+ x : Tensor
283
+ Input tensor of shape (batch_size, input_dim).
284
+
285
+ Returns
286
+ -------
287
+ Tensor
288
+ Output tensor of shape (batch_size, 1).
289
+ """
290
+ x = self.input_proj(x)
291
+ for block in self.res_blocks:
292
+ x = block(x) # Each block applies: activation(x + f(x))
293
+ return self.output_proj(x)
294
+
295
+
296
+ class ComponentErrorPredictor(nn.Module):
297
+ """
298
+ Independent error predictor for a single Munsell component.
299
+
300
+ A deep MLP that learns to predict residual errors for one Munsell
301
+ component (hue, value, chroma, or code).
302
+
303
+ Parameters
304
+ ----------
305
+ input_dim : int, optional
306
+ Input feature dimension. Default is 7 (xyY_norm + base_pred_norm).
307
+ width_multiplier : float, optional
308
+ Multiplier for hidden layer widths. Default is 1.0.
309
+ Use 1.5 for chroma which requires more capacity.
310
+
311
+ Attributes
312
+ ----------
313
+ network : nn.Sequential
314
+ Feed-forward network: input → 128 → 256 → 512 → 256 → 128 → 1
315
+ with GELU activations and BatchNorm after each hidden layer.
316
+
317
+ Notes
318
+ -----
319
+ Default input is [xyY_norm (3) + base_pred_norm (4)] = 7 features.
320
+ Output is a single scalar error correction for the component.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ input_dim: int = 7,
326
+ width_multiplier: float = 1.0,
327
+ ) -> None:
328
+ """Initialize the error predictor."""
329
+ super().__init__()
330
+
331
+ # Scale hidden dimensions
332
+ h1 = int(128 * width_multiplier)
333
+ h2 = int(256 * width_multiplier)
334
+ h3 = int(512 * width_multiplier)
335
+
336
+ self.network = nn.Sequential(
337
+ # Encoder
338
+ nn.Linear(input_dim, h1),
339
+ nn.GELU(),
340
+ nn.BatchNorm1d(h1),
341
+ nn.Linear(h1, h2),
342
+ nn.GELU(),
343
+ nn.BatchNorm1d(h2),
344
+ nn.Linear(h2, h3),
345
+ nn.GELU(),
346
+ nn.BatchNorm1d(h3),
347
+ # Decoder
348
+ nn.Linear(h3, h2),
349
+ nn.GELU(),
350
+ nn.BatchNorm1d(h2),
351
+ nn.Linear(h2, h1),
352
+ nn.GELU(),
353
+ nn.BatchNorm1d(h1),
354
+ # Output
355
+ nn.Linear(h1, 1),
356
+ )
357
+
358
+ def forward(self, x: Tensor) -> Tensor:
359
+ """
360
+ Forward pass through the error predictor.
361
+
362
+ Parameters
363
+ ----------
364
+ x : Tensor
365
+ Combined input of shape (batch_size, input_dim).
366
+
367
+ Returns
368
+ -------
369
+ Tensor
370
+ Predicted error correction of shape (batch_size, 1).
371
+ """
372
+ return self.network(x)
373
+
374
+
375
+ # =============================================================================
376
+ # Transformer Building Blocks
377
+ # =============================================================================
378
+
379
+
380
+ class FeatureTokenizer(nn.Module):
381
+ """
382
+ Tokenize each input feature into high-dimensional embedding.
383
+
384
+ Converts each scalar input feature into a learned embedding vector,
385
+ similar to word embeddings in NLP. Also prepends a learnable CLS token
386
+ used for regression output.
387
+
388
+ Parameters
389
+ ----------
390
+ num_features : int
391
+ Number of input features to tokenize.
392
+ embedding_dim : int
393
+ Dimensionality of each token embedding.
394
+
395
+ Attributes
396
+ ----------
397
+ feature_embeddings : nn.ModuleList
398
+ List of linear layers, one per input feature.
399
+ cls_token : nn.Parameter
400
+ Learnable classification token prepended to feature tokens.
401
+ """
402
+
403
+ def __init__(self, num_features: int, embedding_dim: int) -> None:
404
+ """Initialize the feature tokenizer."""
405
+ super().__init__()
406
+ # Each feature gets its own embedding
407
+ self.feature_embeddings = nn.ModuleList(
408
+ [nn.Linear(1, embedding_dim) for _ in range(num_features)]
409
+ )
410
+ # CLS token for regression
411
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
412
+
413
+ def forward(self, x: Tensor) -> Tensor:
414
+ """
415
+ Transform input features into token embeddings.
416
+
417
+ Parameters
418
+ ----------
419
+ x : Tensor
420
+ Input tensor of shape (batch_size, num_features).
421
+
422
+ Returns
423
+ -------
424
+ Tensor
425
+ Token embeddings of shape (batch_size, 1+num_features, embedding_dim).
426
+ First token is CLS, followed by feature tokens.
427
+ """
428
+ batch_size = x.size(0)
429
+
430
+ # Tokenize each feature
431
+ tokens = []
432
+ for i, embedding in enumerate(self.feature_embeddings):
433
+ feature_val = x[:, i : i + 1] # (batch_size, 1)
434
+ token = embedding(feature_val) # (batch_size, embedding_dim)
435
+ tokens.append(token.unsqueeze(1)) # (batch_size, 1, embedding_dim)
436
+
437
+ # Concatenate feature tokens
438
+ feature_tokens = torch.cat(
439
+ tokens, dim=1
440
+ ) # (batch_size, num_features, embedding_dim)
441
+
442
+ # Prepend CLS token
443
+ cls_tokens = self.cls_token.expand(
444
+ batch_size, -1, -1
445
+ ) # (batch_size, 1, embedding_dim)
446
+ return torch.cat(
447
+ [cls_tokens, feature_tokens], dim=1
448
+ ) # (batch_size, 1+num_features, embedding_dim)
449
+
450
+
451
+ class TransformerBlock(nn.Module):
452
+ """
453
+ Standard transformer block with multi-head attention and feedforward network.
454
+
455
+ Implements the classic transformer architecture with self-attention,
456
+ feedforward layers, layer normalization, and residual connections.
457
+
458
+ Parameters
459
+ ----------
460
+ embedding_dim : int
461
+ Dimension of token embeddings.
462
+ num_heads : int
463
+ Number of attention heads.
464
+ ff_dim : int
465
+ Hidden dimension of feedforward network.
466
+ dropout : float, optional
467
+ Dropout probability, default is 0.1.
468
+
469
+ Attributes
470
+ ----------
471
+ attention : nn.MultiheadAttention
472
+ Multi-head self-attention mechanism.
473
+ norm1 : nn.LayerNorm
474
+ Layer normalization after attention.
475
+ feedforward : nn.Sequential
476
+ Feedforward network with GELU activation.
477
+ norm2 : nn.LayerNorm
478
+ Layer normalization after feedforward.
479
+ """
480
+
481
+ def __init__(
482
+ self, embedding_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1
483
+ ) -> None:
484
+ """Initialize the transformer block."""
485
+ super().__init__()
486
+
487
+ self.attention = nn.MultiheadAttention(
488
+ embedding_dim, num_heads, dropout=dropout, batch_first=True
489
+ )
490
+ self.norm1 = nn.LayerNorm(embedding_dim)
491
+
492
+ self.feedforward = nn.Sequential(
493
+ nn.Linear(embedding_dim, ff_dim),
494
+ nn.GELU(),
495
+ nn.Dropout(dropout),
496
+ nn.Linear(ff_dim, embedding_dim),
497
+ nn.Dropout(dropout),
498
+ )
499
+ self.norm2 = nn.LayerNorm(embedding_dim)
500
+
501
+ def forward(self, x: Tensor) -> Tensor:
502
+ """
503
+ Apply transformer block to input tokens.
504
+
505
+ Parameters
506
+ ----------
507
+ x : Tensor
508
+ Input tokens of shape (batch_size, num_tokens, embedding_dim).
509
+
510
+ Returns
511
+ -------
512
+ Tensor
513
+ Transformed tokens of shape (batch_size, num_tokens, embedding_dim).
514
+ """
515
+ # Self-attention with residual
516
+ attn_output, _ = self.attention(x, x, x)
517
+ x = self.norm1(x + attn_output)
518
+
519
+ # Feedforward with residual
520
+ ff_output = self.feedforward(x)
521
+ return self.norm2(x + ff_output)
522
+
523
+
524
+ # =============================================================================
525
+ # Composite Models: xyY → Munsell
526
+ # =============================================================================
527
+
528
+
529
+ class MLPToMunsell(nn.Module):
530
+ """
531
+ Large MLP for xyY to Munsell conversion.
532
+
533
+ Architecture: 3 → 128 → 256 → 512 → 512 → 256 → 128 → 4
534
+
535
+ Attributes
536
+ ----------
537
+ network : nn.Sequential
538
+ Feed-forward network with ReLU activations and BatchNorm.
539
+ """
540
+
541
+ def __init__(self) -> None:
542
+ """Initialize the MunsellMLP network."""
543
+ super().__init__()
544
+
545
+ self.network = nn.Sequential(
546
+ nn.Linear(3, 128),
547
+ nn.ReLU(),
548
+ nn.BatchNorm1d(128),
549
+ nn.Linear(128, 256),
550
+ nn.ReLU(),
551
+ nn.BatchNorm1d(256),
552
+ nn.Linear(256, 512),
553
+ nn.ReLU(),
554
+ nn.BatchNorm1d(512),
555
+ nn.Linear(512, 512),
556
+ nn.ReLU(),
557
+ nn.BatchNorm1d(512),
558
+ nn.Linear(512, 256),
559
+ nn.ReLU(),
560
+ nn.BatchNorm1d(256),
561
+ nn.Linear(256, 128),
562
+ nn.ReLU(),
563
+ nn.BatchNorm1d(128),
564
+ nn.Linear(128, 4),
565
+ )
566
+
567
+ def forward(self, x: Tensor) -> Tensor:
568
+ """
569
+ Forward pass through the network.
570
+
571
+ Parameters
572
+ ----------
573
+ x : Tensor
574
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
575
+
576
+ Returns
577
+ -------
578
+ Tensor
579
+ Output tensor of shape (batch_size, 4) containing normalized Munsell
580
+ specifications [hue, value, chroma, code].
581
+ """
582
+ return self.network(x)
583
+
584
+
585
+ class MultiHeadMLPToMunsell(nn.Module):
586
+ """
587
+ Multi-head MLP for xyY to Munsell conversion.
588
+
589
+ Each component (hue, value, chroma, code) has a specialized decoder head
590
+ after a shared encoder. The chroma head is wider to handle the more complex
591
+ non-linear relationship between xyY and chroma.
592
+
593
+ Attributes
594
+ ----------
595
+ encoder : nn.Sequential
596
+ Shared encoder: 3 → 128 → 256 → 512 with ReLU and BatchNorm.
597
+ hue_head : nn.Sequential
598
+ Hue decoder: 512 → 256 → 128 → 1 (circular component).
599
+ value_head : nn.Sequential
600
+ Value decoder: 512 → 256 → 128 → 1 (linear component).
601
+ chroma_head : nn.Sequential
602
+ Chroma decoder: 512 → 384 → 256 → 128 → 1 (wider for complexity).
603
+ code_head : nn.Sequential
604
+ Code decoder: 512 → 256 → 128 → 1 (discrete component).
605
+
606
+ Notes
607
+ -----
608
+ The chroma head has increased capacity (384 units in first layer) to handle
609
+ the more complex non-linear relationship between xyY and chroma.
610
+ """
611
+
612
+ def __init__(self) -> None:
613
+ """Initialize the multi-head MLP model."""
614
+ super().__init__()
615
+
616
+ # Shared encoder - learns general color space features
617
+ self.encoder = nn.Sequential(
618
+ nn.Linear(3, 128),
619
+ nn.ReLU(),
620
+ nn.BatchNorm1d(128),
621
+ nn.Linear(128, 256),
622
+ nn.ReLU(),
623
+ nn.BatchNorm1d(256),
624
+ nn.Linear(256, 512),
625
+ nn.ReLU(),
626
+ nn.BatchNorm1d(512),
627
+ )
628
+
629
+ # Hue head - circular/angular component
630
+ self.hue_head = nn.Sequential(
631
+ nn.Linear(512, 256),
632
+ nn.ReLU(),
633
+ nn.BatchNorm1d(256),
634
+ nn.Linear(256, 128),
635
+ nn.ReLU(),
636
+ nn.BatchNorm1d(128),
637
+ nn.Linear(128, 1),
638
+ )
639
+
640
+ # Value head - linear lightness
641
+ self.value_head = nn.Sequential(
642
+ nn.Linear(512, 256),
643
+ nn.ReLU(),
644
+ nn.BatchNorm1d(256),
645
+ nn.Linear(256, 128),
646
+ nn.ReLU(),
647
+ nn.BatchNorm1d(128),
648
+ nn.Linear(128, 1),
649
+ )
650
+
651
+ # Chroma head - non-linear saturation (WIDER for harder task)
652
+ self.chroma_head = nn.Sequential(
653
+ nn.Linear(512, 384), # Wider than other heads
654
+ nn.ReLU(),
655
+ nn.BatchNorm1d(384),
656
+ nn.Linear(384, 256),
657
+ nn.ReLU(),
658
+ nn.BatchNorm1d(256),
659
+ nn.Linear(256, 128),
660
+ nn.ReLU(),
661
+ nn.BatchNorm1d(128),
662
+ nn.Linear(128, 1),
663
+ )
664
+
665
+ # Code head - discrete categorical
666
+ self.code_head = nn.Sequential(
667
+ nn.Linear(512, 256),
668
+ nn.ReLU(),
669
+ nn.BatchNorm1d(256),
670
+ nn.Linear(256, 128),
671
+ nn.ReLU(),
672
+ nn.BatchNorm1d(128),
673
+ nn.Linear(128, 1),
674
+ )
675
+
676
+ def forward(self, x: Tensor) -> Tensor:
677
+ """
678
+ Forward pass through the multi-head network.
679
+
680
+ Parameters
681
+ ----------
682
+ x : Tensor
683
+ Input xyY values of shape (batch_size, 3).
684
+
685
+ Returns
686
+ -------
687
+ Tensor
688
+ Concatenated Munsell predictions [hue, value, chroma, code]
689
+ of shape (batch_size, 4).
690
+ """
691
+ # Shared feature extraction
692
+ features = self.encoder(x)
693
+
694
+ # Component-specific predictions
695
+ hue = self.hue_head(features)
696
+ value = self.value_head(features)
697
+ chroma = self.chroma_head(features)
698
+ code = self.code_head(features)
699
+
700
+ # Concatenate: [Hue, Value, Chroma, Code]
701
+ return torch.cat([hue, value, chroma, code], dim=1)
702
+
703
+
704
+ class MultiMLPToMunsell(nn.Module):
705
+ """
706
+ Multi-MLP for xyY to Munsell conversion.
707
+
708
+ Uses 4 independent ComponentMLP branches, one for each Munsell component.
709
+ The chroma branch can be wider to handle the more complex relationship.
710
+
711
+ Parameters
712
+ ----------
713
+ chroma_width_multiplier : float, optional
714
+ Width multiplier for the chroma branch. Default is 2.0.
715
+ dropout : float, optional
716
+ Dropout probability for all branches. Default is 0.1.
717
+
718
+ Attributes
719
+ ----------
720
+ hue_branch : ComponentMLP
721
+ MLP for hue component (1.0x width).
722
+ value_branch : ComponentMLP
723
+ MLP for value component (1.0x width).
724
+ chroma_branch : ComponentMLP
725
+ MLP for chroma component (configurable width).
726
+ code_branch : ComponentMLP
727
+ MLP for hue code component (1.0x width).
728
+ """
729
+
730
+ def __init__(
731
+ self, chroma_width_multiplier: float = 2.0, dropout: float = 0.1
732
+ ) -> None:
733
+ """Initialize the multi-branch MLP model."""
734
+ super().__init__()
735
+
736
+ self.hue_branch = ComponentMLP(
737
+ input_dim=3, width_multiplier=1.0, dropout=dropout
738
+ )
739
+ self.value_branch = ComponentMLP(
740
+ input_dim=3, width_multiplier=1.0, dropout=dropout
741
+ )
742
+ self.chroma_branch = ComponentMLP(
743
+ input_dim=3, width_multiplier=chroma_width_multiplier, dropout=dropout
744
+ )
745
+ self.code_branch = ComponentMLP(
746
+ input_dim=3, width_multiplier=1.0, dropout=dropout
747
+ )
748
+
749
+ def forward(self, x: Tensor) -> Tensor:
750
+ """
751
+ Forward pass through all 4 independent branches.
752
+
753
+ Parameters
754
+ ----------
755
+ x : Tensor
756
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
757
+
758
+ Returns
759
+ -------
760
+ Tensor
761
+ Concatenated predictions [hue, value, chroma, code]
762
+ of shape (batch_size, 4).
763
+ """
764
+ hue = self.hue_branch(x)
765
+ value = self.value_branch(x)
766
+ chroma = self.chroma_branch(x)
767
+ code = self.code_branch(x)
768
+ return torch.cat([hue, value, chroma, code], dim=1)
769
+
770
+
771
+ class MultiMLPClassCodeToMunsell(nn.Module):
772
+ """
773
+ Multi-MLP for xyY to Munsell conversion with classification head for code.
774
+
775
+ Uses 3 regression branches (hue, value, chroma) identical to
776
+ :class:`MultiMLPToMunsell` and 1 classification branch that outputs
777
+ 10 logits for the hue code (1-10).
778
+
779
+ This addresses the hue code boundary error problem where continuous
780
+ regression of the code value leads to off-by-one rounding errors at
781
+ hue family boundaries.
782
+
783
+ Parameters
784
+ ----------
785
+ chroma_width_multiplier : float, optional
786
+ Width multiplier for the chroma branch. Default is 2.0.
787
+ dropout : float, optional
788
+ Dropout probability for all branches. Default is 0.1.
789
+
790
+ Attributes
791
+ ----------
792
+ hue_branch : ComponentMLP
793
+ MLP for hue component (1.0x width, 1 output).
794
+ value_branch : ComponentMLP
795
+ MLP for value component (1.0x width, 1 output).
796
+ chroma_branch : ComponentMLP
797
+ MLP for chroma component (configurable width, 1 output).
798
+ code_branch : ComponentMLP
799
+ MLP for hue code classification (1.0x width, 10 logit outputs).
800
+ """
801
+
802
+ def __init__(
803
+ self, chroma_width_multiplier: float = 2.0, dropout: float = 0.1
804
+ ) -> None:
805
+ """Initialize the multi-branch MLP model with classification code head."""
806
+ super().__init__()
807
+
808
+ self.hue_branch = ComponentMLP(
809
+ input_dim=3, output_dim=1, width_multiplier=1.0, dropout=dropout
810
+ )
811
+ self.value_branch = ComponentMLP(
812
+ input_dim=3, output_dim=1, width_multiplier=1.0, dropout=dropout
813
+ )
814
+ self.chroma_branch = ComponentMLP(
815
+ input_dim=3,
816
+ output_dim=1,
817
+ width_multiplier=chroma_width_multiplier,
818
+ dropout=dropout,
819
+ )
820
+ self.code_branch = ComponentMLP(
821
+ input_dim=3, output_dim=10, width_multiplier=1.0, dropout=dropout
822
+ )
823
+
824
+ def forward(self, x: Tensor) -> Tensor:
825
+ """
826
+ Forward pass through all branches.
827
+
828
+ Parameters
829
+ ----------
830
+ x : Tensor
831
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
832
+
833
+ Returns
834
+ -------
835
+ Tensor
836
+ Concatenated predictions [hue, value, chroma, code_logit_0, ...,
837
+ code_logit_9] of shape (batch_size, 13).
838
+ """
839
+ hue = self.hue_branch(x)
840
+ value = self.value_branch(x)
841
+ chroma = self.chroma_branch(x)
842
+ code_logits = self.code_branch(x)
843
+ return torch.cat([hue, value, chroma, code_logits], dim=1)
844
+
845
+
846
+ class MultiMLPHueAngleToMunsell(nn.Module):
847
+ """
848
+ Multi-MLP for xyY to Munsell conversion with sin/cos hue angle encoding.
849
+
850
+ Replaces separate hue and code branches with a single hue angle branch
851
+ that outputs sin and cos of the Munsell angle
852
+ (``angle = hue + (code - 1) * 10``). This eliminates discrete code
853
+ prediction entirely, encoding the full circular hue information in a
854
+ continuous, wrap-safe representation.
855
+
856
+ Parameters
857
+ ----------
858
+ chroma_width_multiplier : float, optional
859
+ Width multiplier for the chroma branch. Default is 2.0.
860
+ dropout : float, optional
861
+ Dropout probability for all branches. Default is 0.1.
862
+
863
+ Attributes
864
+ ----------
865
+ hue_angle_branch : ComponentMLP
866
+ MLP for hue angle sin/cos (1.0x width, 2 outputs).
867
+ value_branch : ComponentMLP
868
+ MLP for value component (1.0x width, 1 output).
869
+ chroma_branch : ComponentMLP
870
+ MLP for chroma component (configurable width, 1 output).
871
+ """
872
+
873
+ def __init__(
874
+ self, chroma_width_multiplier: float = 2.0, dropout: float = 0.1
875
+ ) -> None:
876
+ """Initialize the multi-branch MLP model with hue angle encoding."""
877
+ super().__init__()
878
+
879
+ self.hue_angle_branch = ComponentMLP(
880
+ input_dim=3, output_dim=2, width_multiplier=1.0, dropout=dropout
881
+ )
882
+ self.value_branch = ComponentMLP(
883
+ input_dim=3, output_dim=1, width_multiplier=1.0, dropout=dropout
884
+ )
885
+ self.chroma_branch = ComponentMLP(
886
+ input_dim=3,
887
+ output_dim=1,
888
+ width_multiplier=chroma_width_multiplier,
889
+ dropout=dropout,
890
+ )
891
+
892
+ def forward(self, x: Tensor) -> Tensor:
893
+ """
894
+ Forward pass through all branches.
895
+
896
+ Parameters
897
+ ----------
898
+ x : Tensor
899
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
900
+
901
+ Returns
902
+ -------
903
+ Tensor
904
+ Concatenated predictions [sin_angle, cos_angle, value, chroma]
905
+ of shape (batch_size, 4).
906
+ """
907
+ hue_angle = self.hue_angle_branch(x)
908
+ value = self.value_branch(x)
909
+ chroma = self.chroma_branch(x)
910
+ return torch.cat([hue_angle, value, chroma], dim=1)
911
+
912
+
913
+ class MultiResNetToMunsell(nn.Module):
914
+ """
915
+ Multi-ResNet for xyY to Munsell conversion with true skip connections.
916
+
917
+ Uses 4 independent ComponentResNet branches, one for each Munsell component.
918
+ Each branch contains actual residual blocks with skip connections.
919
+
920
+ Parameters
921
+ ----------
922
+ hidden_dim : int, optional
923
+ Hidden dimension for residual blocks. Default is 256.
924
+ num_blocks : int, optional
925
+ Number of residual blocks per branch. Default is 4.
926
+ chroma_hidden_dim : int, optional
927
+ Hidden dimension for chroma branch (typically larger). Default is 512.
928
+
929
+ Attributes
930
+ ----------
931
+ hue_branch : ComponentResNet
932
+ ResNet for hue component.
933
+ value_branch : ComponentResNet
934
+ ResNet for value component.
935
+ chroma_branch : ComponentResNet
936
+ ResNet for chroma component (larger hidden dim).
937
+ code_branch : ComponentResNet
938
+ ResNet for hue code component.
939
+ """
940
+
941
+ def __init__(
942
+ self,
943
+ hidden_dim: int = 256,
944
+ num_blocks: int = 4,
945
+ chroma_hidden_dim: int = 512,
946
+ ) -> None:
947
+ """Initialize the multi-branch ResNet model."""
948
+ super().__init__()
949
+
950
+ self.hue_branch = ComponentResNet(
951
+ input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
952
+ )
953
+ self.value_branch = ComponentResNet(
954
+ input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
955
+ )
956
+ self.chroma_branch = ComponentResNet(
957
+ input_dim=3, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks
958
+ )
959
+ self.code_branch = ComponentResNet(
960
+ input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
961
+ )
962
+
963
+ def forward(self, x: Tensor) -> Tensor:
964
+ """
965
+ Forward pass through all 4 independent ResNet branches.
966
+
967
+ Parameters
968
+ ----------
969
+ x : Tensor
970
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
971
+
972
+ Returns
973
+ -------
974
+ Tensor
975
+ Concatenated predictions [hue, value, chroma, code]
976
+ of shape (batch_size, 4).
977
+ """
978
+ hue = self.hue_branch(x)
979
+ value = self.value_branch(x)
980
+ chroma = self.chroma_branch(x)
981
+ code = self.code_branch(x)
982
+ return torch.cat([hue, value, chroma, code], dim=1)
983
+
984
+
985
+ class TransformerToMunsell(nn.Module):
986
+ """
987
+ Transformer for xyY to Munsell conversion.
988
+
989
+ Uses a feature tokenizer to convert input features to embeddings,
990
+ followed by transformer blocks with self-attention, and separate
991
+ output heads for each Munsell component.
992
+
993
+ Parameters
994
+ ----------
995
+ num_features : int, optional
996
+ Number of input features (default is 3 for xyY).
997
+ embedding_dim : int, optional
998
+ Dimension of token embeddings (default is 256).
999
+ num_blocks : int, optional
1000
+ Number of transformer blocks (default is 6).
1001
+ num_heads : int, optional
1002
+ Number of attention heads (default is 8).
1003
+ ff_dim : int, optional
1004
+ Feedforward network hidden dimension (default is 1024).
1005
+ dropout : float, optional
1006
+ Dropout probability (default is 0.1).
1007
+
1008
+ Attributes
1009
+ ----------
1010
+ tokenizer : FeatureTokenizer
1011
+ Converts input features to token embeddings with CLS token.
1012
+ transformer_blocks : nn.ModuleList
1013
+ Stack of transformer blocks with self-attention.
1014
+ final_norm : nn.LayerNorm
1015
+ Final layer normalization before output heads.
1016
+ hue_head : nn.Sequential
1017
+ Output head for hue prediction.
1018
+ value_head : nn.Sequential
1019
+ Output head for value prediction.
1020
+ chroma_head : nn.Sequential
1021
+ Deeper output head for chroma prediction.
1022
+ code_head : nn.Sequential
1023
+ Output head for hue code prediction.
1024
+
1025
+ Notes
1026
+ -----
1027
+ Architecture: 3 xyY features → 3 tokens + 1 CLS token → transformer blocks
1028
+ with self-attention → multi-head output with specialized component heads.
1029
+ The chroma head has additional depth due to prediction difficulty.
1030
+ """
1031
+
1032
+ def __init__(
1033
+ self,
1034
+ num_features: int = 3,
1035
+ embedding_dim: int = 256,
1036
+ num_blocks: int = 6,
1037
+ num_heads: int = 8,
1038
+ ff_dim: int = 1024,
1039
+ dropout: float = 0.1,
1040
+ ) -> None:
1041
+ """Initialize the transformer model."""
1042
+ super().__init__()
1043
+
1044
+ self.tokenizer = FeatureTokenizer(num_features, embedding_dim)
1045
+
1046
+ self.transformer_blocks = nn.ModuleList(
1047
+ [
1048
+ TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
1049
+ for _ in range(num_blocks)
1050
+ ]
1051
+ )
1052
+
1053
+ self.final_norm = nn.LayerNorm(embedding_dim)
1054
+
1055
+ # Multi-head output - separate heads for each Munsell component
1056
+ self.hue_head = nn.Sequential(
1057
+ nn.Linear(embedding_dim, 128),
1058
+ nn.GELU(),
1059
+ nn.Dropout(dropout),
1060
+ nn.Linear(128, 1),
1061
+ )
1062
+ self.value_head = nn.Sequential(
1063
+ nn.Linear(embedding_dim, 128),
1064
+ nn.GELU(),
1065
+ nn.Dropout(dropout),
1066
+ nn.Linear(128, 1),
1067
+ )
1068
+ self.chroma_head = nn.Sequential(
1069
+ nn.Linear(embedding_dim, 256),
1070
+ nn.GELU(),
1071
+ nn.Dropout(dropout),
1072
+ nn.Linear(256, 128),
1073
+ nn.GELU(),
1074
+ nn.Linear(128, 1),
1075
+ )
1076
+ self.code_head = nn.Sequential(
1077
+ nn.Linear(embedding_dim, 128),
1078
+ nn.GELU(),
1079
+ nn.Dropout(dropout),
1080
+ nn.Linear(128, 1),
1081
+ )
1082
+
1083
+ def forward(self, x: Tensor) -> Tensor:
1084
+ """
1085
+ Forward pass through the transformer.
1086
+
1087
+ Parameters
1088
+ ----------
1089
+ x : Tensor
1090
+ Input xyY values of shape (batch_size, 3).
1091
+
1092
+ Returns
1093
+ -------
1094
+ Tensor
1095
+ Predicted Munsell specification [hue, value, chroma, code]
1096
+ of shape (batch_size, 4).
1097
+
1098
+ Notes
1099
+ -----
1100
+ The CLS token representation is used for the final prediction through
1101
+ separate task-specific heads for each Munsell component.
1102
+ """
1103
+ tokens = self.tokenizer(x)
1104
+
1105
+ for block in self.transformer_blocks:
1106
+ tokens = block(tokens)
1107
+
1108
+ tokens = self.final_norm(tokens)
1109
+ cls_token = tokens[:, 0, :]
1110
+
1111
+ hue = self.hue_head(cls_token)
1112
+ value = self.value_head(cls_token)
1113
+ chroma = self.chroma_head(cls_token)
1114
+ code = self.code_head(cls_token)
1115
+
1116
+ return torch.cat([hue, value, chroma, code], dim=1)
1117
+
1118
+
1119
+ # =============================================================================
1120
+ # Error Predictors: xyY → Munsell
1121
+ # =============================================================================
1122
+
1123
+
1124
+ class MultiHeadErrorPredictorToMunsell(nn.Module):
1125
+ """
1126
+ Multi-Head error predictor for xyY to Munsell conversion.
1127
+
1128
+ Each branch is a ComponentErrorPredictor specialized for one
1129
+ Munsell component. The chroma branch is wider (1.5x) to handle
1130
+ the more complex error patterns in chroma prediction.
1131
+
1132
+ Parameters
1133
+ ----------
1134
+ input_dim : int, optional
1135
+ Input feature dimension. Default is 7.
1136
+ chroma_width : float, optional
1137
+ Width multiplier for chroma branch. Default is 1.5.
1138
+
1139
+ Attributes
1140
+ ----------
1141
+ hue_branch : ComponentErrorPredictor
1142
+ Error predictor for hue component (1.0x width).
1143
+ value_branch : ComponentErrorPredictor
1144
+ Error predictor for value component (1.0x width).
1145
+ chroma_branch : ComponentErrorPredictor
1146
+ Error predictor for chroma component (1.5x width by default).
1147
+ code_branch : ComponentErrorPredictor
1148
+ Error predictor for hue code component (1.0x width).
1149
+ """
1150
+
1151
+ def __init__(
1152
+ self,
1153
+ input_dim: int = 7,
1154
+ chroma_width: float = 1.5,
1155
+ ) -> None:
1156
+ """Initialize the multi-head error predictor."""
1157
+ super().__init__()
1158
+
1159
+ # Independent error predictor for each component
1160
+ self.hue_branch = ComponentErrorPredictor(
1161
+ input_dim=input_dim, width_multiplier=1.0
1162
+ )
1163
+ self.value_branch = ComponentErrorPredictor(
1164
+ input_dim=input_dim, width_multiplier=1.0
1165
+ )
1166
+ self.chroma_branch = ComponentErrorPredictor(
1167
+ input_dim=input_dim, width_multiplier=chroma_width
1168
+ )
1169
+ self.code_branch = ComponentErrorPredictor(
1170
+ input_dim=input_dim, width_multiplier=1.0
1171
+ )
1172
+
1173
+ def forward(self, x: Tensor) -> Tensor:
1174
+ """
1175
+ Forward pass through all error predictor branches.
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ x : Tensor
1180
+ Combined input of shape (batch_size, input_dim).
1181
+
1182
+ Returns
1183
+ -------
1184
+ Tensor
1185
+ Concatenated error corrections [hue, value, chroma, code]
1186
+ of shape (batch_size, 4).
1187
+ """
1188
+ # Each branch processes the same combined input independently
1189
+ hue_error = self.hue_branch(x)
1190
+ value_error = self.value_branch(x)
1191
+ chroma_error = self.chroma_branch(x)
1192
+ code_error = self.code_branch(x)
1193
+
1194
+ # Concatenate: [Hue_error, Value_error, Chroma_error, Code_error]
1195
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
1196
+
1197
+
1198
+ class MultiMLPErrorPredictorToMunsell(nn.Module):
1199
+ """
1200
+ Multi-MLP error predictor for xyY to Munsell conversion.
1201
+
1202
+ Uses 4 independent ComponentErrorPredictor branches, one for each
1203
+ Munsell component error.
1204
+
1205
+ Parameters
1206
+ ----------
1207
+ chroma_width : float, optional
1208
+ Width multiplier for chroma branch. Default is 1.5.
1209
+
1210
+ Attributes
1211
+ ----------
1212
+ hue_branch : ComponentErrorPredictor
1213
+ Error predictor for hue component (1.0x width).
1214
+ value_branch : ComponentErrorPredictor
1215
+ Error predictor for value component (1.0x width).
1216
+ chroma_branch : ComponentErrorPredictor
1217
+ Error predictor for chroma component (configurable width).
1218
+ code_branch : ComponentErrorPredictor
1219
+ Error predictor for hue code component (1.0x width).
1220
+ """
1221
+
1222
+ def __init__(self, chroma_width: float = 1.5) -> None:
1223
+ """Initialize the multi-head error predictor."""
1224
+ super().__init__()
1225
+
1226
+ self.hue_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
1227
+ self.value_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
1228
+ self.chroma_branch = ComponentErrorPredictor(
1229
+ input_dim=7, width_multiplier=chroma_width
1230
+ )
1231
+ self.code_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
1232
+
1233
+ def forward(self, x: Tensor) -> Tensor:
1234
+ """
1235
+ Forward pass through all error predictor branches.
1236
+
1237
+ Parameters
1238
+ ----------
1239
+ x : Tensor
1240
+ Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
1241
+
1242
+ Returns
1243
+ -------
1244
+ Tensor
1245
+ Concatenated error corrections [hue, value, chroma, code]
1246
+ of shape (batch_size, 4).
1247
+ """
1248
+ hue_error = self.hue_branch(x)
1249
+ value_error = self.value_branch(x)
1250
+ chroma_error = self.chroma_branch(x)
1251
+ code_error = self.code_branch(x)
1252
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
1253
+
1254
+
1255
+ class MultiMLPClassCodeErrorPredictorToMunsell(nn.Module):
1256
+ """
1257
+ Multi-MLP error predictor for the Classification Code base model.
1258
+
1259
+ Uses 3 independent ComponentErrorPredictor branches for the regression
1260
+ outputs (hue, value, chroma). No code branch is needed since the
1261
+ classification head already achieves perfect code accuracy.
1262
+
1263
+ Parameters
1264
+ ----------
1265
+ input_dim : int, optional
1266
+ Input feature dimension per branch. Default is 6
1267
+ (xyY_norm + regression_pred_norm). Use 16 for code-aware variant
1268
+ (xyY_norm + regression_pred_norm + code_onehot).
1269
+ chroma_width : float, optional
1270
+ Width multiplier for chroma branch. Default is 1.5.
1271
+
1272
+ Attributes
1273
+ ----------
1274
+ hue_branch : ComponentErrorPredictor
1275
+ Error predictor for hue component (1.0x width).
1276
+ value_branch : ComponentErrorPredictor
1277
+ Error predictor for value component (1.0x width).
1278
+ chroma_branch : ComponentErrorPredictor
1279
+ Error predictor for chroma component (configurable width).
1280
+ """
1281
+
1282
+ def __init__(self, input_dim: int = 6, chroma_width: float = 1.5) -> None:
1283
+ """Initialize the 3-branch error predictor."""
1284
+ super().__init__()
1285
+
1286
+ self.hue_branch = ComponentErrorPredictor(
1287
+ input_dim=input_dim, width_multiplier=1.0
1288
+ )
1289
+ self.value_branch = ComponentErrorPredictor(
1290
+ input_dim=input_dim, width_multiplier=1.0
1291
+ )
1292
+ self.chroma_branch = ComponentErrorPredictor(
1293
+ input_dim=input_dim, width_multiplier=chroma_width
1294
+ )
1295
+
1296
+ def forward(self, x: Tensor) -> Tensor:
1297
+ """
1298
+ Forward pass through all error predictor branches.
1299
+
1300
+ Parameters
1301
+ ----------
1302
+ x : Tensor
1303
+ Combined input of shape (batch_size, input_dim).
1304
+
1305
+ Returns
1306
+ -------
1307
+ Tensor
1308
+ Concatenated error corrections [hue, value, chroma]
1309
+ of shape (batch_size, 3).
1310
+ """
1311
+ hue_error = self.hue_branch(x)
1312
+ value_error = self.value_branch(x)
1313
+ chroma_error = self.chroma_branch(x)
1314
+ return torch.cat([hue_error, value_error, chroma_error], dim=1)
1315
+
1316
+
1317
+ class MultiResNetErrorPredictorToMunsell(nn.Module):
1318
+ """
1319
+ Multi-ResNet error predictor for xyY to Munsell conversion.
1320
+
1321
+ Uses 4 independent ComponentResNet branches with true skip connections,
1322
+ one for each Munsell component error.
1323
+
1324
+ Parameters
1325
+ ----------
1326
+ hidden_dim : int, optional
1327
+ Hidden dimension for residual blocks. Default is 256.
1328
+ num_blocks : int, optional
1329
+ Number of residual blocks per branch. Default is 4.
1330
+ chroma_hidden_dim : int, optional
1331
+ Hidden dimension for chroma branch. Default is 384.
1332
+
1333
+ Attributes
1334
+ ----------
1335
+ hue_branch : ComponentResNet
1336
+ ResNet error predictor for hue component.
1337
+ value_branch : ComponentResNet
1338
+ ResNet error predictor for value component.
1339
+ chroma_branch : ComponentResNet
1340
+ ResNet error predictor for chroma component.
1341
+ code_branch : ComponentResNet
1342
+ ResNet error predictor for code component.
1343
+ """
1344
+
1345
+ def __init__(
1346
+ self,
1347
+ hidden_dim: int = 256,
1348
+ num_blocks: int = 4,
1349
+ chroma_hidden_dim: int = 384,
1350
+ ) -> None:
1351
+ """Initialize the multi-ResNet error predictor."""
1352
+ super().__init__()
1353
+
1354
+ # Input: xyY (3) + base prediction (4) = 7
1355
+ self.hue_branch = ComponentResNet(
1356
+ input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
1357
+ )
1358
+ self.value_branch = ComponentResNet(
1359
+ input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
1360
+ )
1361
+ self.chroma_branch = ComponentResNet(
1362
+ input_dim=7, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks
1363
+ )
1364
+ self.code_branch = ComponentResNet(
1365
+ input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
1366
+ )
1367
+
1368
+ def forward(self, x: Tensor) -> Tensor:
1369
+ """
1370
+ Forward pass through all error predictor branches.
1371
+
1372
+ Parameters
1373
+ ----------
1374
+ x : Tensor
1375
+ Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
1376
+
1377
+ Returns
1378
+ -------
1379
+ Tensor
1380
+ Concatenated error corrections [hue, value, chroma, code]
1381
+ of shape (batch_size, 4).
1382
+ """
1383
+ hue_error = self.hue_branch(x)
1384
+ value_error = self.value_branch(x)
1385
+ chroma_error = self.chroma_branch(x)
1386
+ code_error = self.code_branch(x)
1387
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
1388
+
1389
+
1390
+ # =============================================================================
1391
+ # Composite Models: Munsell → xyY
1392
+ # =============================================================================
1393
+
1394
+
1395
+ class MultiMLPToxyY(nn.Module):
1396
+ """
1397
+ Multi-MLP for Munsell to xyY conversion.
1398
+
1399
+ Uses 3 independent ComponentMLP branches, one for each xyY component.
1400
+
1401
+ Parameters
1402
+ ----------
1403
+ width_multiplier : float, optional
1404
+ Width multiplier for x and y branches. Default is 1.0.
1405
+ y_width_multiplier : float, optional
1406
+ Width multiplier for Y (luminance) branch. Default is 1.25.
1407
+
1408
+ Attributes
1409
+ ----------
1410
+ x_branch : ComponentMLP
1411
+ MLP for x chromaticity component.
1412
+ y_branch : ComponentMLP
1413
+ MLP for y chromaticity component.
1414
+ Y_branch : ComponentMLP
1415
+ MLP for Y luminance component.
1416
+ """
1417
+
1418
+ def __init__(
1419
+ self, width_multiplier: float = 1.0, y_width_multiplier: float = 1.25
1420
+ ) -> None:
1421
+ """Initialize the multi-MLP model."""
1422
+ super().__init__()
1423
+
1424
+ self.x_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier)
1425
+ self.y_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier)
1426
+ self.Y_branch = ComponentMLP(input_dim=4, width_multiplier=y_width_multiplier)
1427
+
1428
+ def forward(self, munsell: Tensor) -> Tensor:
1429
+ """
1430
+ Forward pass through all branches.
1431
+
1432
+ Parameters
1433
+ ----------
1434
+ munsell : Tensor
1435
+ Normalized Munsell specification [hue, value, chroma, code]
1436
+ of shape (batch_size, 4).
1437
+
1438
+ Returns
1439
+ -------
1440
+ Tensor
1441
+ Predicted xyY values [x, y, Y] of shape (batch_size, 3).
1442
+ """
1443
+ x = self.x_branch(munsell)
1444
+ y = self.y_branch(munsell)
1445
+ Y = self.Y_branch(munsell)
1446
+ return torch.cat([x, y, Y], dim=1)
1447
+
1448
+
1449
+ # =============================================================================
1450
+ # Error Predictors: Munsell → xyY
1451
+ # =============================================================================
1452
+
1453
+
1454
+ class MultiMLPErrorPredictorToxyY(nn.Module):
1455
+ """
1456
+ Multi-MLP error predictor for Munsell to xyY conversion.
1457
+
1458
+ Uses 3 independent ComponentErrorPredictor branches, one for each
1459
+ xyY component error.
1460
+
1461
+ Parameters
1462
+ ----------
1463
+ width_multiplier : float, optional
1464
+ Width multiplier for all branches. Default is 1.0.
1465
+
1466
+ Attributes
1467
+ ----------
1468
+ x_branch : ComponentErrorPredictor
1469
+ Error predictor for x chromaticity component.
1470
+ y_branch : ComponentErrorPredictor
1471
+ Error predictor for y chromaticity component.
1472
+ Y_branch : ComponentErrorPredictor
1473
+ Error predictor for Y luminance component.
1474
+ """
1475
+
1476
+ def __init__(self, width_multiplier: float = 1.0) -> None:
1477
+ """Initialize the multi-head error predictor."""
1478
+ super().__init__()
1479
+
1480
+ self.x_branch = ComponentErrorPredictor(
1481
+ input_dim=7, width_multiplier=width_multiplier
1482
+ )
1483
+ self.y_branch = ComponentErrorPredictor(
1484
+ input_dim=7, width_multiplier=width_multiplier
1485
+ )
1486
+ self.Y_branch = ComponentErrorPredictor(
1487
+ input_dim=7, width_multiplier=width_multiplier
1488
+ )
1489
+
1490
+ def forward(self, combined_input: Tensor) -> Tensor:
1491
+ """
1492
+ Forward pass through all error predictor branches.
1493
+
1494
+ Parameters
1495
+ ----------
1496
+ combined_input : Tensor
1497
+ Combined input [munsell_norm, base_pred] of shape (batch_size, 7).
1498
+
1499
+ Returns
1500
+ -------
1501
+ Tensor
1502
+ Concatenated error corrections [x, y, Y] of shape (batch_size, 3).
1503
+ """
1504
+ x_error = self.x_branch(combined_input)
1505
+ y_error = self.y_branch(combined_input)
1506
+ Y_error = self.Y_branch(combined_input)
1507
+ return torch.cat([x_error, y_error, Y_error], dim=1)
learning_munsell/training/from_xyY/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training scripts for xyY to Munsell conversion."""
learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-Error Predictor using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Chroma width multiplier
8
+ - Loss function weights (MSE, MAE, log penalty, Huber)
9
+ - Huber delta
10
+ - Dropout
11
+
12
+ Objective: Minimize validation loss
13
+ """
14
+
15
+ import logging
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+
19
+ import mlflow
20
+ import numpy as np
21
+ import onnxruntime as ort
22
+ import optuna
23
+ import torch
24
+ from optuna.trial import Trial
25
+ from torch import nn, optim
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+
28
+ from learning_munsell import PROJECT_ROOT
29
+ from learning_munsell.models.networks import (
30
+ MultiMLPErrorPredictorToMunsell,
31
+ )
32
+ from learning_munsell.utilities.common import setup_mlflow_experiment
33
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ def precision_focused_loss(
39
+ pred: torch.Tensor,
40
+ target: torch.Tensor,
41
+ mse_weight: float = 1.0,
42
+ mae_weight: float = 0.5,
43
+ log_weight: float = 0.3,
44
+ huber_weight: float = 0.5,
45
+ huber_delta: float = 0.01,
46
+ ) -> torch.Tensor:
47
+ """
48
+ Precision-focused loss function with configurable weights.
49
+
50
+ Combines multiple loss components to encourage accurate error prediction:
51
+ - MSE: Standard mean squared error
52
+ - MAE: Mean absolute error for robustness
53
+ - Log penalty: Penalizes small errors more heavily
54
+ - Huber loss: Robust to outliers with adjustable delta
55
+
56
+ Parameters
57
+ ----------
58
+ pred : torch.Tensor
59
+ Predicted values, shape (batch_size, n_components).
60
+ target : torch.Tensor
61
+ Target values, shape (batch_size, n_components).
62
+ mse_weight : float, optional
63
+ Weight for MSE component. Default is 1.0.
64
+ mae_weight : float, optional
65
+ Weight for MAE component. Default is 0.5.
66
+ log_weight : float, optional
67
+ Weight for logarithmic penalty component. Default is 0.3.
68
+ huber_weight : float, optional
69
+ Weight for Huber loss component. Default is 0.5.
70
+ huber_delta : float, optional
71
+ Delta parameter for Huber loss transition point. Default is 0.01.
72
+
73
+ Returns
74
+ -------
75
+ torch.Tensor
76
+ Weighted combination of loss components, scalar tensor.
77
+ """
78
+
79
+ mse = torch.mean((pred - target) ** 2)
80
+ mae = torch.mean(torch.abs(pred - target))
81
+ log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
82
+
83
+ abs_error = torch.abs(pred - target)
84
+ huber = torch.where(
85
+ abs_error <= huber_delta,
86
+ 0.5 * abs_error**2,
87
+ huber_delta * (abs_error - 0.5 * huber_delta),
88
+ )
89
+ huber_loss = torch.mean(huber)
90
+
91
+ return (
92
+ mse_weight * mse
93
+ + mae_weight * mae
94
+ + log_weight * log_penalty
95
+ + huber_weight * huber_loss
96
+ )
97
+
98
+
99
+ def load_base_model(
100
+ model_path: Path, params_path: Path
101
+ ) -> tuple[ort.InferenceSession, dict, dict]:
102
+ """
103
+ Load the base ONNX model and its normalization parameters.
104
+
105
+ Parameters
106
+ ----------
107
+ model_path : Path
108
+ Path to the base model ONNX file.
109
+ params_path : Path
110
+ Path to the normalization parameters NPZ file.
111
+
112
+ Returns
113
+ -------
114
+ ort.InferenceSession
115
+ ONNX Runtime inference session for the base model.
116
+ dict
117
+ Input normalization parameters (x_range, y_range, Y_range).
118
+ dict
119
+ Output normalization parameters (hue_range, value_range,
120
+ chroma_range, code_range).
121
+ """
122
+ session = ort.InferenceSession(str(model_path))
123
+ params = np.load(params_path, allow_pickle=True)
124
+ return (
125
+ session,
126
+ params["input_parameters"].item(),
127
+ params["output_parameters"].item(),
128
+ )
129
+
130
+
131
+ def train_epoch(
132
+ model: nn.Module,
133
+ dataloader: DataLoader,
134
+ optimizer: optim.Optimizer,
135
+ device: torch.device,
136
+ loss_params: dict[str, float],
137
+ ) -> float:
138
+ """
139
+ Train the model for one epoch.
140
+
141
+ Parameters
142
+ ----------
143
+ model : nn.Module
144
+ Error predictor model to train.
145
+ dataloader : DataLoader
146
+ DataLoader providing training batches.
147
+ optimizer : optim.Optimizer
148
+ Optimizer for updating model parameters.
149
+ device : torch.device
150
+ Device to run training on (CPU, CUDA, or MPS).
151
+ loss_params : dict of str to float
152
+ Parameters for precision_focused_loss function.
153
+
154
+ Returns
155
+ -------
156
+ float
157
+ Average training loss over the epoch.
158
+ """
159
+ model.train()
160
+ total_loss = 0.0
161
+
162
+ for X_batch, y_batch in dataloader:
163
+ X_batch = X_batch.to(device) # noqa: PLW2901
164
+ y_batch = y_batch.to(device) # noqa: PLW2901
165
+ outputs = model(X_batch)
166
+ loss = precision_focused_loss(outputs, y_batch, **loss_params)
167
+
168
+ optimizer.zero_grad()
169
+ loss.backward()
170
+ optimizer.step()
171
+
172
+ total_loss += loss.item()
173
+
174
+ return total_loss / len(dataloader)
175
+
176
+
177
+ def validate(
178
+ model: nn.Module,
179
+ dataloader: DataLoader,
180
+ device: torch.device,
181
+ loss_params: dict[str, float],
182
+ ) -> float:
183
+ """
184
+ Validate the model on the validation set.
185
+
186
+ Parameters
187
+ ----------
188
+ model : nn.Module
189
+ Error predictor model to validate.
190
+ dataloader : DataLoader
191
+ DataLoader providing validation batches.
192
+ device : torch.device
193
+ Device to run validation on (CPU, CUDA, or MPS).
194
+ loss_params : dict of str to float
195
+ Parameters for precision_focused_loss function.
196
+
197
+ Returns
198
+ -------
199
+ float
200
+ Average validation loss.
201
+ """
202
+ model.eval()
203
+ total_loss = 0.0
204
+
205
+ with torch.no_grad():
206
+ for X_batch, y_batch in dataloader:
207
+ X_batch = X_batch.to(device) # noqa: PLW2901
208
+ y_batch = y_batch.to(device) # noqa: PLW2901
209
+ outputs = model(X_batch)
210
+ loss = precision_focused_loss(outputs, y_batch, **loss_params)
211
+
212
+ total_loss += loss.item()
213
+
214
+ return total_loss / len(dataloader)
215
+
216
+
217
+ def objective(trial: Trial) -> float:
218
+ """
219
+ Optuna objective function to minimize validation loss.
220
+
221
+ This function defines the hyperparameter search space and training
222
+ procedure for each trial. It optimizes:
223
+ - Learning rate (5e-4 to 1e-3, log scale)
224
+ - Batch size (512 or 1024)
225
+ - Chroma branch width multiplier (1.0 to 1.5)
226
+ - Dropout rate (0.1 to 0.2)
227
+ - Loss function weights (MSE, Huber)
228
+ - Huber delta parameter (0.01 to 0.05)
229
+
230
+ Parameters
231
+ ----------
232
+ trial : Trial
233
+ Optuna trial object for suggesting hyperparameters.
234
+
235
+ Returns
236
+ -------
237
+ float
238
+ Best validation loss achieved during training.
239
+
240
+ Raises
241
+ ------
242
+ FileNotFoundError
243
+ If base model or training data files are not found.
244
+ optuna.TrialPruned
245
+ If trial is pruned based on intermediate results.
246
+ """
247
+
248
+ # Hyperparameters to optimize - constrained based on Trial 0 insights
249
+ lr = trial.suggest_float("lr", 5e-4, 1e-3, log=True) # Higher LR worked well
250
+ batch_size = trial.suggest_categorical(
251
+ "batch_size", [512, 1024]
252
+ ) # Smaller batches better
253
+ chroma_width = trial.suggest_float(
254
+ "chroma_width", 1.0, 1.5, step=0.25
255
+ ) # Smaller worked
256
+ dropout = trial.suggest_float("dropout", 0.1, 0.2, step=0.05)
257
+
258
+ # Simplified loss - just MSE + optional small Huber (no log penalty!)
259
+ mse_weight = trial.suggest_float("mse_weight", 1.0, 2.0, step=0.25)
260
+ huber_weight = trial.suggest_float("huber_weight", 0.0, 0.5, step=0.25)
261
+ huber_delta = trial.suggest_float("huber_delta", 0.01, 0.05, step=0.01)
262
+
263
+ loss_params = {
264
+ "mse_weight": mse_weight,
265
+ "mae_weight": 0.0, # Fixed at 0
266
+ "log_weight": 0.0, # Fixed at 0 (was causing scale issues)
267
+ "huber_weight": huber_weight,
268
+ "huber_delta": huber_delta,
269
+ }
270
+
271
+ LOGGER.info("")
272
+ LOGGER.info("=" * 80)
273
+ LOGGER.info("Trial %d", trial.number)
274
+ LOGGER.info("=" * 80)
275
+ LOGGER.info(" lr: %.6f", lr)
276
+ LOGGER.info(" batch_size: %d", batch_size)
277
+ LOGGER.info(" chroma_width: %.2f", chroma_width)
278
+ LOGGER.info(" dropout: %.2f", dropout)
279
+ LOGGER.info(" mse_weight: %.2f", mse_weight)
280
+ LOGGER.info(" huber_weight: %.2f", huber_weight)
281
+ LOGGER.info(" huber_delta: %.3f", huber_delta)
282
+
283
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
284
+
285
+ # Load base model and data
286
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
287
+ data_dir = PROJECT_ROOT / "data"
288
+
289
+ base_model_path = model_dir / "multi_mlp.onnx"
290
+ params_path = model_dir / "multi_mlp_normalization_parameters.npz"
291
+ cache_file = data_dir / "training_data.npz"
292
+
293
+ if not base_model_path.exists():
294
+ msg = f"Base model not found: {base_model_path}"
295
+ raise FileNotFoundError(msg)
296
+
297
+ base_session, input_parameters, output_parameters = load_base_model(
298
+ base_model_path, params_path
299
+ )
300
+
301
+ # Load data
302
+ data = np.load(cache_file)
303
+ X_train = data["X_train"]
304
+ y_train = data["y_train"]
305
+ X_val = data["X_val"]
306
+ y_val = data["y_val"]
307
+
308
+ # Normalize and generate base predictions
309
+ X_train_norm = normalize_xyY(X_train, input_parameters)
310
+ y_train_norm = normalize_munsell(y_train, output_parameters)
311
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
312
+
313
+ X_val_norm = normalize_xyY(X_val, input_parameters)
314
+ y_val_norm = normalize_munsell(y_val, output_parameters)
315
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
316
+
317
+ # Compute errors
318
+ error_train = y_train_norm - base_pred_train_norm
319
+ error_val = y_val_norm - base_pred_val_norm
320
+
321
+ # Combined input
322
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
323
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
324
+
325
+ # PyTorch tensors
326
+ X_train_t = torch.FloatTensor(X_train_combined)
327
+ error_train_t = torch.FloatTensor(error_train)
328
+ X_val_t = torch.FloatTensor(X_val_combined)
329
+ error_val_t = torch.FloatTensor(error_val)
330
+
331
+ # Data loaders
332
+ train_dataset = TensorDataset(X_train_t, error_train_t)
333
+ val_dataset = TensorDataset(X_val_t, error_val_t)
334
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
335
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
336
+
337
+ # Initialize model
338
+ model = MultiMLPErrorPredictorToMunsell(
339
+ chroma_width=chroma_width, dropout=dropout
340
+ ).to(device)
341
+
342
+ total_params = sum(p.numel() for p in model.parameters())
343
+ LOGGER.info(" Total parameters: %s", f"{total_params:,}")
344
+
345
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
346
+
347
+ # MLflow setup
348
+ run_name = setup_mlflow_experiment(
349
+ "from_xyY", f"hparam_error_predictor_trial_{trial.number}"
350
+ )
351
+
352
+ # Training loop
353
+ num_epochs = 100
354
+ patience = 15
355
+ best_val_loss = float("inf")
356
+ patience_counter = 0
357
+
358
+ with mlflow.start_run(run_name=run_name):
359
+ mlflow.log_params(
360
+ {
361
+ "trial": trial.number,
362
+ "lr": lr,
363
+ "batch_size": batch_size,
364
+ "chroma_width": chroma_width,
365
+ "dropout": dropout,
366
+ "mse_weight": mse_weight,
367
+ "huber_weight": huber_weight,
368
+ "huber_delta": huber_delta,
369
+ "total_params": total_params,
370
+ }
371
+ )
372
+
373
+ for epoch in range(num_epochs):
374
+ train_loss = train_epoch(
375
+ model, train_loader, optimizer, device, loss_params
376
+ )
377
+ val_loss = validate(model, val_loader, device, loss_params)
378
+
379
+ mlflow.log_metrics(
380
+ {
381
+ "train_loss": train_loss,
382
+ "val_loss": val_loss,
383
+ },
384
+ step=epoch,
385
+ )
386
+
387
+ if (epoch + 1) % 10 == 0:
388
+ LOGGER.info(
389
+ " Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
390
+ epoch + 1,
391
+ num_epochs,
392
+ train_loss,
393
+ val_loss,
394
+ )
395
+
396
+ if val_loss < best_val_loss:
397
+ best_val_loss = val_loss
398
+ patience_counter = 0
399
+ else:
400
+ patience_counter += 1
401
+ if patience_counter >= patience:
402
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
403
+ break
404
+
405
+ trial.report(val_loss, epoch)
406
+
407
+ if trial.should_prune():
408
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
409
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
410
+ raise optuna.TrialPruned
411
+
412
+ # Log final results
413
+ mlflow.log_metrics(
414
+ {
415
+ "best_val_loss": best_val_loss,
416
+ "final_train_loss": train_loss,
417
+ "final_epoch": epoch + 1,
418
+ }
419
+ )
420
+
421
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
422
+
423
+ return best_val_loss
424
+
425
+
426
+ def main() -> None:
427
+ """
428
+ Run hyperparameter search for Multi-MLP Error Predictor.
429
+
430
+ Performs systematic hyperparameter optimization using Optuna with:
431
+ - MedianPruner for early stopping of unpromising trials
432
+ - 15 total trials
433
+ - MLflow logging for each trial
434
+ - Result visualization and saving
435
+
436
+ The search aims to find optimal hyperparameters for predicting errors
437
+ in a base Munsell prediction model, which can then be used to improve
438
+ predictions by correcting systematic biases.
439
+ """
440
+
441
+ LOGGER.info("=" * 80)
442
+ LOGGER.info("Multi-Error Predictor Hyperparameter Search with Optuna")
443
+ LOGGER.info("=" * 80)
444
+
445
+ study = optuna.create_study(
446
+ direction="minimize",
447
+ study_name="multi_mlp_error_predictor_hparam_search",
448
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
449
+ )
450
+
451
+ n_trials = 15
452
+
453
+ LOGGER.info("")
454
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
455
+ LOGGER.info("")
456
+
457
+ study.optimize(objective, n_trials=n_trials, timeout=None)
458
+
459
+ # Print results
460
+ LOGGER.info("")
461
+ LOGGER.info("=" * 80)
462
+ LOGGER.info("Hyperparameter Search Results")
463
+ LOGGER.info("=" * 80)
464
+ LOGGER.info("")
465
+ LOGGER.info("Best trial:")
466
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
467
+ LOGGER.info("")
468
+ LOGGER.info("Best hyperparameters:")
469
+ for key, value in study.best_params.items():
470
+ LOGGER.info(" %s: %s", key, value)
471
+
472
+ # Save results
473
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
474
+ results_dir.mkdir(exist_ok=True)
475
+
476
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
477
+ results_file = results_dir / f"error_predictor_hparam_search_{timestamp}.txt"
478
+
479
+ with open(results_file, "w") as f:
480
+ f.write("=" * 80 + "\n")
481
+ f.write("Multi-Error Predictor Hyperparameter Search Results\n")
482
+ f.write("=" * 80 + "\n\n")
483
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
484
+ f.write(f"Number of trials: {len(study.trials)}\n")
485
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
486
+ f.write("Best hyperparameters:\n")
487
+ for key, value in study.best_params.items():
488
+ f.write(f" {key}: {value}\n")
489
+ f.write("\n\nAll trials:\n")
490
+ f.write("-" * 80 + "\n")
491
+
492
+ for trial in study.trials:
493
+ f.write(f"\nTrial {trial.number}:\n")
494
+ f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n")
495
+ f.write(" Params:\n")
496
+ for key, value in trial.params.items():
497
+ f.write(f" {key}: {value}\n")
498
+
499
+ LOGGER.info("")
500
+ LOGGER.info("Results saved to: %s", results_file)
501
+
502
+
503
+ if __name__ == "__main__":
504
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
505
+
506
+ main()
learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-Head model (xyY to Munsell) using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Encoder width multiplier (shared encoder capacity)
8
+ - Head width multiplier (component-specific head capacity)
9
+ - Chroma head width (specialized for chroma prediction)
10
+ - Dropout
11
+ - Weight decay
12
+
13
+ Objective: Minimize validation loss
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from datetime import datetime
20
+ from typing import TYPE_CHECKING
21
+
22
+ import matplotlib.pyplot as plt
23
+ import mlflow
24
+ import numpy as np
25
+ import optuna
26
+ import torch
27
+ from torch import nn, optim
28
+ from torch.utils.data import DataLoader, TensorDataset
29
+
30
+ from learning_munsell import PROJECT_ROOT
31
+ from learning_munsell.utilities.common import setup_mlflow_experiment
32
+ from learning_munsell.utilities.data import (
33
+ MUNSELL_NORMALIZATION_PARAMS,
34
+ normalize_munsell,
35
+ )
36
+ from learning_munsell.utilities.losses import weighted_mse_loss
37
+ from learning_munsell.utilities.training import train_epoch, validate
38
+
39
+ if TYPE_CHECKING:
40
+ from optuna.trial import Trial
41
+
42
+ LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class MultiHeadParametric(nn.Module):
46
+ """
47
+ Parametric Multi-Head model for hyperparameter search (xyY to Munsell).
48
+
49
+ This model uses a shared encoder to extract general color space features
50
+ from xyY inputs, followed by component-specific heads for predicting
51
+ each Munsell component independently.
52
+
53
+ Architecture:
54
+ - Shared encoder: 3 → h1 → h2 → h3 (scaled by encoder_width)
55
+ - hue, value, code heads: h3 → h2' → h1' → 1 (scaled by head_width)
56
+ - chroma head: h3 → h2'' → h1'' → 1 (scaled by chroma_head_width)
57
+
58
+ Parameters
59
+ ----------
60
+ encoder_width : float, optional
61
+ Width multiplier for shared encoder layers. Default is 1.0.
62
+ Base dimensions: h1=128, h2=256, h3=512.
63
+ head_width : float, optional
64
+ Width multiplier for hue, value, and code heads. Default is 1.0.
65
+ Base dimensions: h1=128, h2=256.
66
+ chroma_head_width : float, optional
67
+ Width multiplier for chroma head (typically wider). Default is 1.0.
68
+ Base dimensions: h1=128, h2=256, h3=384.
69
+ dropout : float, optional
70
+ Dropout rate applied after hidden layers. Default is 0.0.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ encoder_width: float = 1.0,
76
+ head_width: float = 1.0,
77
+ chroma_head_width: float = 1.0,
78
+ dropout: float = 0.0,
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ # Encoder dimensions (shared)
83
+ e_h1 = int(128 * encoder_width)
84
+ e_h2 = int(256 * encoder_width)
85
+ e_h3 = int(512 * encoder_width)
86
+
87
+ # Head dimensions (component-specific)
88
+ h_h1 = int(128 * head_width)
89
+ h_h2 = int(256 * head_width)
90
+
91
+ # Chroma head dimensions (specialized)
92
+ c_h1 = int(128 * chroma_head_width)
93
+ c_h2 = int(256 * chroma_head_width)
94
+ c_h3 = int(384 * chroma_head_width)
95
+
96
+ # Shared encoder - learns general color space features
97
+ encoder_layers = [
98
+ nn.Linear(3, e_h1),
99
+ nn.ReLU(),
100
+ nn.BatchNorm1d(e_h1),
101
+ ]
102
+
103
+ if dropout > 0:
104
+ encoder_layers.append(nn.Dropout(dropout))
105
+
106
+ encoder_layers.extend(
107
+ [
108
+ nn.Linear(e_h1, e_h2),
109
+ nn.ReLU(),
110
+ nn.BatchNorm1d(e_h2),
111
+ ]
112
+ )
113
+
114
+ if dropout > 0:
115
+ encoder_layers.append(nn.Dropout(dropout))
116
+
117
+ encoder_layers.extend(
118
+ [
119
+ nn.Linear(e_h2, e_h3),
120
+ nn.ReLU(),
121
+ nn.BatchNorm1d(e_h3),
122
+ ]
123
+ )
124
+
125
+ if dropout > 0:
126
+ encoder_layers.append(nn.Dropout(dropout))
127
+
128
+ self.encoder = nn.Sequential(*encoder_layers)
129
+
130
+ # Component-specific heads (hue, value, code)
131
+ def create_head() -> nn.Sequential:
132
+ head_layers = [
133
+ nn.Linear(e_h3, h_h2),
134
+ nn.ReLU(),
135
+ nn.BatchNorm1d(h_h2),
136
+ ]
137
+
138
+ if dropout > 0:
139
+ head_layers.append(nn.Dropout(dropout))
140
+
141
+ head_layers.extend(
142
+ [
143
+ nn.Linear(h_h2, h_h1),
144
+ nn.ReLU(),
145
+ nn.BatchNorm1d(h_h1),
146
+ ]
147
+ )
148
+
149
+ if dropout > 0:
150
+ head_layers.append(nn.Dropout(dropout))
151
+
152
+ head_layers.append(nn.Linear(h_h1, 1))
153
+
154
+ return nn.Sequential(*head_layers)
155
+
156
+ self.hue_head = create_head()
157
+ self.value_head = create_head()
158
+ self.code_head = create_head()
159
+
160
+ # Chroma head - wider for harder task
161
+ chroma_layers = [
162
+ nn.Linear(e_h3, c_h3),
163
+ nn.ReLU(),
164
+ nn.BatchNorm1d(c_h3),
165
+ ]
166
+
167
+ if dropout > 0:
168
+ chroma_layers.append(nn.Dropout(dropout))
169
+
170
+ chroma_layers.extend(
171
+ [
172
+ nn.Linear(c_h3, c_h2),
173
+ nn.ReLU(),
174
+ nn.BatchNorm1d(c_h2),
175
+ ]
176
+ )
177
+
178
+ if dropout > 0:
179
+ chroma_layers.append(nn.Dropout(dropout))
180
+
181
+ chroma_layers.extend(
182
+ [
183
+ nn.Linear(c_h2, c_h1),
184
+ nn.ReLU(),
185
+ nn.BatchNorm1d(c_h1),
186
+ ]
187
+ )
188
+
189
+ if dropout > 0:
190
+ chroma_layers.append(nn.Dropout(dropout))
191
+
192
+ chroma_layers.append(nn.Linear(c_h1, 1))
193
+
194
+ self.chroma_head = nn.Sequential(*chroma_layers)
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ """
198
+ Forward pass through shared encoder and component-specific heads.
199
+
200
+ Parameters
201
+ ----------
202
+ x : torch.Tensor
203
+ Input tensor of shape (batch_size, 3) containing normalized
204
+ xyY values.
205
+
206
+ Returns
207
+ -------
208
+ torch.Tensor
209
+ Predicted Munsell components, shape (batch_size, 4).
210
+ Output order: [hue, value, chroma, code].
211
+ """
212
+ # Shared feature extraction
213
+ features = self.encoder(x)
214
+
215
+ # Component-specific predictions
216
+ hue = self.hue_head(features)
217
+ value = self.value_head(features)
218
+ chroma = self.chroma_head(features)
219
+ code = self.code_head(features)
220
+
221
+ # Concatenate: [hue, value, chroma, code]
222
+ return torch.cat([hue, value, chroma, code], dim=1)
223
+
224
+
225
+ def objective(trial: Trial) -> float:
226
+ """
227
+ Optuna objective function to minimize validation loss.
228
+
229
+ This function defines the hyperparameter search space and training
230
+ procedure for each trial. It optimizes:
231
+ - Learning rate (1e-4 to 1e-3, log scale)
232
+ - Batch size (256, 512, or 1024)
233
+ - Encoder width multiplier (0.75 to 1.5)
234
+ - Head width multiplier (0.75 to 1.5)
235
+ - Chroma head width multiplier (1.0 to 1.75)
236
+ - Dropout rate (0.0 to 0.2)
237
+ - Weight decay (1e-5 to 1e-3, log scale)
238
+
239
+ Parameters
240
+ ----------
241
+ trial : Trial
242
+ Optuna trial object for suggesting hyperparameters.
243
+
244
+ Returns
245
+ -------
246
+ float
247
+ Best validation loss achieved during training.
248
+
249
+ Raises
250
+ ------
251
+ optuna.TrialPruned
252
+ If trial is pruned based on intermediate results.
253
+ """
254
+
255
+ # Suggest hyperparameters
256
+ lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
257
+ batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024])
258
+ encoder_width = trial.suggest_float("encoder_width", 0.75, 1.5, step=0.25)
259
+ head_width = trial.suggest_float("head_width", 0.75, 1.5, step=0.25)
260
+ chroma_head_width = trial.suggest_float("chroma_head_width", 1.0, 1.75, step=0.25)
261
+ dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05)
262
+ weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
263
+
264
+ LOGGER.info("")
265
+ LOGGER.info("=" * 80)
266
+ LOGGER.info("Trial %d", trial.number)
267
+ LOGGER.info("=" * 80)
268
+ LOGGER.info(" lr: %.6f", lr)
269
+ LOGGER.info(" batch_size: %d", batch_size)
270
+ LOGGER.info(" encoder_width: %.2f", encoder_width)
271
+ LOGGER.info(" head_width: %.2f", head_width)
272
+ LOGGER.info(" chroma_head_width: %.2f", chroma_head_width)
273
+ LOGGER.info(" dropout: %.2f", dropout)
274
+ LOGGER.info(" weight_decay: %.6f", weight_decay)
275
+
276
+ # Set device
277
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
278
+ LOGGER.info(" device: %s", device)
279
+
280
+ # Load data
281
+ data_dir = PROJECT_ROOT / "data"
282
+ cache_file = data_dir / "training_data.npz"
283
+ data = np.load(cache_file)
284
+
285
+ X_train = data["X_train"]
286
+ y_train = data["y_train"]
287
+ X_val = data["X_val"]
288
+ y_val = data["y_val"]
289
+
290
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
291
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
292
+ y_train_norm = normalize_munsell(y_train, output_parameters)
293
+ y_val_norm = normalize_munsell(y_val, output_parameters)
294
+
295
+ # Convert to tensors
296
+ X_train_t = torch.from_numpy(X_train).float()
297
+ y_train_t = torch.from_numpy(y_train_norm).float()
298
+ X_val_t = torch.from_numpy(X_val).float()
299
+ y_val_t = torch.from_numpy(y_val_norm).float()
300
+
301
+ train_loader = DataLoader(
302
+ TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True
303
+ )
304
+ val_loader = DataLoader(
305
+ TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False
306
+ )
307
+
308
+ LOGGER.info(
309
+ " Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t)
310
+ )
311
+
312
+ # Initialize model
313
+ model = MultiHeadParametric(
314
+ encoder_width=encoder_width,
315
+ head_width=head_width,
316
+ chroma_head_width=chroma_head_width,
317
+ dropout=dropout,
318
+ ).to(device)
319
+
320
+ # Count parameters
321
+ total_params = sum(p.numel() for p in model.parameters())
322
+ LOGGER.info(" Total parameters: %s", f"{total_params:,}")
323
+
324
+ # Training setup
325
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
326
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
327
+
328
+ # MLflow setup
329
+ run_name = setup_mlflow_experiment(
330
+ "from_xyY", f"hparam_multi_head_trial_{trial.number}"
331
+ )
332
+
333
+ # Training loop with early stopping
334
+ num_epochs = 100 # Reduced for hyperparameter search
335
+ patience = 15
336
+ best_val_loss = float("inf")
337
+ patience_counter = 0
338
+
339
+ with mlflow.start_run(run_name=run_name):
340
+ mlflow.log_params(
341
+ {
342
+ "trial": trial.number,
343
+ "lr": lr,
344
+ "batch_size": batch_size,
345
+ "encoder_width": encoder_width,
346
+ "head_width": head_width,
347
+ "chroma_head_width": chroma_head_width,
348
+ "dropout": dropout,
349
+ "weight_decay": weight_decay,
350
+ "total_params": total_params,
351
+ }
352
+ )
353
+
354
+ for epoch in range(num_epochs):
355
+ train_loss = train_epoch(
356
+ model, train_loader, optimizer, weighted_mse_loss, device
357
+ )
358
+ val_loss = validate(model, val_loader, weighted_mse_loss, device)
359
+ scheduler.step()
360
+
361
+ # Per-component MAE
362
+ with torch.no_grad():
363
+ pred_val = model(X_val_t.to(device))
364
+ mae = torch.mean(torch.abs(pred_val - y_val_t.to(device)), dim=0).cpu()
365
+
366
+ # Log to MLflow
367
+ mlflow.log_metrics(
368
+ {
369
+ "train_loss": train_loss,
370
+ "val_loss": val_loss,
371
+ "mae_hue": mae[0].item(),
372
+ "mae_value": mae[1].item(),
373
+ "mae_chroma": mae[2].item(),
374
+ "mae_code": mae[3].item(),
375
+ "learning_rate": optimizer.param_groups[0]["lr"],
376
+ },
377
+ step=epoch,
378
+ )
379
+
380
+ if (epoch + 1) % 10 == 0:
381
+ LOGGER.info(
382
+ " Epoch %03d/%d - Train: %.6f, Val: %.6f - "
383
+ "MAE: hue=%.6f, value=%.6f, chroma=%.6f, code=%.6f",
384
+ epoch + 1,
385
+ num_epochs,
386
+ train_loss,
387
+ val_loss,
388
+ mae[0],
389
+ mae[1],
390
+ mae[2],
391
+ mae[3],
392
+ )
393
+
394
+ # Early stopping
395
+ if val_loss < best_val_loss:
396
+ best_val_loss = val_loss
397
+ patience_counter = 0
398
+ else:
399
+ patience_counter += 1
400
+ if patience_counter >= patience:
401
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
402
+ break
403
+
404
+ # Report intermediate value for pruning
405
+ trial.report(val_loss, epoch)
406
+
407
+ # Handle pruning
408
+ if trial.should_prune():
409
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
410
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
411
+ raise optuna.TrialPruned
412
+
413
+ # Log final results
414
+ mlflow.log_metrics(
415
+ {
416
+ "best_val_loss": best_val_loss,
417
+ "final_train_loss": train_loss,
418
+ "final_mae_hue": mae[0].item(),
419
+ "final_mae_value": mae[1].item(),
420
+ "final_mae_chroma": mae[2].item(),
421
+ "final_mae_code": mae[3].item(),
422
+ "final_epoch": epoch + 1,
423
+ }
424
+ )
425
+
426
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
427
+
428
+ return best_val_loss
429
+
430
+
431
+ def main() -> None:
432
+ """
433
+ Run hyperparameter search for Multi-Head model (xyY to Munsell).
434
+
435
+ Performs systematic hyperparameter optimization using Optuna with:
436
+ - MedianPruner for early stopping of unpromising trials
437
+ - 20 total trials
438
+ - MLflow logging for each trial
439
+ - Result visualization using matplotlib (optimization history,
440
+ parameter importances, parallel coordinate plot)
441
+
442
+ The search aims to find optimal hyperparameters for converting xyY
443
+ color coordinates to Munsell color specifications using a multi-head
444
+ architecture with shared encoder and component-specific heads.
445
+ """
446
+
447
+ LOGGER.info("=" * 80)
448
+ LOGGER.info("Multi-Head (from_xyY) Hyperparameter Search with Optuna")
449
+ LOGGER.info("=" * 80)
450
+
451
+ # Create study
452
+ study = optuna.create_study(
453
+ direction="minimize",
454
+ study_name="multi_head_from_xyY_hparam_search",
455
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
456
+ )
457
+
458
+ # Run optimization
459
+ n_trials = 20 # Number of trials to run
460
+
461
+ LOGGER.info("")
462
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
463
+ LOGGER.info("")
464
+
465
+ study.optimize(objective, n_trials=n_trials, timeout=None)
466
+
467
+ # Print results
468
+ LOGGER.info("")
469
+ LOGGER.info("=" * 80)
470
+ LOGGER.info("Hyperparameter Search Results")
471
+ LOGGER.info("=" * 80)
472
+ LOGGER.info("")
473
+ LOGGER.info("Best trial:")
474
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
475
+ LOGGER.info("")
476
+ LOGGER.info("Best hyperparameters:")
477
+ for key, value in study.best_params.items():
478
+ LOGGER.info(" %s: %s", key, value)
479
+
480
+ # Save results
481
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
482
+ results_dir.mkdir(exist_ok=True, parents=True)
483
+
484
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
485
+ results_file = results_dir / f"hparam_search_multi_head_{timestamp}.txt"
486
+
487
+ with open(results_file, "w") as f:
488
+ f.write("=" * 80 + "\n")
489
+ f.write("Multi-Head (from_xyY) Hyperparameter Search Results\n")
490
+ f.write("=" * 80 + "\n\n")
491
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
492
+ f.write(f"Number of trials: {len(study.trials)}\n")
493
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
494
+ f.write("Best hyperparameters:\n")
495
+ for key, value in study.best_params.items():
496
+ f.write(f" {key}: {value}\n")
497
+ f.write("\n\nAll trials:\n")
498
+ f.write("-" * 80 + "\n")
499
+
500
+ for t in study.trials:
501
+ f.write(f"\nTrial {t.number}:\n")
502
+ if t.value is not None:
503
+ f.write(f" Value: {t.value:.6f}\n")
504
+ else:
505
+ f.write(" Value: Pruned\n")
506
+ f.write(" Params:\n")
507
+ for key, value in t.params.items():
508
+ f.write(f" {key}: {value}\n")
509
+
510
+ LOGGER.info("")
511
+ LOGGER.info("Results saved to: %s", results_file)
512
+
513
+ # Generate visualizations using matplotlib
514
+ from optuna.visualization.matplotlib import ( # noqa: PLC0415
515
+ plot_optimization_history,
516
+ plot_parallel_coordinate,
517
+ plot_param_importances,
518
+ )
519
+
520
+ # Optimization history
521
+ ax = plot_optimization_history(study)
522
+ ax.figure.savefig(
523
+ results_dir / f"optimization_history_multi_head_{timestamp}.png", dpi=150
524
+ )
525
+ plt.close(ax.figure)
526
+
527
+ # Parameter importances
528
+ ax = plot_param_importances(study)
529
+ ax.figure.savefig(
530
+ results_dir / f"param_importances_multi_head_{timestamp}.png", dpi=150
531
+ )
532
+ plt.close(ax.figure)
533
+
534
+ # Parallel coordinate plot
535
+ ax = plot_parallel_coordinate(study)
536
+ ax.figure.savefig(
537
+ results_dir / f"parallel_coordinate_multi_head_{timestamp}.png", dpi=150
538
+ )
539
+ plt.close(ax.figure)
540
+
541
+ LOGGER.info("Visualizations saved to: %s", results_dir)
542
+
543
+
544
+ if __name__ == "__main__":
545
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
546
+
547
+ main()
learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-Head Error Predictor using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Width multipliers for each component branch (hue, value, chroma, code)
8
+ - Loss function component weights
9
+
10
+ Objective: Minimize validation loss (combined base + error predictor)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from datetime import datetime
17
+ from typing import TYPE_CHECKING
18
+
19
+ import matplotlib.pyplot as plt
20
+ import mlflow
21
+ import numpy as np
22
+ import onnxruntime as ort
23
+ import optuna
24
+ import torch
25
+ from torch import nn, optim
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+
28
+ from learning_munsell import PROJECT_ROOT
29
+ from learning_munsell.models.networks import ComponentErrorPredictor
30
+ from learning_munsell.utilities.common import setup_mlflow_experiment
31
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
32
+ from learning_munsell.utilities.training import train_epoch, validate
33
+
34
+ if TYPE_CHECKING:
35
+ from collections.abc import Callable
36
+ from pathlib import Path
37
+
38
+ from optuna.trial import Trial
39
+
40
+ LOGGER = logging.getLogger(__name__)
41
+
42
+
43
+ class MultiHeadErrorPredictorParametric(nn.Module):
44
+ """
45
+ Parametric Multi-Head error predictor with 4 independent branches.
46
+
47
+ This model consists of four independent ComponentErrorPredictor
48
+ networks, one for each Munsell component (hue, value, chroma, code).
49
+ Each branch can have different widths for hyperparameter optimization.
50
+
51
+ Parameters
52
+ ----------
53
+ hue_width : float, optional
54
+ Width multiplier for the hue branch. Default is 1.0.
55
+ value_width : float, optional
56
+ Width multiplier for the value branch. Default is 1.0.
57
+ chroma_width : float, optional
58
+ Width multiplier for the chroma branch. Default is 1.5.
59
+ code_width : float, optional
60
+ Width multiplier for the code branch. Default is 1.0.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ hue_width: float = 1.0,
66
+ value_width: float = 1.0,
67
+ chroma_width: float = 1.5,
68
+ code_width: float = 1.0,
69
+ ) -> None:
70
+ super().__init__()
71
+
72
+ # Independent error predictor for each component
73
+ self.hue_branch = ComponentErrorPredictor(width_multiplier=hue_width)
74
+ self.value_branch = ComponentErrorPredictor(width_multiplier=value_width)
75
+ self.chroma_branch = ComponentErrorPredictor(width_multiplier=chroma_width)
76
+ self.code_branch = ComponentErrorPredictor(width_multiplier=code_width)
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Forward pass through all four error predictor branches.
81
+
82
+ Parameters
83
+ ----------
84
+ x : torch.Tensor
85
+ Input tensor of shape (batch_size, 7) containing normalized
86
+ xyY values and base model predictions.
87
+
88
+ Returns
89
+ -------
90
+ torch.Tensor
91
+ Predicted errors for all components, shape (batch_size, 4).
92
+ Output order: [hue_error, value_error, chroma_error, code_error].
93
+ """
94
+ # Each branch processes the same combined input independently
95
+ hue_error = self.hue_branch(x)
96
+ value_error = self.value_branch(x)
97
+ chroma_error = self.chroma_branch(x)
98
+ code_error = self.code_branch(x)
99
+
100
+ # Concatenate: [Hue_error, Value_error, Chroma_error, Code_error]
101
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
102
+
103
+
104
+ def load_base_model(
105
+ model_path: Path, params_path: Path
106
+ ) -> tuple[ort.InferenceSession, dict, dict]:
107
+ """
108
+ Load the base Multi-Head ONNX model and its normalization parameters.
109
+
110
+ Parameters
111
+ ----------
112
+ model_path : Path
113
+ Path to the base Multi-Head model ONNX file.
114
+ params_path : Path
115
+ Path to the normalization parameters NPZ file.
116
+
117
+ Returns
118
+ -------
119
+ ort.InferenceSession
120
+ ONNX Runtime inference session for the base model.
121
+ dict
122
+ Input normalization parameters (x_range, y_range, Y_range).
123
+ dict
124
+ Output normalization parameters (hue_range, value_range,
125
+ chroma_range, code_range).
126
+ """
127
+ session = ort.InferenceSession(str(model_path))
128
+ params = np.load(params_path, allow_pickle=True)
129
+ return (
130
+ session,
131
+ params["input_parameters"].item(),
132
+ params["output_parameters"].item(),
133
+ )
134
+
135
+
136
+ def create_weighted_loss(
137
+ mse_weight: float,
138
+ mae_weight: float,
139
+ log_weight: float,
140
+ huber_weight: float,
141
+ huber_delta: float,
142
+ ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
143
+ """
144
+ Create a weighted loss function combining multiple loss components.
145
+
146
+ Parameters
147
+ ----------
148
+ mse_weight : float
149
+ Weight for MSE component.
150
+ mae_weight : float
151
+ Weight for MAE component.
152
+ log_weight : float
153
+ Weight for logarithmic penalty component.
154
+ huber_weight : float
155
+ Weight for Huber loss component.
156
+ huber_delta : float
157
+ Delta parameter for Huber loss transition point.
158
+
159
+ Returns
160
+ -------
161
+ callable
162
+ Loss function that accepts (pred, target) and returns a scalar loss.
163
+ """
164
+
165
+ def weighted_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
166
+ """
167
+ Compute weighted combination of loss components.
168
+
169
+ Parameters
170
+ ----------
171
+ pred : torch.Tensor
172
+ Predicted values, shape (batch_size, n_components).
173
+ target : torch.Tensor
174
+ Target values, shape (batch_size, n_components).
175
+
176
+ Returns
177
+ -------
178
+ torch.Tensor
179
+ Weighted combination of loss components, scalar tensor.
180
+ """
181
+ # Standard MSE
182
+ mse = torch.mean((pred - target) ** 2)
183
+
184
+ # Mean absolute error
185
+ mae = torch.mean(torch.abs(pred - target))
186
+
187
+ # Logarithmic penalty
188
+ log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
189
+
190
+ # Huber loss
191
+ abs_error = torch.abs(pred - target)
192
+ huber = torch.where(
193
+ abs_error <= huber_delta,
194
+ 0.5 * abs_error**2,
195
+ huber_delta * (abs_error - 0.5 * huber_delta),
196
+ )
197
+ huber_loss = torch.mean(huber)
198
+
199
+ # Combine with weights
200
+ return (
201
+ mse_weight * mse
202
+ + mae_weight * mae
203
+ + log_weight * log_penalty
204
+ + huber_weight * huber_loss
205
+ )
206
+
207
+ return weighted_loss
208
+
209
+
210
+ def objective(trial: Trial) -> float:
211
+ """
212
+ Optuna objective function to minimize validation loss.
213
+
214
+ This function defines the hyperparameter search space and training
215
+ procedure for each trial. It optimizes:
216
+ - Learning rate (1e-4 to 1e-3, log scale)
217
+ - Batch size (512, 1024, or 2048)
218
+ - Width multipliers for each component branch
219
+ - Loss function weights (MSE, MAE, log penalty, Huber)
220
+ - Huber delta parameter (0.005 to 0.02)
221
+
222
+ Parameters
223
+ ----------
224
+ trial : Trial
225
+ Optuna trial object for suggesting hyperparameters.
226
+
227
+ Returns
228
+ -------
229
+ float
230
+ Best validation loss achieved during training.
231
+
232
+ Raises
233
+ ------
234
+ optuna.TrialPruned
235
+ If trial is pruned based on intermediate results.
236
+ """
237
+
238
+ # Suggest hyperparameters
239
+ lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
240
+ batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048])
241
+ hue_width = trial.suggest_float("hue_width", 0.75, 1.5, step=0.25)
242
+ value_width = trial.suggest_float("value_width", 0.75, 1.5, step=0.25)
243
+ chroma_width = trial.suggest_float("chroma_width", 1.0, 2.0, step=0.25)
244
+ code_width = trial.suggest_float("code_width", 0.75, 1.5, step=0.25)
245
+
246
+ # Loss function weights
247
+ mse_weight = trial.suggest_float("mse_weight", 0.5, 2.0, step=0.5)
248
+ mae_weight = trial.suggest_float("mae_weight", 0.0, 1.0, step=0.25)
249
+ log_weight = trial.suggest_float("log_weight", 0.0, 0.5, step=0.1)
250
+ huber_weight = trial.suggest_float("huber_weight", 0.0, 1.0, step=0.25)
251
+ huber_delta = trial.suggest_float("huber_delta", 0.005, 0.02, step=0.005)
252
+
253
+ LOGGER.info("")
254
+ LOGGER.info("=" * 80)
255
+ LOGGER.info("Trial %d", trial.number)
256
+ LOGGER.info("=" * 80)
257
+ LOGGER.info(" lr: %.6f", lr)
258
+ LOGGER.info(" batch_size: %d", batch_size)
259
+ LOGGER.info(" hue_width: %.2f", hue_width)
260
+ LOGGER.info(" value_width: %.2f", value_width)
261
+ LOGGER.info(" chroma_width: %.2f", chroma_width)
262
+ LOGGER.info(" code_width: %.2f", code_width)
263
+ LOGGER.info(" mse_weight: %.2f", mse_weight)
264
+ LOGGER.info(" mae_weight: %.2f", mae_weight)
265
+ LOGGER.info(" log_weight: %.2f", log_weight)
266
+ LOGGER.info(" huber_weight: %.2f", huber_weight)
267
+ LOGGER.info(" huber_delta: %.3f", huber_delta)
268
+
269
+ # Set device
270
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
271
+ LOGGER.info(" device: %s", device)
272
+
273
+ # Paths
274
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
275
+ data_dir = PROJECT_ROOT / "data"
276
+
277
+ base_model_path = model_directory / "multi_head.onnx"
278
+ params_path = model_directory / "multi_head_normalization_parameters.npz"
279
+ cache_file = data_dir / "training_data.npz"
280
+
281
+ # Load base model
282
+ base_session, input_parameters, output_parameters = load_base_model(
283
+ base_model_path, params_path
284
+ )
285
+
286
+ # Load training data
287
+ data = np.load(cache_file)
288
+ X_train = data["X_train"]
289
+ y_train = data["y_train"]
290
+ X_val = data["X_val"]
291
+ y_val = data["y_val"]
292
+
293
+ # Normalize
294
+ X_train_norm = normalize_xyY(X_train, input_parameters)
295
+ y_train_norm = normalize_munsell(y_train, output_parameters)
296
+ X_val_norm = normalize_xyY(X_val, input_parameters)
297
+ y_val_norm = normalize_munsell(y_val, output_parameters)
298
+
299
+ # Generate base model predictions
300
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
301
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
302
+
303
+ # Compute errors
304
+ error_train = y_train_norm - base_pred_train_norm
305
+ error_val = y_val_norm - base_pred_val_norm
306
+
307
+ # Create combined input: [xyY_norm, base_prediction_norm]
308
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
309
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
310
+
311
+ # Convert to PyTorch tensors
312
+ X_train_t = torch.FloatTensor(X_train_combined)
313
+ error_train_t = torch.FloatTensor(error_train)
314
+ X_val_t = torch.FloatTensor(X_val_combined)
315
+ error_val_t = torch.FloatTensor(error_val)
316
+
317
+ # Create data loaders
318
+ train_loader = DataLoader(
319
+ TensorDataset(X_train_t, error_train_t), batch_size=batch_size, shuffle=True
320
+ )
321
+ val_loader = DataLoader(
322
+ TensorDataset(X_val_t, error_val_t), batch_size=batch_size, shuffle=False
323
+ )
324
+
325
+ LOGGER.info(
326
+ " Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t)
327
+ )
328
+
329
+ # Initialize error predictor model
330
+ model = MultiHeadErrorPredictorParametric(
331
+ hue_width=hue_width,
332
+ value_width=value_width,
333
+ chroma_width=chroma_width,
334
+ code_width=code_width,
335
+ ).to(device)
336
+
337
+ # Count parameters
338
+ total_params = sum(p.numel() for p in model.parameters())
339
+ LOGGER.info(" Total parameters: %s", f"{total_params:,}")
340
+
341
+ # Training setup
342
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
343
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
344
+ optimizer, mode="min", factor=0.5, patience=5
345
+ )
346
+
347
+ # Create loss function
348
+ criterion = create_weighted_loss(
349
+ mse_weight, mae_weight, log_weight, huber_weight, huber_delta
350
+ )
351
+
352
+ # MLflow setup
353
+ run_name = setup_mlflow_experiment(
354
+ "from_xyY", f"hparam_multi_head_error_trial_{trial.number}"
355
+ )
356
+
357
+ # Training loop with early stopping
358
+ num_epochs = 50 # Reduced for hyperparameter search
359
+ patience = 10
360
+ best_val_loss = float("inf")
361
+ patience_counter = 0
362
+
363
+ with mlflow.start_run(run_name=run_name):
364
+ mlflow.log_params(
365
+ {
366
+ "lr": lr,
367
+ "batch_size": batch_size,
368
+ "hue_width": hue_width,
369
+ "value_width": value_width,
370
+ "chroma_width": chroma_width,
371
+ "code_width": code_width,
372
+ "mse_weight": mse_weight,
373
+ "mae_weight": mae_weight,
374
+ "log_weight": log_weight,
375
+ "huber_weight": huber_weight,
376
+ "huber_delta": huber_delta,
377
+ "total_params": total_params,
378
+ "trial_number": trial.number,
379
+ }
380
+ )
381
+
382
+ for epoch in range(num_epochs):
383
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
384
+ val_loss = validate(model, val_loader, criterion, device)
385
+ scheduler.step(val_loss)
386
+
387
+ # Log to MLflow
388
+ mlflow.log_metrics(
389
+ {
390
+ "train_loss": train_loss,
391
+ "val_loss": val_loss,
392
+ "learning_rate": optimizer.param_groups[0]["lr"],
393
+ },
394
+ step=epoch,
395
+ )
396
+
397
+ if (epoch + 1) % 10 == 0:
398
+ LOGGER.info(
399
+ " Epoch %03d/%d - Train: %.6f, Val: %.6f, LR: %.6f",
400
+ epoch + 1,
401
+ num_epochs,
402
+ train_loss,
403
+ val_loss,
404
+ optimizer.param_groups[0]["lr"],
405
+ )
406
+
407
+ # Early stopping
408
+ if val_loss < best_val_loss:
409
+ best_val_loss = val_loss
410
+ patience_counter = 0
411
+ else:
412
+ patience_counter += 1
413
+ if patience_counter >= patience:
414
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
415
+ break
416
+
417
+ # Report intermediate value for pruning
418
+ trial.report(val_loss, epoch)
419
+
420
+ # Handle pruning
421
+ if trial.should_prune():
422
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
423
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
424
+ raise optuna.TrialPruned
425
+
426
+ # Log final results
427
+ mlflow.log_metrics(
428
+ {
429
+ "best_val_loss": best_val_loss,
430
+ "final_train_loss": train_loss,
431
+ }
432
+ )
433
+
434
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
435
+
436
+ return best_val_loss
437
+
438
+
439
+ def main() -> None:
440
+ """
441
+ Run hyperparameter search for Multi-Head Error Predictor.
442
+
443
+ Performs systematic hyperparameter optimization using Optuna with:
444
+ - MedianPruner for early stopping of unpromising trials
445
+ - 30 total trials
446
+ - MLflow logging for each trial
447
+ - Result visualization using matplotlib (optimization history,
448
+ parameter importances, parallel coordinate plot)
449
+
450
+ The search aims to find optimal hyperparameters for predicting errors
451
+ in a base Multi-Head model, allowing for error correction and improved
452
+ Munsell predictions.
453
+ """
454
+
455
+ LOGGER.info("=" * 80)
456
+ LOGGER.info("Multi-Head Error Predictor Hyperparameter Search with Optuna")
457
+ LOGGER.info("=" * 80)
458
+
459
+ # Create study
460
+ study = optuna.create_study(
461
+ direction="minimize",
462
+ study_name="multi_head_error_predictor_hparam_search",
463
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=5),
464
+ )
465
+
466
+ # Run optimization
467
+ n_trials = 30 # Number of trials to run
468
+
469
+ LOGGER.info("")
470
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
471
+ LOGGER.info("")
472
+
473
+ study.optimize(objective, n_trials=n_trials, timeout=None)
474
+
475
+ # Print results
476
+ LOGGER.info("")
477
+ LOGGER.info("=" * 80)
478
+ LOGGER.info("Hyperparameter Search Results")
479
+ LOGGER.info("=" * 80)
480
+ LOGGER.info("")
481
+ LOGGER.info("Best trial:")
482
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
483
+ LOGGER.info("")
484
+ LOGGER.info("Best hyperparameters:")
485
+ for key, value in study.best_params.items():
486
+ LOGGER.info(" %s: %s", key, value)
487
+
488
+ # Save results
489
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
490
+ results_dir.mkdir(exist_ok=True, parents=True)
491
+
492
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
493
+ results_file = results_dir / f"hparam_search_multi_head_error_{timestamp}.txt"
494
+
495
+ with open(results_file, "w") as f:
496
+ f.write("=" * 80 + "\n")
497
+ f.write("Multi-Head Error Predictor Hyperparameter Search Results\n")
498
+ f.write("=" * 80 + "\n\n")
499
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
500
+ f.write(f"Number of trials: {len(study.trials)}\n")
501
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
502
+ f.write("Best hyperparameters:\n")
503
+ for key, value in study.best_params.items():
504
+ f.write(f" {key}: {value}\n")
505
+ f.write("\n\nAll trials:\n")
506
+ f.write("-" * 80 + "\n")
507
+
508
+ for t in study.trials:
509
+ f.write(f"\nTrial {t.number}:\n")
510
+ if t.value is not None:
511
+ f.write(f" Value: {t.value:.6f}\n")
512
+ else:
513
+ f.write(" Value: Pruned\n")
514
+ f.write(" Params:\n")
515
+ for key, value in t.params.items():
516
+ f.write(f" {key}: {value}\n")
517
+
518
+ LOGGER.info("")
519
+ LOGGER.info("Results saved to: %s", results_file)
520
+
521
+ # Generate visualizations using matplotlib
522
+ from optuna.visualization.matplotlib import ( # noqa: PLC0415
523
+ plot_optimization_history,
524
+ plot_parallel_coordinate,
525
+ plot_param_importances,
526
+ )
527
+
528
+ # Optimization history
529
+ ax = plot_optimization_history(study)
530
+ ax.figure.savefig(
531
+ results_dir / f"optimization_history_multi_head_error_{timestamp}.png", dpi=150
532
+ )
533
+ plt.close(ax.figure)
534
+
535
+ # Parameter importances
536
+ ax = plot_param_importances(study)
537
+ ax.figure.savefig(
538
+ results_dir / f"param_importances_multi_head_error_{timestamp}.png", dpi=150
539
+ )
540
+ plt.close(ax.figure)
541
+
542
+ # Parallel coordinate plot
543
+ ax = plot_parallel_coordinate(study)
544
+ ax.figure.savefig(
545
+ results_dir / f"parallel_coordinate_multi_head_error_{timestamp}.png", dpi=150
546
+ )
547
+ plt.close(ax.figure)
548
+
549
+ LOGGER.info("Visualizations saved to: %s", results_dir)
550
+
551
+
552
+ if __name__ == "__main__":
553
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
554
+
555
+ main()
learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-MLP model using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Chroma width multiplier
8
+ - Chroma loss weight
9
+ - Code loss weight
10
+ - Dropout (optional)
11
+
12
+ Objective: Minimize validation loss
13
+ """
14
+
15
+ import logging
16
+ from datetime import datetime
17
+
18
+ import matplotlib.pyplot as plt
19
+ import mlflow
20
+ import numpy as np
21
+ import optuna
22
+ import torch
23
+ from optuna.trial import Trial
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.models.networks import MultiMLPToMunsell
29
+ from learning_munsell.utilities.common import setup_mlflow_experiment
30
+ from learning_munsell.utilities.data import (
31
+ MUNSELL_NORMALIZATION_PARAMS,
32
+ normalize_munsell,
33
+ )
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ def weighted_mse_loss(
39
+ pred: torch.Tensor,
40
+ target: torch.Tensor,
41
+ hue_weight: float = 1.0,
42
+ value_weight: float = 1.0,
43
+ chroma_weight: float = 4.0,
44
+ code_weight: float = 0.5,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Component-wise weighted MSE loss with configurable weights.
48
+
49
+ Applies different weights to each Munsell component to account for
50
+ varying prediction difficulty and importance.
51
+
52
+ Parameters
53
+ ----------
54
+ pred : torch.Tensor
55
+ Predicted values, shape (batch_size, 4).
56
+ target : torch.Tensor
57
+ Target values, shape (batch_size, 4).
58
+ hue_weight : float, optional
59
+ Weight for hue component. Default is 1.0.
60
+ value_weight : float, optional
61
+ Weight for value component. Default is 1.0.
62
+ chroma_weight : float, optional
63
+ Weight for chroma component (typically higher). Default is 4.0.
64
+ code_weight : float, optional
65
+ Weight for code component (typically lower). Default is 0.5.
66
+
67
+ Returns
68
+ -------
69
+ torch.Tensor
70
+ Weighted MSE loss, scalar tensor.
71
+ """
72
+ weights = torch.tensor(
73
+ [hue_weight, value_weight, chroma_weight, code_weight], device=pred.device
74
+ )
75
+
76
+ mse = (pred - target) ** 2
77
+ weighted_mse = mse * weights
78
+ return weighted_mse.mean()
79
+
80
+
81
+ def train_epoch(
82
+ model: nn.Module,
83
+ dataloader: DataLoader,
84
+ optimizer: optim.Optimizer,
85
+ device: torch.device,
86
+ chroma_weight: float,
87
+ code_weight: float,
88
+ ) -> float:
89
+ """
90
+ Train the model for one epoch.
91
+
92
+ Parameters
93
+ ----------
94
+ model : nn.Module
95
+ Multi-MLP model to train.
96
+ dataloader : DataLoader
97
+ DataLoader providing training batches.
98
+ optimizer : optim.Optimizer
99
+ Optimizer for updating model parameters.
100
+ device : torch.device
101
+ Device to run training on (CPU, CUDA, or MPS).
102
+ chroma_weight : float
103
+ Weight for chroma component in loss function.
104
+ code_weight : float
105
+ Weight for code component in loss function.
106
+
107
+ Returns
108
+ -------
109
+ float
110
+ Average training loss over the epoch.
111
+ """
112
+ model.train()
113
+ total_loss = 0.0
114
+
115
+ for X_batch, y_batch in dataloader:
116
+ X_batch = X_batch.to(device) # noqa: PLW2901
117
+ y_batch = y_batch.to(device) # noqa: PLW2901
118
+ # Forward pass
119
+ outputs = model(X_batch)
120
+ loss = weighted_mse_loss(
121
+ outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight
122
+ )
123
+
124
+ # Backward pass
125
+ optimizer.zero_grad()
126
+ loss.backward()
127
+ optimizer.step()
128
+
129
+ total_loss += loss.item()
130
+
131
+ return total_loss / len(dataloader)
132
+
133
+
134
+ def validate(
135
+ model: nn.Module,
136
+ dataloader: DataLoader,
137
+ device: torch.device,
138
+ chroma_weight: float,
139
+ code_weight: float,
140
+ ) -> float:
141
+ """
142
+ Validate the model on the validation set.
143
+
144
+ Parameters
145
+ ----------
146
+ model : nn.Module
147
+ Multi-MLP model to validate.
148
+ dataloader : DataLoader
149
+ DataLoader providing validation batches.
150
+ device : torch.device
151
+ Device to run validation on (CPU, CUDA, or MPS).
152
+ chroma_weight : float
153
+ Weight for chroma component in loss function.
154
+ code_weight : float
155
+ Weight for code component in loss function.
156
+
157
+ Returns
158
+ -------
159
+ float
160
+ Average validation loss.
161
+ """
162
+ model.eval()
163
+ total_loss = 0.0
164
+
165
+ with torch.no_grad():
166
+ for X_batch, y_batch in dataloader:
167
+ X_batch = X_batch.to(device) # noqa: PLW2901
168
+ y_batch = y_batch.to(device) # noqa: PLW2901
169
+ outputs = model(X_batch)
170
+ loss = weighted_mse_loss(
171
+ outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight
172
+ )
173
+
174
+ total_loss += loss.item()
175
+
176
+ return total_loss / len(dataloader)
177
+
178
+
179
+ def objective(trial: Trial) -> float:
180
+ """
181
+ Optuna objective function to minimize validation loss.
182
+
183
+ This function defines the hyperparameter search space and training
184
+ procedure for each trial. It optimizes:
185
+ - Learning rate (1e-4 to 1e-3, log scale)
186
+ - Batch size (512, 1024, or 2048)
187
+ - Chroma branch width multiplier (1.5 to 2.5)
188
+ - Chroma loss weight (3.0 to 6.0)
189
+ - Code loss weight (0.3 to 1.0)
190
+ - Dropout rate (0.0 to 0.2)
191
+
192
+ Parameters
193
+ ----------
194
+ trial : Trial
195
+ Optuna trial object for suggesting hyperparameters.
196
+
197
+ Returns
198
+ -------
199
+ float
200
+ Best validation loss achieved during training.
201
+
202
+ Raises
203
+ ------
204
+ FileNotFoundError
205
+ If training data file is not found.
206
+ optuna.TrialPruned
207
+ If trial is pruned based on intermediate results.
208
+ """
209
+
210
+ # Suggest hyperparameters
211
+ lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
212
+ batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048])
213
+ chroma_width = trial.suggest_float("chroma_width", 1.5, 2.5, step=0.25)
214
+ chroma_weight = trial.suggest_float("chroma_weight", 3.0, 6.0, step=0.5)
215
+ code_weight = trial.suggest_float("code_weight", 0.3, 1.0, step=0.1)
216
+ dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05)
217
+
218
+ LOGGER.info("")
219
+ LOGGER.info("=" * 80)
220
+ LOGGER.info("Trial %d", trial.number)
221
+ LOGGER.info("=" * 80)
222
+ LOGGER.info(" lr: %.6f", lr)
223
+ LOGGER.info(" batch_size: %d", batch_size)
224
+ LOGGER.info(" chroma_width: %.2f", chroma_width)
225
+ LOGGER.info(" chroma_weight: %.1f", chroma_weight)
226
+ LOGGER.info(" code_weight: %.1f", code_weight)
227
+ LOGGER.info(" dropout: %.2f", dropout)
228
+
229
+ # Set device
230
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
+
232
+ # Load training data
233
+ data_file = PROJECT_ROOT / "data" / "training_data.npz"
234
+
235
+ if not data_file.exists():
236
+ LOGGER.error("Training data not found at %s", data_file)
237
+ LOGGER.error("Run generate_training_data.py first")
238
+ msg = f"Training data not found: {data_file}"
239
+ raise FileNotFoundError(msg)
240
+
241
+ data = np.load(data_file)
242
+
243
+ # Use pre-split data
244
+ X_train = data["X_train"]
245
+ y_train = data["y_train"]
246
+ X_val = data["X_val"]
247
+ y_val = data["y_val"]
248
+
249
+ LOGGER.info(
250
+ "Loaded %d training samples, %d validation samples", len(X_train), len(X_val)
251
+ )
252
+
253
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
254
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
255
+ y_train = normalize_munsell(y_train, output_parameters)
256
+ y_val = normalize_munsell(y_val, output_parameters)
257
+
258
+ # Convert to PyTorch tensors
259
+ X_train_t = torch.FloatTensor(X_train)
260
+ y_train_t = torch.FloatTensor(y_train)
261
+ X_val_t = torch.FloatTensor(X_val)
262
+ y_val_t = torch.FloatTensor(y_val)
263
+
264
+ # Create data loaders
265
+ train_dataset = TensorDataset(X_train_t, y_train_t)
266
+ val_dataset = TensorDataset(X_val_t, y_val_t)
267
+
268
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
269
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
270
+
271
+ # Initialize model
272
+ model = MultiMLPToMunsell(chroma_width_multiplier=chroma_width, dropout=dropout).to(
273
+ device
274
+ )
275
+
276
+ # Count parameters
277
+ total_params = sum(p.numel() for p in model.parameters())
278
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
279
+
280
+ # Training setup
281
+ optimizer = optim.Adam(model.parameters(), lr=lr)
282
+
283
+ # MLflow setup
284
+ run_name = setup_mlflow_experiment(
285
+ "from_xyY", f"hparam_multi_mlp_trial_{trial.number}"
286
+ )
287
+
288
+ # Training loop with early stopping
289
+ num_epochs = 100 # Reduced for hyperparameter search
290
+ patience = 15
291
+ best_val_loss = float("inf")
292
+ patience_counter = 0
293
+
294
+ with mlflow.start_run(run_name=run_name):
295
+ mlflow.log_params(
296
+ {
297
+ "trial": trial.number,
298
+ "lr": lr,
299
+ "batch_size": batch_size,
300
+ "chroma_width": chroma_width,
301
+ "chroma_weight": chroma_weight,
302
+ "code_weight": code_weight,
303
+ "dropout": dropout,
304
+ "total_params": total_params,
305
+ }
306
+ )
307
+
308
+ for epoch in range(num_epochs):
309
+ train_loss = train_epoch(
310
+ model, train_loader, optimizer, device, chroma_weight, code_weight
311
+ )
312
+ val_loss = validate(model, val_loader, device, chroma_weight, code_weight)
313
+
314
+ # Log to MLflow
315
+ mlflow.log_metrics(
316
+ {
317
+ "train_loss": train_loss,
318
+ "val_loss": val_loss,
319
+ "learning_rate": lr,
320
+ },
321
+ step=epoch,
322
+ )
323
+
324
+ if (epoch + 1) % 10 == 0:
325
+ LOGGER.info(
326
+ " Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
327
+ epoch + 1,
328
+ num_epochs,
329
+ train_loss,
330
+ val_loss,
331
+ )
332
+
333
+ # Early stopping
334
+ if val_loss < best_val_loss:
335
+ best_val_loss = val_loss
336
+ patience_counter = 0
337
+ else:
338
+ patience_counter += 1
339
+ if patience_counter >= patience:
340
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
341
+ break
342
+
343
+ # Report intermediate value for pruning
344
+ trial.report(val_loss, epoch)
345
+
346
+ # Handle pruning
347
+ if trial.should_prune():
348
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
349
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
350
+ raise optuna.TrialPruned
351
+
352
+ # Log final results
353
+ mlflow.log_metrics(
354
+ {
355
+ "best_val_loss": best_val_loss,
356
+ "final_train_loss": train_loss,
357
+ "final_epoch": epoch + 1,
358
+ }
359
+ )
360
+
361
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
362
+
363
+ return best_val_loss
364
+
365
+
366
+ def main() -> None:
367
+ """
368
+ Run hyperparameter search for Multi-MLP model.
369
+
370
+ Performs systematic hyperparameter optimization using Optuna with:
371
+ - MedianPruner for early stopping of unpromising trials
372
+ - 15 total trials
373
+ - MLflow logging for each trial
374
+ - Result visualization using matplotlib (optimization history,
375
+ parameter importances, parallel coordinate plot)
376
+
377
+ The search aims to find optimal hyperparameters for converting xyY
378
+ color coordinates to Munsell color specifications using a multi-MLP
379
+ architecture with independent branches for each component.
380
+ """
381
+
382
+ LOGGER.info("=" * 80)
383
+ LOGGER.info("Multi-MLP Hyperparameter Search with Optuna")
384
+ LOGGER.info("=" * 80)
385
+
386
+ # Create study
387
+ study = optuna.create_study(
388
+ direction="minimize",
389
+ study_name="multi_mlp_hparam_search",
390
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
391
+ )
392
+
393
+ # Run optimization
394
+ n_trials = 15 # Number of trials to run
395
+
396
+ LOGGER.info("")
397
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
398
+ LOGGER.info("")
399
+
400
+ study.optimize(objective, n_trials=n_trials, timeout=None)
401
+
402
+ # Print results
403
+ LOGGER.info("")
404
+ LOGGER.info("=" * 80)
405
+ LOGGER.info("Hyperparameter Search Results")
406
+ LOGGER.info("=" * 80)
407
+ LOGGER.info("")
408
+ LOGGER.info("Best trial:")
409
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
410
+ LOGGER.info("")
411
+ LOGGER.info("Best hyperparameters:")
412
+ for key, value in study.best_params.items():
413
+ LOGGER.info(" %s: %s", key, value)
414
+
415
+ # Save results
416
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
417
+ results_dir.mkdir(exist_ok=True)
418
+
419
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
420
+ results_file = results_dir / f"hparam_search_{timestamp}.txt"
421
+
422
+ with open(results_file, "w") as f:
423
+ f.write("=" * 80 + "\n")
424
+ f.write("Multi-MLP Hyperparameter Search Results\n")
425
+ f.write("=" * 80 + "\n\n")
426
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
427
+ f.write(f"Number of trials: {len(study.trials)}\n")
428
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
429
+ f.write("Best hyperparameters:\n")
430
+ for key, value in study.best_params.items():
431
+ f.write(f" {key}: {value}\n")
432
+ f.write("\n\nAll trials:\n")
433
+ f.write("-" * 80 + "\n")
434
+
435
+ for trial in study.trials:
436
+ f.write(f"\nTrial {trial.number}:\n")
437
+ f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n")
438
+ f.write(" Params:\n")
439
+ for key, value in trial.params.items():
440
+ f.write(f" {key}: {value}\n")
441
+
442
+ LOGGER.info("")
443
+ LOGGER.info("Results saved to: %s", results_file)
444
+
445
+ # Generate visualizations using matplotlib
446
+ from optuna.visualization.matplotlib import ( # noqa: PLC0415
447
+ plot_optimization_history,
448
+ plot_parallel_coordinate,
449
+ plot_param_importances,
450
+ )
451
+
452
+ # Optimization history
453
+ ax = plot_optimization_history(study)
454
+ ax.figure.savefig(results_dir / f"optimization_history_{timestamp}.png", dpi=150)
455
+ plt.close(ax.figure)
456
+
457
+ # Parameter importances
458
+ ax = plot_param_importances(study)
459
+ ax.figure.savefig(results_dir / f"param_importances_{timestamp}.png", dpi=150)
460
+ plt.close(ax.figure)
461
+
462
+ # Parallel coordinate plot
463
+ ax = plot_parallel_coordinate(study)
464
+ ax.figure.savefig(results_dir / f"parallel_coordinate_{timestamp}.png", dpi=150)
465
+ plt.close(ax.figure)
466
+
467
+ LOGGER.info("Visualizations saved to: %s", results_dir)
468
+
469
+
470
+ if __name__ == "__main__":
471
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
472
+
473
+ main()
learning_munsell/training/from_xyY/refine_multi_head_real.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refine Multi-Head model on REAL Munsell colors only.
3
+
4
+ This script fine-tunes the best Multi-Head model using only the 2734 real
5
+ (measured) Munsell colors, which should improve accuracy on the evaluation set.
6
+ """
7
+
8
+ import logging
9
+
10
+ import click
11
+ import mlflow
12
+ import mlflow.pytorch
13
+ import numpy as np
14
+ import torch
15
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
16
+ from colour.notation.munsell import (
17
+ munsell_colour_to_munsell_specification,
18
+ munsell_specification_to_xyY,
19
+ )
20
+ from numpy.typing import NDArray
21
+ from sklearn.model_selection import train_test_split
22
+ from torch import nn, optim
23
+ from torch.utils.data import DataLoader, TensorDataset
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
27
+ from learning_munsell.utilities.common import (
28
+ log_training_epoch,
29
+ setup_mlflow_experiment,
30
+ )
31
+ from learning_munsell.utilities.data import (
32
+ MUNSELL_NORMALIZATION_PARAMS,
33
+ XYY_NORMALIZATION_PARAMS,
34
+ normalize_munsell,
35
+ )
36
+ from learning_munsell.utilities.training import train_epoch, validate
37
+
38
+ LOGGER = logging.getLogger(__name__)
39
+
40
+
41
+ def generate_real_samples(
42
+ n_samples_per_color: int = 100,
43
+ perturbation_pct: float = 0.05,
44
+ ) -> tuple[NDArray, NDArray]:
45
+ """
46
+ Generate training samples from REAL (measured) Munsell colors only.
47
+
48
+ Creates augmented samples by applying small perturbations to the 2734 real
49
+ Munsell color specifications to increase training data while staying close
50
+ to measured values.
51
+
52
+ Parameters
53
+ ----------
54
+ n_samples_per_color : int, optional
55
+ Number of perturbed samples to generate per real color (default is 100).
56
+ perturbation_pct : float, optional
57
+ Percentage of range to use for perturbations (default is 0.05 = 5%).
58
+
59
+ Returns
60
+ -------
61
+ xyY_samples : NDArray
62
+ Array of shape (n_samples, 3) containing xyY coordinates.
63
+ munsell_samples : NDArray
64
+ Array of shape (n_samples, 4) containing Munsell specifications
65
+ [hue, value, chroma, code].
66
+
67
+ Notes
68
+ -----
69
+ Perturbations are applied uniformly within ±perturbation_pct of the
70
+ component ranges:
71
+ - Hue range: 9.5 (0.5 to 10.0)
72
+ - Value range: 9.0 (1.0 to 10.0)
73
+ - Chroma range: 50.0 (0.0 to 50.0)
74
+
75
+ Invalid samples (that cannot be converted to xyY) are skipped.
76
+ """
77
+ LOGGER.info(
78
+ "Generating samples from %d REAL Munsell colors...", len(MUNSELL_COLOURS_REAL)
79
+ )
80
+
81
+ np.random.seed(42)
82
+
83
+ hue_range = 9.5
84
+ value_range = 9.0
85
+ chroma_range = 50.0
86
+
87
+ xyY_samples = []
88
+ munsell_samples = []
89
+
90
+ for munsell_spec_tuple, _ in MUNSELL_COLOURS_REAL:
91
+ hue_code_str, value, chroma = munsell_spec_tuple
92
+ munsell_str = f"{hue_code_str} {value}/{chroma}"
93
+ base_spec = munsell_colour_to_munsell_specification(munsell_str)
94
+
95
+ for _ in range(n_samples_per_color):
96
+ hue_delta = np.random.uniform(
97
+ -perturbation_pct * hue_range, perturbation_pct * hue_range
98
+ )
99
+ value_delta = np.random.uniform(
100
+ -perturbation_pct * value_range, perturbation_pct * value_range
101
+ )
102
+ chroma_delta = np.random.uniform(
103
+ -perturbation_pct * chroma_range, perturbation_pct * chroma_range
104
+ )
105
+
106
+ perturbed_spec = base_spec.copy()
107
+ perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0)
108
+ perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0)
109
+ perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0)
110
+
111
+ try:
112
+ xyY = munsell_specification_to_xyY(perturbed_spec)
113
+ xyY_samples.append(xyY)
114
+ munsell_samples.append(perturbed_spec)
115
+ except Exception: # noqa: BLE001, S112
116
+ continue
117
+
118
+ LOGGER.info("Generated %d samples", len(xyY_samples))
119
+ return np.array(xyY_samples), np.array(munsell_samples)
120
+
121
+
122
+ @click.command()
123
+ @click.option("--epochs", default=300, help="Number of training epochs")
124
+ @click.option("--batch-size", default=512, help="Batch size for training")
125
+ @click.option("--lr", default=1e-5, help="Learning rate")
126
+ @click.option("--patience", default=30, help="Early stopping patience")
127
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
128
+ """
129
+ Refine Multi-Head model on REAL Munsell colors only.
130
+
131
+ Fine-tunes a pretrained Multi-Head MLP model using only the 2734 real
132
+ (measured) Munsell colors with small perturbations. This refinement step
133
+ aims to improve accuracy on actual measured colors by focusing the model
134
+ on the real color gamut.
135
+
136
+ Notes
137
+ -----
138
+ Training configuration:
139
+ - Dataset: 2734 real Munsell colors with 200 samples per color
140
+ - Perturbation: 3% of component ranges (smaller than initial training)
141
+ - Learning rate: 1e-5 (lower for fine-tuning)
142
+ - Batch size: 512
143
+ - Early stopping: patience of 30 epochs
144
+ - Optimizer: AdamW with weight decay 0.01
145
+ - Scheduler: ReduceLROnPlateau with factor 0.5, patience 15
146
+
147
+ Workflow:
148
+ 1. Generate augmented samples from real Munsell colors
149
+ 2. Load pretrained model (multi_head_large_best.pth)
150
+ 3. Fine-tune with lower learning rate
151
+ 4. Save best model based on validation loss
152
+ 5. Export to ONNX format
153
+ 6. Log metrics to MLflow
154
+
155
+ Files generated:
156
+ - multi_head_refined_real_best.pth: Best checkpoint
157
+ - multi_head_refined_real.onnx: ONNX model
158
+ - multi_head_refined_real_normalization_parameters.npz: Normalization params
159
+ """
160
+ LOGGER.info("=" * 80)
161
+ LOGGER.info("Multi-Head Refinement on REAL Munsell Colors")
162
+ LOGGER.info("=" * 80)
163
+
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+ if torch.backends.mps.is_available():
166
+ device = torch.device("mps")
167
+ LOGGER.info("Using device: %s", device)
168
+
169
+ # Generate REAL-only samples
170
+ LOGGER.info("")
171
+ xyY_all, munsell_all = generate_real_samples(
172
+ n_samples_per_color=200, # 200 samples per real color
173
+ perturbation_pct=0.03, # Smaller perturbations for refinement
174
+ )
175
+
176
+ # Split data
177
+ X_train, X_val, y_train, y_val = train_test_split(
178
+ xyY_all, munsell_all, test_size=0.15, random_state=42
179
+ )
180
+
181
+ LOGGER.info("Train samples: %d", len(X_train))
182
+ LOGGER.info("Validation samples: %d", len(X_val))
183
+
184
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
185
+ # Use hardcoded ranges covering the full Munsell space for generalization
186
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
187
+ y_train_norm = normalize_munsell(y_train, output_parameters)
188
+ y_val_norm = normalize_munsell(y_val, output_parameters)
189
+
190
+ # Convert to tensors
191
+ X_train_t = torch.FloatTensor(X_train)
192
+ y_train_t = torch.FloatTensor(y_train_norm)
193
+ X_val_t = torch.FloatTensor(X_val)
194
+ y_val_t = torch.FloatTensor(y_val_norm)
195
+
196
+ # Data loaders
197
+ train_dataset = TensorDataset(X_train_t, y_train_t)
198
+ val_dataset = TensorDataset(X_val_t, y_val_t)
199
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
200
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
201
+
202
+ # Load pretrained model
203
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
204
+ pretrained_path = model_directory / "multi_head_large_best.pth"
205
+
206
+ model = MultiHeadMLPToMunsell().to(device)
207
+
208
+ if pretrained_path.exists():
209
+ LOGGER.info("")
210
+ LOGGER.info("Loading pretrained model from %s...", pretrained_path)
211
+ checkpoint = torch.load(
212
+ pretrained_path, weights_only=False, map_location=device
213
+ )
214
+ model.load_state_dict(checkpoint["model_state_dict"])
215
+ LOGGER.info("Pretrained model loaded successfully")
216
+ else:
217
+ LOGGER.info("")
218
+ LOGGER.info("No pretrained model found, training from scratch")
219
+
220
+ total_params = sum(p.numel() for p in model.parameters())
221
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
222
+
223
+ # Fine-tuning with lower learning rate
224
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
225
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
226
+ optimizer, mode="min", factor=0.5, patience=15
227
+ )
228
+ criterion = nn.MSELoss()
229
+
230
+ # MLflow setup
231
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_refined_real")
232
+
233
+ LOGGER.info("")
234
+ LOGGER.info("MLflow run: %s", run_name)
235
+ LOGGER.info("Learning rate: %e (fine-tuning)", lr)
236
+
237
+ # Training loop
238
+ best_val_loss = float("inf")
239
+ patience_counter = 0
240
+
241
+ LOGGER.info("")
242
+ LOGGER.info("Starting refinement training...")
243
+
244
+ with mlflow.start_run(run_name=run_name):
245
+ mlflow.log_params(
246
+ {
247
+ "model": "multi_head_refined_real",
248
+ "learning_rate": lr,
249
+ "batch_size": batch_size,
250
+ "num_epochs": epochs,
251
+ "patience": patience,
252
+ "total_params": total_params,
253
+ "train_samples": len(X_train),
254
+ "val_samples": len(X_val),
255
+ "dataset": "REAL_only",
256
+ "perturbation_pct": 0.03,
257
+ "samples_per_color": 200,
258
+ }
259
+ )
260
+
261
+ for epoch in range(epochs):
262
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
263
+ val_loss = validate(model, val_loader, criterion, device)
264
+
265
+ scheduler.step(val_loss)
266
+
267
+ log_training_epoch(
268
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
269
+ )
270
+
271
+ LOGGER.info(
272
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.2e",
273
+ epoch + 1,
274
+ epochs,
275
+ train_loss,
276
+ val_loss,
277
+ optimizer.param_groups[0]["lr"],
278
+ )
279
+
280
+ if val_loss < best_val_loss:
281
+ best_val_loss = val_loss
282
+ patience_counter = 0
283
+
284
+ checkpoint_file = model_directory / "multi_head_refined_real_best.pth"
285
+
286
+ torch.save(
287
+ {
288
+ "model_state_dict": model.state_dict(),
289
+ "output_parameters": output_parameters,
290
+ "epoch": epoch,
291
+ "val_loss": val_loss,
292
+ },
293
+ checkpoint_file,
294
+ )
295
+
296
+ LOGGER.info(" -> Saved best model (val_loss: %.6f)", val_loss)
297
+ else:
298
+ patience_counter += 1
299
+ if patience_counter >= patience:
300
+ LOGGER.info("")
301
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
302
+ break
303
+
304
+ mlflow.log_metrics(
305
+ {
306
+ "best_val_loss": best_val_loss,
307
+ "final_epoch": epoch + 1,
308
+ }
309
+ )
310
+
311
+ # Export to ONNX
312
+ LOGGER.info("")
313
+ LOGGER.info("Exporting refined model to ONNX...")
314
+ model.eval()
315
+
316
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
317
+ model.load_state_dict(checkpoint["model_state_dict"])
318
+
319
+ model_cpu = model.cpu()
320
+ dummy_input = torch.randn(1, 3)
321
+
322
+ onnx_file = model_directory / "multi_head_refined_real.onnx"
323
+ torch.onnx.export(
324
+ model_cpu,
325
+ dummy_input,
326
+ onnx_file,
327
+ export_params=True,
328
+ opset_version=14,
329
+ input_names=["xyY"],
330
+ output_names=["munsell_spec"],
331
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
332
+ )
333
+
334
+ params_file = (
335
+ model_directory / "multi_head_refined_real_normalization_parameters.npz"
336
+ )
337
+ input_parameters = XYY_NORMALIZATION_PARAMS
338
+ np.savez(
339
+ params_file,
340
+ input_parameters=input_parameters,
341
+ output_parameters=output_parameters,
342
+ )
343
+
344
+ mlflow.log_artifact(str(checkpoint_file))
345
+ mlflow.log_artifact(str(onnx_file))
346
+ mlflow.log_artifact(str(params_file))
347
+
348
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
349
+ LOGGER.info("Normalization params saved to: %s", params_file)
350
+
351
+ LOGGER.info("=" * 80)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
356
+
357
+ main()
learning_munsell/training/from_xyY/train_deep_wide.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Deep + Wide model for xyY to Munsell conversion.
3
+
4
+ Option 5: Hybrid Deep + Wide architecture
5
+ - Input: 3 features (xyY)
6
+ - Deep path: 3 → 512 → 1024 (ResBlocks) → 512
7
+ - Wide path: 3 → 128 (direct linear)
8
+ - Combine: [512, 128] → 256 → 4
9
+ - Output: 4 features (hue, value, chroma, code)
10
+ """
11
+
12
+ import logging
13
+
14
+ import click
15
+ import mlflow
16
+ import mlflow.pytorch
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn, optim
20
+ from torch.utils.data import DataLoader, TensorDataset
21
+
22
+ from learning_munsell import PROJECT_ROOT
23
+ from learning_munsell.models.networks import ResidualBlock
24
+ from learning_munsell.utilities.common import (
25
+ log_training_epoch,
26
+ setup_mlflow_experiment,
27
+ )
28
+ from learning_munsell.utilities.data import (
29
+ MUNSELL_NORMALIZATION_PARAMS,
30
+ XYY_NORMALIZATION_PARAMS,
31
+ normalize_munsell,
32
+ )
33
+ from learning_munsell.utilities.losses import precision_focused_loss
34
+ from learning_munsell.utilities.training import train_epoch, validate
35
+
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+
39
+ class DeepWideNet(nn.Module):
40
+ """
41
+ Deep + Wide Network for xyY to Munsell conversion.
42
+
43
+ Architecture:
44
+ - Deep path: Complex non-linear transformation
45
+ - Wide path: Direct linear connections
46
+ - Combines both for final prediction
47
+
48
+ Parameters
49
+ ----------
50
+ num_residual_blocks : int, optional
51
+ Number of residual blocks in deep path. Default is 4.
52
+
53
+ Attributes
54
+ ----------
55
+ deep_encoder : nn.Sequential
56
+ Deep path encoder: 3 → 512 → 1024.
57
+ deep_residual_blocks : nn.ModuleList
58
+ Stack of residual blocks in deep path.
59
+ deep_decoder : nn.Sequential
60
+ Deep path decoder: 1024 → 512.
61
+ wide_path : nn.Sequential
62
+ Wide path: 3 → 128.
63
+ output_head : nn.Sequential
64
+ Combined output: [512, 128] → 256 → 4.
65
+
66
+ Notes
67
+ -----
68
+ Hybrid architecture inspired by Google's Wide & Deep Learning:
69
+ - Deep path: 3 → 512 → 1024 → (ResBlocks) → 512
70
+ - Wide path: 3 → 128 (direct linear transformation)
71
+ - Combined: Concatenate [512, 128] → 256 → 4
72
+
73
+ The deep path learns complex non-linear transformations while the
74
+ wide path provides direct linear connections to preserve simple
75
+ relationships. Both paths are concatenated before the final output.
76
+ """
77
+
78
+ def __init__(self, num_residual_blocks: int = 4) -> None:
79
+ """Initialize the deep and wide network."""
80
+ super().__init__()
81
+
82
+ # Deep path: Complex transformation
83
+ self.deep_encoder = nn.Sequential(
84
+ nn.Linear(3, 512),
85
+ nn.GELU(),
86
+ nn.BatchNorm1d(512),
87
+ nn.Linear(512, 1024),
88
+ nn.GELU(),
89
+ nn.BatchNorm1d(1024),
90
+ )
91
+
92
+ self.deep_residual_blocks = nn.ModuleList(
93
+ [ResidualBlock(1024) for _ in range(num_residual_blocks)]
94
+ )
95
+
96
+ self.deep_decoder = nn.Sequential(
97
+ nn.Linear(1024, 512),
98
+ nn.GELU(),
99
+ nn.BatchNorm1d(512),
100
+ )
101
+
102
+ # Wide path: Direct linear transformation
103
+ self.wide_path = nn.Sequential(
104
+ nn.Linear(3, 128),
105
+ nn.GELU(),
106
+ nn.BatchNorm1d(128),
107
+ )
108
+
109
+ # Combined output: Concatenate deep (512) + wide (128) = 640
110
+ self.output_head = nn.Sequential(
111
+ nn.Linear(640, 256),
112
+ nn.GELU(),
113
+ nn.BatchNorm1d(256),
114
+ nn.Linear(256, 4),
115
+ )
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ """
119
+ Forward pass through deep and wide paths.
120
+
121
+ Parameters
122
+ ----------
123
+ x : Tensor
124
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
125
+
126
+ Returns
127
+ -------
128
+ Tensor
129
+ Output tensor of shape (batch_size, 4) containing normalized Munsell
130
+ specifications [hue, value, chroma, code].
131
+
132
+ Notes
133
+ -----
134
+ The forward pass processes input through two parallel paths:
135
+ 1. Deep path: Complex transformation through encoder, residual blocks,
136
+ and decoder (3 → 512 → 1024 → 512)
137
+ 2. Wide path: Direct linear transformation (3 → 128)
138
+ 3. Concatenation: Combine deep (512) + wide (128) = 640 features
139
+ 4. Output head: Final transformation to 4 components (640 → 256 → 4)
140
+ """
141
+ # Deep path
142
+ deep = self.deep_encoder(x)
143
+ for block in self.deep_residual_blocks:
144
+ deep = block(deep)
145
+ deep = self.deep_decoder(deep)
146
+
147
+ # Wide path
148
+ wide = self.wide_path(x)
149
+
150
+ # Concatenate and output
151
+ combined = torch.cat([deep, wide], dim=1)
152
+ return self.output_head(combined)
153
+
154
+
155
+ @click.command()
156
+ @click.option("--epochs", default=300, help="Number of training epochs")
157
+ @click.option("--batch-size", default=1024, help="Batch size for training")
158
+ @click.option("--lr", default=3e-4, help="Learning rate")
159
+ @click.option("--patience", default=20, help="Early stopping patience")
160
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
161
+ """
162
+ Train the DeepWideNet model for xyY to Munsell conversion.
163
+
164
+ Notes
165
+ -----
166
+ The training pipeline:
167
+ 1. Loads normalization parameters from existing config
168
+ 2. Loads training data from cache
169
+ 3. Normalizes inputs and outputs to [0, 1] range
170
+ 4. Creates PyTorch DataLoaders
171
+ 5. Initializes DeepWideNet with deep and wide paths
172
+ 6. Trains with AdamW optimizer and precision-focused loss
173
+ 7. Uses learning rate scheduler (ReduceLROnPlateau)
174
+ 8. Implements early stopping based on validation loss
175
+ 9. Exports best model to ONNX format
176
+ 10. Logs all metrics and artifacts to MLflow
177
+ """
178
+
179
+ LOGGER.info("=" * 80)
180
+ LOGGER.info("Deep + Wide Network: xyY → Munsell")
181
+ LOGGER.info("=" * 80)
182
+
183
+ # Set device
184
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
+ LOGGER.info("Using device: %s", device)
186
+
187
+ # Paths
188
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
189
+ data_dir = PROJECT_ROOT / "data"
190
+ cache_file = data_dir / "training_data.npz"
191
+
192
+ # Load training data
193
+ LOGGER.info("")
194
+ LOGGER.info("Loading training data from %s...", cache_file)
195
+ data = np.load(cache_file)
196
+ X_train = data["X_train"]
197
+ y_train = data["y_train"]
198
+ X_val = data["X_val"]
199
+ y_val = data["y_val"]
200
+
201
+ LOGGER.info("Train samples: %d", len(X_train))
202
+ LOGGER.info("Validation samples: %d", len(X_val))
203
+
204
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
205
+ # Use hardcoded ranges covering the full Munsell space for generalization
206
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
207
+ y_train_norm = normalize_munsell(y_train, output_parameters)
208
+ y_val_norm = normalize_munsell(y_val, output_parameters)
209
+
210
+ # Convert to PyTorch tensors
211
+ X_train_t = torch.FloatTensor(X_train)
212
+ y_train_t = torch.FloatTensor(y_train_norm)
213
+ X_val_t = torch.FloatTensor(X_val)
214
+ y_val_t = torch.FloatTensor(y_val_norm)
215
+
216
+ # Create data loaders
217
+ train_dataset = TensorDataset(X_train_t, y_train_t)
218
+ val_dataset = TensorDataset(X_val_t, y_val_t)
219
+
220
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
221
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
222
+
223
+ # Initialize model
224
+ model = DeepWideNet(num_residual_blocks=4).to(device)
225
+ LOGGER.info("")
226
+ LOGGER.info("Deep + Wide architecture:")
227
+ LOGGER.info("%s", model)
228
+
229
+ # Count parameters
230
+ total_params = sum(p.numel() for p in model.parameters())
231
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
232
+
233
+ # Training setup
234
+ learning_rate = lr
235
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
236
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
237
+ optimizer, mode="min", factor=0.5, patience=5
238
+ )
239
+ criterion = precision_focused_loss
240
+
241
+ # MLflow setup
242
+ run_name = setup_mlflow_experiment("from_xyY", "deep_wide")
243
+
244
+ LOGGER.info("")
245
+ LOGGER.info("MLflow run: %s", run_name)
246
+
247
+ # Training loop
248
+ best_val_loss = float("inf")
249
+ patience_counter = 0
250
+
251
+ LOGGER.info("")
252
+ LOGGER.info("Starting training...")
253
+
254
+ with mlflow.start_run(run_name=run_name):
255
+ # Log parameters
256
+ mlflow.log_params(
257
+ {
258
+ "model": "deep_wide",
259
+ "learning_rate": learning_rate,
260
+ "batch_size": batch_size,
261
+ "num_epochs": epochs,
262
+ "patience": patience,
263
+ "total_params": total_params,
264
+ }
265
+ )
266
+
267
+ for epoch in range(epochs):
268
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
269
+ val_loss = validate(model, val_loader, criterion, device)
270
+
271
+ scheduler.step(val_loss)
272
+
273
+ # Log to MLflow
274
+ log_training_epoch(
275
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
276
+ )
277
+
278
+ LOGGER.info(
279
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
280
+ epoch + 1,
281
+ epochs,
282
+ train_loss,
283
+ val_loss,
284
+ optimizer.param_groups[0]["lr"],
285
+ )
286
+
287
+ # Early stopping
288
+ if val_loss < best_val_loss:
289
+ best_val_loss = val_loss
290
+ patience_counter = 0
291
+
292
+ model_directory.mkdir(exist_ok=True)
293
+ checkpoint_file = model_directory / "deep_wide_best.pth"
294
+
295
+ torch.save(
296
+ {
297
+ "model_state_dict": model.state_dict(),
298
+ "epoch": epoch,
299
+ "val_loss": val_loss,
300
+ },
301
+ checkpoint_file,
302
+ )
303
+
304
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
305
+ else:
306
+ patience_counter += 1
307
+ if patience_counter >= patience:
308
+ LOGGER.info("")
309
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
310
+ break
311
+
312
+ # Log final metrics
313
+ mlflow.log_metrics(
314
+ {
315
+ "best_val_loss": best_val_loss,
316
+ "final_epoch": epoch + 1,
317
+ }
318
+ )
319
+
320
+ # Export to ONNX
321
+ LOGGER.info("")
322
+ LOGGER.info("Exporting to ONNX...")
323
+ model.eval()
324
+
325
+ checkpoint = torch.load(checkpoint_file)
326
+ model.load_state_dict(checkpoint["model_state_dict"])
327
+
328
+ dummy_input = torch.randn(1, 3).to(device)
329
+
330
+ onnx_file = model_directory / "deep_wide.onnx"
331
+ torch.onnx.export(
332
+ model,
333
+ dummy_input,
334
+ onnx_file,
335
+ export_params=True,
336
+ opset_version=15,
337
+ input_names=["xyY"],
338
+ output_names=["munsell_spec"],
339
+ dynamic_axes={
340
+ "xyY": {0: "batch_size"},
341
+ "munsell_spec": {0: "batch_size"},
342
+ },
343
+ )
344
+
345
+ # Save normalization parameters alongside model
346
+ params_file = model_directory / "deep_wide_normalization_parameters.npz"
347
+ input_parameters = XYY_NORMALIZATION_PARAMS
348
+ np.savez(
349
+ params_file,
350
+ input_parameters=input_parameters,
351
+ output_parameters=output_parameters,
352
+ )
353
+
354
+ # Log artifacts to MLflow
355
+ mlflow.log_artifact(str(checkpoint_file))
356
+ mlflow.log_artifact(str(onnx_file))
357
+ mlflow.log_artifact(str(params_file))
358
+ mlflow.pytorch.log_model(model, "model")
359
+
360
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
361
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
362
+ LOGGER.info("Artifacts logged to MLflow")
363
+
364
+ LOGGER.info("=" * 80)
365
+
366
+
367
+ if __name__ == "__main__":
368
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
369
+
370
+ main()
learning_munsell/training/from_xyY/train_ft_transformer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train FT-Transformer model for xyY to Munsell conversion.
3
+
4
+ Option 4: Feature Tokenizer + Transformer architecture
5
+ - Input: 3 features (xyY) → each becomes a 256-dim token
6
+ - Add [CLS] token for regression
7
+ - 4-6 transformer blocks with multi-head attention
8
+ - Output: Take [CLS] token → MLP → 4 features
9
+ """
10
+
11
+ import logging
12
+
13
+ import click
14
+ import mlflow
15
+ import mlflow.pytorch
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn, optim
19
+ from torch.utils.data import DataLoader, TensorDataset
20
+
21
+ from learning_munsell import PROJECT_ROOT
22
+ from learning_munsell.models.networks import FeatureTokenizer, TransformerBlock
23
+ from learning_munsell.utilities.common import (
24
+ fix_onnx_dynamic_batch,
25
+ log_training_epoch,
26
+ setup_mlflow_experiment,
27
+ )
28
+ from learning_munsell.utilities.data import (
29
+ MUNSELL_NORMALIZATION_PARAMS,
30
+ XYY_NORMALIZATION_PARAMS,
31
+ normalize_munsell,
32
+ )
33
+ from learning_munsell.utilities.losses import precision_focused_loss
34
+ from learning_munsell.utilities.training import train_epoch, validate
35
+
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+
39
+ class FTTransformer(nn.Module):
40
+ """
41
+ Feature Tokenizer + Transformer for xyY to Munsell conversion.
42
+
43
+ This model adapts transformer architecture for tabular data by tokenizing
44
+ each input feature separately and using self-attention to capture complex
45
+ feature interactions.
46
+
47
+ Architecture
48
+ ------------
49
+ - Tokenize each feature (3 features → 3 tokens)
50
+ - Add CLS token (4 tokens total)
51
+ - 4 transformer blocks with multi-head attention
52
+ - Extract CLS token → MLP head → 4 outputs
53
+
54
+ Parameters
55
+ ----------
56
+ num_features : int, optional
57
+ Number of input features (xyY), default is 3.
58
+ embedding_dim : int, optional
59
+ Dimension of token embeddings, default is 256.
60
+ num_blocks : int, optional
61
+ Number of transformer blocks, default is 4.
62
+ num_heads : int, optional
63
+ Number of attention heads, default is 4.
64
+ ff_dim : int, optional
65
+ Feedforward network hidden dimension, default is 512.
66
+ dropout : float, optional
67
+ Dropout probability, default is 0.1.
68
+
69
+ Attributes
70
+ ----------
71
+ tokenizer : FeatureTokenizer
72
+ Converts input features to token embeddings.
73
+ transformer_blocks : nn.ModuleList
74
+ Stack of transformer blocks.
75
+ output_head : nn.Sequential
76
+ MLP that maps CLS token to output predictions.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ num_features: int = 3,
82
+ embedding_dim: int = 256,
83
+ num_blocks: int = 4,
84
+ num_heads: int = 4,
85
+ ff_dim: int = 512,
86
+ dropout: float = 0.1,
87
+ ) -> None:
88
+ """Initialize the FT-Transformer model."""
89
+ super().__init__()
90
+
91
+ # Feature tokenizer
92
+ self.tokenizer = FeatureTokenizer(num_features, embedding_dim)
93
+
94
+ # Transformer blocks
95
+ self.transformer_blocks = nn.ModuleList(
96
+ [
97
+ TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
98
+ for _ in range(num_blocks)
99
+ ]
100
+ )
101
+
102
+ # Output head (from CLS token)
103
+ self.output_head = nn.Sequential(
104
+ nn.Linear(embedding_dim, 128),
105
+ nn.GELU(),
106
+ nn.Dropout(dropout),
107
+ nn.Linear(128, 4),
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ """
112
+ Forward pass through FT-Transformer.
113
+
114
+ Parameters
115
+ ----------
116
+ x : Tensor
117
+ Input xyY values of shape (batch_size, 3).
118
+
119
+ Returns
120
+ -------
121
+ Tensor
122
+ Predicted Munsell specification [hue, value, chroma, code]
123
+ of shape (batch_size, 4).
124
+ """
125
+ # Tokenize features
126
+ tokens = self.tokenizer(x) # (batch_size, 1+num_features, embedding_dim)
127
+
128
+ # Transformer blocks
129
+ for block in self.transformer_blocks:
130
+ tokens = block(tokens)
131
+
132
+ # Extract CLS token (first token)
133
+ cls_token = tokens[:, 0, :] # (batch_size, embedding_dim)
134
+
135
+ # Output head
136
+ return self.output_head(cls_token)
137
+
138
+
139
+ @click.command()
140
+ @click.option("--epochs", default=300, help="Number of training epochs")
141
+ @click.option("--batch-size", default=1024, help="Batch size for training")
142
+ @click.option("--lr", default=3e-4, help="Learning rate")
143
+ @click.option("--patience", default=20, help="Early stopping patience")
144
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
145
+ """
146
+ Train FT-Transformer model for xyY to Munsell conversion.
147
+
148
+ Notes
149
+ -----
150
+ The training pipeline:
151
+ 1. Loads normalization parameters from existing config
152
+ 2. Loads training data from cache
153
+ 3. Normalizes inputs and outputs to [0, 1] range
154
+ 4. Creates PyTorch DataLoaders
155
+ 5. Initializes FT-Transformer with feature tokenization
156
+ 6. Trains with AdamW optimizer and precision-focused loss
157
+ 7. Uses learning rate scheduler (ReduceLROnPlateau)
158
+ 8. Implements early stopping based on validation loss
159
+ 9. Exports best model to ONNX format
160
+ 10. Logs all metrics and artifacts to MLflow
161
+ """
162
+
163
+ LOGGER.info("=" * 80)
164
+ LOGGER.info("FT-Transformer: xyY → Munsell")
165
+ LOGGER.info("=" * 80)
166
+
167
+ # Set device
168
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+ LOGGER.info("Using device: %s", device)
170
+
171
+ # Paths
172
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
173
+ data_dir = PROJECT_ROOT / "data"
174
+ cache_file = data_dir / "training_data.npz"
175
+
176
+ # Load training data
177
+ LOGGER.info("")
178
+ LOGGER.info("Loading training data from %s...", cache_file)
179
+ data = np.load(cache_file)
180
+ X_train = data["X_train"]
181
+ y_train = data["y_train"]
182
+ X_val = data["X_val"]
183
+ y_val = data["y_val"]
184
+
185
+ LOGGER.info("Train samples: %d", len(X_train))
186
+ LOGGER.info("Validation samples: %d", len(X_val))
187
+
188
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
189
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
190
+ y_train_norm = normalize_munsell(y_train, output_parameters)
191
+ y_val_norm = normalize_munsell(y_val, output_parameters)
192
+
193
+ # Convert to PyTorch tensors
194
+ X_train_t = torch.FloatTensor(X_train)
195
+ y_train_t = torch.FloatTensor(y_train_norm)
196
+ X_val_t = torch.FloatTensor(X_val)
197
+ y_val_t = torch.FloatTensor(y_val_norm)
198
+
199
+ # Create data loaders
200
+ train_dataset = TensorDataset(X_train_t, y_train_t)
201
+ val_dataset = TensorDataset(X_val_t, y_val_t)
202
+
203
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
204
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
205
+
206
+ # Initialize model
207
+ model = FTTransformer(
208
+ num_features=3,
209
+ embedding_dim=256,
210
+ num_blocks=4,
211
+ num_heads=4,
212
+ ff_dim=512,
213
+ dropout=0.1,
214
+ ).to(device)
215
+
216
+ LOGGER.info("")
217
+ LOGGER.info("FT-Transformer architecture:")
218
+ LOGGER.info("%s", model)
219
+
220
+ # Count parameters
221
+ total_params = sum(p.numel() for p in model.parameters())
222
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
223
+
224
+ # Training setup
225
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
226
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
227
+ optimizer, mode="min", factor=0.5, patience=5
228
+ )
229
+ criterion = precision_focused_loss
230
+
231
+ # MLflow setup
232
+ run_name = setup_mlflow_experiment("from_xyY", "ft_transformer")
233
+
234
+ LOGGER.info("")
235
+ LOGGER.info("MLflow run: %s", run_name)
236
+
237
+ # Training loop
238
+ best_val_loss = float("inf")
239
+ patience_counter = 0
240
+
241
+ LOGGER.info("")
242
+ LOGGER.info("Starting training...")
243
+
244
+ with mlflow.start_run(run_name=run_name):
245
+ mlflow.log_params(
246
+ {
247
+ "model": "ft_transformer",
248
+ "learning_rate": lr,
249
+ "batch_size": batch_size,
250
+ "num_epochs": epochs,
251
+ "patience": patience,
252
+ "total_params": total_params,
253
+ }
254
+ )
255
+
256
+ for epoch in range(epochs):
257
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
258
+ val_loss = validate(model, val_loader, criterion, device)
259
+
260
+ scheduler.step(val_loss)
261
+
262
+ log_training_epoch(
263
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
264
+ )
265
+
266
+ LOGGER.info(
267
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
268
+ epoch + 1,
269
+ epochs,
270
+ train_loss,
271
+ val_loss,
272
+ optimizer.param_groups[0]["lr"],
273
+ )
274
+
275
+ # Early stopping
276
+ if val_loss < best_val_loss:
277
+ best_val_loss = val_loss
278
+ patience_counter = 0
279
+
280
+ model_directory.mkdir(exist_ok=True)
281
+ checkpoint_file = model_directory / "ft_transformer_best.pth"
282
+
283
+ torch.save(
284
+ {
285
+ "model_state_dict": model.state_dict(),
286
+ "epoch": epoch,
287
+ "val_loss": val_loss,
288
+ },
289
+ checkpoint_file,
290
+ )
291
+
292
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
293
+ else:
294
+ patience_counter += 1
295
+ if patience_counter >= patience:
296
+ LOGGER.info("")
297
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
298
+ break
299
+
300
+ mlflow.log_metrics(
301
+ {
302
+ "best_val_loss": best_val_loss,
303
+ "final_epoch": epoch + 1,
304
+ }
305
+ )
306
+
307
+ # Export to ONNX
308
+ LOGGER.info("")
309
+ LOGGER.info("Exporting to ONNX...")
310
+ model.eval()
311
+
312
+ checkpoint = torch.load(checkpoint_file)
313
+ model.load_state_dict(checkpoint["model_state_dict"])
314
+
315
+ dummy_input = torch.randn(2, 3).to(device)
316
+
317
+ onnx_file = model_directory / "ft_transformer.onnx"
318
+ torch.onnx.export(
319
+ model,
320
+ dummy_input,
321
+ onnx_file,
322
+ export_params=True,
323
+ opset_version=15,
324
+ input_names=["xyY"],
325
+ output_names=["munsell_spec"],
326
+ dynamic_axes={
327
+ "xyY": {0: "batch_size"},
328
+ "munsell_spec": {0: "batch_size"},
329
+ },
330
+ )
331
+
332
+ fix_onnx_dynamic_batch(onnx_file)
333
+
334
+ # Save normalization parameters alongside model
335
+ params_file = model_directory / "ft_transformer_normalization_parameters.npz"
336
+ input_parameters = XYY_NORMALIZATION_PARAMS
337
+ np.savez(
338
+ params_file,
339
+ input_parameters=input_parameters,
340
+ output_parameters=output_parameters,
341
+ )
342
+
343
+ mlflow.log_artifact(str(checkpoint_file))
344
+ mlflow.log_artifact(str(onnx_file))
345
+ mlflow.log_artifact(str(params_file))
346
+ mlflow.pytorch.log_model(model, "model")
347
+
348
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
349
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
350
+ LOGGER.info("Artifacts logged to MLflow")
351
+
352
+ LOGGER.info("=" * 80)
353
+
354
+
355
+ if __name__ == "__main__":
356
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
357
+
358
+ main()
learning_munsell/training/from_xyY/train_mixture_of_experts.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Mixture of Experts model for xyY to Munsell conversion.
3
+
4
+ Option 6: Mixture of Experts architecture
5
+ - Input: 3 features (xyY)
6
+ - Gating network: 3 → 128 → 64 → 4 (softmax weights)
7
+ - 4 Expert networks: Each 3 → 256 → 256 → 4 (MLP)
8
+ - Output: Weighted combination of expert outputs
9
+ """
10
+
11
+ import logging
12
+
13
+ import click
14
+ import mlflow
15
+ import mlflow.pytorch
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn, optim
19
+ from torch.utils.data import DataLoader, TensorDataset
20
+
21
+ from learning_munsell import PROJECT_ROOT
22
+ from learning_munsell.models.networks import ResidualBlock
23
+ from learning_munsell.utilities.common import (
24
+ log_training_epoch,
25
+ setup_mlflow_experiment,
26
+ )
27
+ from learning_munsell.utilities.data import (
28
+ MUNSELL_NORMALIZATION_PARAMS,
29
+ XYY_NORMALIZATION_PARAMS,
30
+ normalize_munsell,
31
+ )
32
+
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ class ExpertNetwork(nn.Module):
37
+ """
38
+ Single expert network with MLP architecture.
39
+
40
+ Each expert is a specialized neural network that learns to handle
41
+ specific regions of the input space. Uses residual connections for
42
+ improved gradient flow.
43
+
44
+ Architecture
45
+ ------------
46
+ - Encoder: 3 → 256 with GELU and BatchNorm
47
+ - Residual blocks: Configurable number of ResidualBlock(256)
48
+ - Decoder: 256 → 4
49
+
50
+ Parameters
51
+ ----------
52
+ num_residual_blocks : int, optional
53
+ Number of residual blocks, default is 2.
54
+
55
+ Attributes
56
+ ----------
57
+ encoder : nn.Sequential
58
+ Input encoding layer.
59
+ residual_blocks : nn.ModuleList
60
+ Stack of residual blocks.
61
+ decoder : nn.Sequential
62
+ Output decoding layer.
63
+ """
64
+
65
+ def __init__(self, num_residual_blocks: int = 2) -> None:
66
+ """Initialize the expert network."""
67
+ super().__init__()
68
+
69
+ self.encoder = nn.Sequential(
70
+ nn.Linear(3, 256),
71
+ nn.GELU(),
72
+ nn.BatchNorm1d(256),
73
+ )
74
+
75
+ self.residual_blocks = nn.ModuleList(
76
+ [ResidualBlock(256) for _ in range(num_residual_blocks)]
77
+ )
78
+
79
+ self.decoder = nn.Sequential(
80
+ nn.Linear(256, 4),
81
+ )
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Forward pass through expert network.
86
+
87
+ Parameters
88
+ ----------
89
+ x : Tensor
90
+ Input xyY values of shape (batch_size, 3).
91
+
92
+ Returns
93
+ -------
94
+ Tensor
95
+ Expert's prediction of shape (batch_size, 4).
96
+ """
97
+ x = self.encoder(x)
98
+ for block in self.residual_blocks:
99
+ x = block(x)
100
+ return self.decoder(x)
101
+
102
+
103
+ class GatingNetwork(nn.Module):
104
+ """
105
+ Gating network to compute expert weights.
106
+
107
+ Learns to route inputs to appropriate experts by outputting a probability
108
+ distribution over all experts. Different inputs activate different experts
109
+ based on learned input characteristics.
110
+
111
+ Architecture
112
+ ------------
113
+ 3 → 128 → 64 → num_experts → softmax
114
+
115
+ Parameters
116
+ ----------
117
+ num_experts : int
118
+ Number of expert networks to gate.
119
+
120
+ Attributes
121
+ ----------
122
+ gate : nn.Sequential
123
+ MLP that maps inputs to expert logits.
124
+ """
125
+
126
+ def __init__(self, num_experts: int) -> None:
127
+ """Initialize the gating network."""
128
+ super().__init__()
129
+
130
+ self.gate = nn.Sequential(
131
+ nn.Linear(3, 128),
132
+ nn.GELU(),
133
+ nn.BatchNorm1d(128),
134
+ nn.Linear(128, 64),
135
+ nn.GELU(),
136
+ nn.BatchNorm1d(64),
137
+ nn.Linear(64, num_experts),
138
+ )
139
+
140
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
141
+ """
142
+ Compute expert weights for input.
143
+
144
+ Parameters
145
+ ----------
146
+ x : Tensor
147
+ Input xyY values of shape (batch_size, 3).
148
+
149
+ Returns
150
+ -------
151
+ Tensor
152
+ Softmax weights over experts of shape (batch_size, num_experts).
153
+ Weights sum to 1 along expert dimension.
154
+ """
155
+ # Output softmax weights for each expert
156
+ return torch.softmax(self.gate(x), dim=-1)
157
+
158
+
159
+ class MixtureOfExperts(nn.Module):
160
+ """
161
+ Mixture of Experts for xyY to Munsell conversion.
162
+
163
+ Implements a mixture of experts architecture where multiple specialized
164
+ neural networks (experts) are combined via learned gating weights. This
165
+ allows different experts to specialize in different regions of the input
166
+ space (e.g., different color ranges or hue families).
167
+
168
+ Architecture
169
+ ------------
170
+ - Gating network: Learns which expert(s) to use for each input
171
+ - Multiple expert networks: Each specializes in different input regions
172
+ - Output: Weighted combination of expert predictions based on gate weights
173
+ - Load balancing: Auxiliary loss encourages balanced expert usage
174
+
175
+ Parameters
176
+ ----------
177
+ num_experts : int, optional
178
+ Number of expert networks, default is 4.
179
+ num_residual_blocks : int, optional
180
+ Number of residual blocks per expert, default is 2.
181
+
182
+ Attributes
183
+ ----------
184
+ num_experts : int
185
+ Number of expert networks.
186
+ gating_network : GatingNetwork
187
+ Network that computes expert weights.
188
+ experts : nn.ModuleList
189
+ List of expert networks.
190
+ load_balance_weight : float
191
+ Weight for load balancing auxiliary loss.
192
+ """
193
+
194
+ def __init__(self, num_experts: int = 4, num_residual_blocks: int = 2) -> None:
195
+ """Initialize the mixture of experts model."""
196
+ super().__init__()
197
+
198
+ self.num_experts = num_experts
199
+
200
+ # Gating network
201
+ self.gating_network = GatingNetwork(num_experts)
202
+
203
+ # Expert networks
204
+ self.experts = nn.ModuleList(
205
+ [ExpertNetwork(num_residual_blocks) for _ in range(num_experts)]
206
+ )
207
+
208
+ # Load balancing loss weight
209
+ self.load_balance_weight = 0.01
210
+
211
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
212
+ """
213
+ Forward pass through mixture of experts.
214
+
215
+ Parameters
216
+ ----------
217
+ x : Tensor
218
+ Input xyY values of shape (batch_size, 3).
219
+
220
+ Returns
221
+ -------
222
+ tuple
223
+ (output, gate_weights) where:
224
+ - output: Weighted expert predictions of shape (batch_size, 4)
225
+ - gate_weights: Expert weights of shape (batch_size, num_experts)
226
+ """
227
+ # Get gating weights
228
+ gate_weights = self.gating_network(x) # (batch_size, num_experts)
229
+
230
+ # Get expert outputs
231
+ expert_outputs = torch.stack(
232
+ [expert(x) for expert in self.experts], dim=1
233
+ ) # (batch_size, num_experts, 4)
234
+
235
+ # Weighted combination
236
+ gate_weights_expanded = gate_weights.unsqueeze(
237
+ -1
238
+ ) # (batch_size, num_experts, 1)
239
+ output = torch.sum(
240
+ expert_outputs * gate_weights_expanded, dim=1
241
+ ) # (batch_size, 4)
242
+
243
+ return output, gate_weights
244
+
245
+
246
+ def precision_focused_loss(
247
+ pred: torch.Tensor,
248
+ target: torch.Tensor,
249
+ gate_weights: torch.Tensor,
250
+ load_balance_weight: float = 0.01,
251
+ ) -> torch.Tensor:
252
+ """
253
+ Precision-focused loss function with load balancing for mixture of experts.
254
+
255
+ Combines standard regression losses (MSE, MAE, log penalty, Huber) with
256
+ a load balancing auxiliary loss that encourages uniform expert usage across
257
+ the dataset to prevent expert collapse.
258
+
259
+ Parameters
260
+ ----------
261
+ pred : torch.Tensor
262
+ Predicted values.
263
+ target : torch.Tensor
264
+ Target ground truth values.
265
+ gate_weights : torch.Tensor
266
+ Expert gating weights of shape (batch_size, num_experts).
267
+ load_balance_weight : float, optional
268
+ Weight for load balancing auxiliary loss, default is 0.01.
269
+
270
+ Returns
271
+ -------
272
+ torch.Tensor
273
+ Combined loss value including load balancing term.
274
+
275
+ Notes
276
+ -----
277
+ The load balancing loss encourages each expert to handle roughly
278
+ 1/num_experts of the data, preventing scenarios where only a few
279
+ experts are used while others remain idle.
280
+ """
281
+ # Standard precision loss
282
+ mse = torch.mean((pred - target) ** 2)
283
+ mae = torch.mean(torch.abs(pred - target))
284
+ log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
285
+
286
+ delta = 0.01
287
+ abs_error = torch.abs(pred - target)
288
+ huber = torch.where(
289
+ abs_error <= delta, 0.5 * abs_error**2, delta * (abs_error - 0.5 * delta)
290
+ )
291
+ huber_loss = torch.mean(huber)
292
+
293
+ # Load balancing loss: Encourage balanced expert usage
294
+ # Compute importance (sum of gate weights per expert)
295
+ importance = gate_weights.sum(dim=0) # (num_experts,)
296
+ # Normalize to probabilities
297
+ importance = importance / importance.sum()
298
+ # Encourage uniform distribution (1/num_experts for each)
299
+ num_experts = gate_weights.size(1)
300
+ target_importance = torch.ones_like(importance) / num_experts
301
+ load_balance_loss = torch.mean((importance - target_importance) ** 2)
302
+
303
+ return (
304
+ 1.0 * mse
305
+ + 0.5 * mae
306
+ + 0.3 * log_penalty
307
+ + 0.5 * huber_loss
308
+ + load_balance_weight * load_balance_loss
309
+ )
310
+
311
+
312
+ def train_epoch(
313
+ model: nn.Module,
314
+ dataloader: DataLoader,
315
+ optimizer: optim.Optimizer,
316
+ device: torch.device,
317
+ ) -> float:
318
+ """
319
+ Train the mixture of experts model for one epoch.
320
+
321
+ Parameters
322
+ ----------
323
+ model : nn.Module
324
+ The neural network model to train.
325
+ dataloader : DataLoader
326
+ DataLoader providing training batches (X, y).
327
+ optimizer : optim.Optimizer
328
+ Optimizer for updating model parameters.
329
+ device : torch.device
330
+ Device to run training on.
331
+
332
+ Returns
333
+ -------
334
+ float
335
+ Average loss for the epoch.
336
+
337
+ Notes
338
+ -----
339
+ Loss includes both prediction error and load balancing term.
340
+ The loss function is computed by precision_focused_loss which is
341
+ passed gate_weights for load balancing.
342
+ """
343
+ model.train()
344
+ total_loss = 0.0
345
+
346
+ for X_batch, y_batch in dataloader:
347
+ X_batch = X_batch.to(device) # noqa: PLW2901
348
+ y_batch = y_batch.to(device) # noqa: PLW2901
349
+ outputs, gate_weights = model(X_batch)
350
+ loss = precision_focused_loss(
351
+ outputs, y_batch, gate_weights, model.load_balance_weight
352
+ )
353
+
354
+ optimizer.zero_grad()
355
+ loss.backward()
356
+ optimizer.step()
357
+
358
+ total_loss += loss.item()
359
+
360
+ return total_loss / len(dataloader)
361
+
362
+
363
+ def validate(model: nn.Module, dataloader: DataLoader, device: torch.device) -> float:
364
+ """
365
+ Validate the mixture of experts model on validation set.
366
+
367
+ Parameters
368
+ ----------
369
+ model : nn.Module
370
+ The neural network model to validate.
371
+ dataloader : DataLoader
372
+ DataLoader providing validation batches (X, y).
373
+ device : torch.device
374
+ Device to run validation on.
375
+
376
+ Returns
377
+ -------
378
+ float
379
+ Average loss for the validation set.
380
+ """
381
+ model.eval()
382
+ total_loss = 0.0
383
+
384
+ with torch.no_grad():
385
+ for X_batch, y_batch in dataloader:
386
+ X_batch = X_batch.to(device) # noqa: PLW2901
387
+ y_batch = y_batch.to(device) # noqa: PLW2901
388
+ outputs, gate_weights = model(X_batch)
389
+ loss = precision_focused_loss(
390
+ outputs, y_batch, gate_weights, model.load_balance_weight
391
+ )
392
+
393
+ total_loss += loss.item()
394
+
395
+ return total_loss / len(dataloader)
396
+
397
+
398
+ @click.command()
399
+ @click.option("--epochs", default=300, help="Number of training epochs")
400
+ @click.option("--batch-size", default=1024, help="Batch size for training")
401
+ @click.option("--lr", default=3e-4, help="Learning rate")
402
+ @click.option("--patience", default=20, help="Early stopping patience")
403
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
404
+ """
405
+ Train mixture of experts model for xyY to Munsell conversion.
406
+
407
+ Notes
408
+ -----
409
+ The training pipeline:
410
+ 1. Loads normalization parameters from existing config
411
+ 2. Loads training data from cache
412
+ 3. Normalizes inputs and outputs to [0, 1] range
413
+ 4. Creates PyTorch DataLoaders
414
+ 5. Initializes MixtureOfExperts with 4 expert networks
415
+ 6. Trains with AdamW optimizer and precision-focused loss
416
+ 7. Uses learning rate scheduler (ReduceLROnPlateau)
417
+ 8. Implements early stopping based on validation loss
418
+ 9. Exports best model to ONNX format
419
+ 10. Logs all metrics and artifacts to MLflow
420
+ """
421
+
422
+ LOGGER.info("=" * 80)
423
+ LOGGER.info("Mixture of Experts: xyY → Munsell")
424
+ LOGGER.info("=" * 80)
425
+
426
+ # Set device
427
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
428
+ LOGGER.info("Using device: %s", device)
429
+
430
+ # Paths
431
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
432
+ data_dir = PROJECT_ROOT / "data"
433
+ cache_file = data_dir / "training_data.npz"
434
+
435
+ # Load training data
436
+ LOGGER.info("")
437
+ LOGGER.info("Loading training data from %s...", cache_file)
438
+ data = np.load(cache_file)
439
+ X_train = data["X_train"]
440
+ y_train = data["y_train"]
441
+ X_val = data["X_val"]
442
+ y_val = data["y_val"]
443
+
444
+ LOGGER.info("Train samples: %d", len(X_train))
445
+ LOGGER.info("Validation samples: %d", len(X_val))
446
+
447
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
448
+ # Use hardcoded ranges covering the full Munsell space for generalization
449
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
450
+ y_train_norm = normalize_munsell(y_train, output_parameters)
451
+ y_val_norm = normalize_munsell(y_val, output_parameters)
452
+
453
+ # Convert to PyTorch tensors
454
+ X_train_t = torch.FloatTensor(X_train)
455
+ y_train_t = torch.FloatTensor(y_train_norm)
456
+ X_val_t = torch.FloatTensor(X_val)
457
+ y_val_t = torch.FloatTensor(y_val_norm)
458
+
459
+ # Create data loaders
460
+ train_dataset = TensorDataset(X_train_t, y_train_t)
461
+ val_dataset = TensorDataset(X_val_t, y_val_t)
462
+
463
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
464
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
465
+
466
+ # Initialize model
467
+ model = MixtureOfExperts(num_experts=4, num_residual_blocks=2).to(device)
468
+ LOGGER.info("")
469
+ LOGGER.info("Mixture of Experts architecture:")
470
+ LOGGER.info("%s", model)
471
+
472
+ # Count parameters
473
+ total_params = sum(p.numel() for p in model.parameters())
474
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
475
+
476
+ # Training setup
477
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
478
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
479
+ optimizer, mode="min", factor=0.5, patience=5
480
+ )
481
+
482
+ # MLflow setup
483
+ run_name = setup_mlflow_experiment("from_xyY", "mixture_of_experts")
484
+
485
+ LOGGER.info("")
486
+ LOGGER.info("MLflow run: %s", run_name)
487
+
488
+ # Training loop
489
+ best_val_loss = float("inf")
490
+ patience_counter = 0
491
+
492
+ LOGGER.info("")
493
+ LOGGER.info("Starting training...")
494
+
495
+ with mlflow.start_run(run_name=run_name):
496
+ mlflow.log_params(
497
+ {
498
+ "model": "mixture_of_experts",
499
+ "learning_rate": lr,
500
+ "batch_size": batch_size,
501
+ "num_epochs": epochs,
502
+ "patience": patience,
503
+ "total_params": total_params,
504
+ }
505
+ )
506
+
507
+ for epoch in range(epochs):
508
+ train_loss = train_epoch(model, train_loader, optimizer, device)
509
+ val_loss = validate(model, val_loader, device)
510
+
511
+ scheduler.step(val_loss)
512
+
513
+ log_training_epoch(
514
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
515
+ )
516
+
517
+ LOGGER.info(
518
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
519
+ epoch + 1,
520
+ epochs,
521
+ train_loss,
522
+ val_loss,
523
+ optimizer.param_groups[0]["lr"],
524
+ )
525
+
526
+ # Early stopping
527
+ if val_loss < best_val_loss:
528
+ best_val_loss = val_loss
529
+ patience_counter = 0
530
+
531
+ model_directory.mkdir(exist_ok=True)
532
+ checkpoint_file = model_directory / "mixture_of_experts_best.pth"
533
+
534
+ torch.save(
535
+ {
536
+ "model_state_dict": model.state_dict(),
537
+ "epoch": epoch,
538
+ "val_loss": val_loss,
539
+ },
540
+ checkpoint_file,
541
+ )
542
+
543
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
544
+ else:
545
+ patience_counter += 1
546
+ if patience_counter >= patience:
547
+ LOGGER.info("")
548
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
549
+ break
550
+
551
+ mlflow.log_metrics(
552
+ {
553
+ "best_val_loss": best_val_loss,
554
+ "final_epoch": epoch + 1,
555
+ }
556
+ )
557
+
558
+ # Export to ONNX (simplified - outputs only prediction, not gate weights)
559
+ LOGGER.info("")
560
+ LOGGER.info("Exporting to ONNX...")
561
+ model.eval()
562
+
563
+ checkpoint = torch.load(checkpoint_file)
564
+ model.load_state_dict(checkpoint["model_state_dict"])
565
+
566
+ # Create wrapper for ONNX export (only return prediction)
567
+ class MoEWrapper(nn.Module):
568
+ def __init__(self, moe_model: nn.Module) -> None:
569
+ super().__init__()
570
+ self.moe_model = moe_model
571
+
572
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
573
+ output, _ = self.moe_model(x)
574
+ return output
575
+
576
+ wrapped_model = MoEWrapper(model).to(device)
577
+ wrapped_model.eval()
578
+
579
+ dummy_input = torch.randn(1, 3).to(device)
580
+
581
+ onnx_file = model_directory / "mixture_of_experts.onnx"
582
+ torch.onnx.export(
583
+ wrapped_model,
584
+ dummy_input,
585
+ onnx_file,
586
+ export_params=True,
587
+ opset_version=15,
588
+ input_names=["xyY"],
589
+ output_names=["munsell_spec"],
590
+ dynamic_axes={
591
+ "xyY": {0: "batch_size"},
592
+ "munsell_spec": {0: "batch_size"},
593
+ },
594
+ )
595
+
596
+ # Save normalization parameters alongside model
597
+ params_file = (
598
+ model_directory / "mixture_of_experts_normalization_parameters.npz"
599
+ )
600
+ input_parameters = XYY_NORMALIZATION_PARAMS
601
+ np.savez(
602
+ params_file,
603
+ input_parameters=input_parameters,
604
+ output_parameters=output_parameters,
605
+ )
606
+
607
+ mlflow.log_artifact(str(checkpoint_file))
608
+ mlflow.log_artifact(str(onnx_file))
609
+ mlflow.log_artifact(str(params_file))
610
+ mlflow.pytorch.log_model(model, "model")
611
+
612
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
613
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
614
+ LOGGER.info("Artifacts logged to MLflow")
615
+
616
+ LOGGER.info("=" * 80)
617
+
618
+
619
+ if __name__ == "__main__":
620
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
621
+
622
+ main()
learning_munsell/training/from_xyY/train_mlp.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train ML model for xyY to Munsell conversion.
3
+
4
+ This script trains a compact MLP/DNN model with architecture:
5
+ 3 inputs → [64, 128, 128, 64] hidden layers → 4 outputs
6
+
7
+ Target: < 1e-7 accuracy compared to iterative algorithm
8
+ """
9
+
10
+ import logging
11
+
12
+ import click
13
+ import mlflow
14
+ import mlflow.pytorch
15
+ import numpy as np
16
+ import torch
17
+ from torch import optim
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from learning_munsell import PROJECT_ROOT
21
+ from learning_munsell.models.networks import MLPToMunsell
22
+ from learning_munsell.utilities.common import (
23
+ log_training_epoch,
24
+ setup_mlflow_experiment,
25
+ )
26
+ from learning_munsell.utilities.data import (
27
+ MUNSELL_NORMALIZATION_PARAMS,
28
+ XYY_NORMALIZATION_PARAMS,
29
+ normalize_munsell,
30
+ )
31
+ from learning_munsell.utilities.losses import weighted_mse_loss
32
+ from learning_munsell.utilities.training import train_epoch, validate
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ @click.command()
38
+ @click.option("--epochs", default=300, help="Maximum training epochs.")
39
+ @click.option("--batch-size", default=1024, help="Training batch size.")
40
+ @click.option("--lr", default=5e-4, help="Learning rate.")
41
+ @click.option("--patience", default=20, help="Early stopping patience.")
42
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
43
+ """
44
+ Train the MLPToMunsell model for xyY to Munsell conversion.
45
+
46
+ Parameters
47
+ ----------
48
+ epochs : int
49
+ Maximum number of training epochs.
50
+ batch_size : int
51
+ Training batch size.
52
+ lr : float
53
+ Learning rate for AdamW optimizer.
54
+ patience : int
55
+ Early stopping patience (epochs without improvement).
56
+
57
+ Notes
58
+ -----
59
+ The training pipeline:
60
+ 1. Loads training data from cache
61
+ 2. Normalizes Munsell outputs to [0, 1] range
62
+ 3. Trains compact MLP model (3 → [64, 128, 128, 64] → 4)
63
+ 4. Uses weighted MSE loss function
64
+ 5. Learning rate scheduling with ReduceLROnPlateau
65
+ 6. Early stopping based on validation loss
66
+ 7. Exports model to ONNX format
67
+ 8. Logs metrics and artifacts to MLflow
68
+ """
69
+ LOGGER.info("=" * 80)
70
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Model Training")
71
+ LOGGER.info("=" * 80)
72
+
73
+ # Set device
74
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ LOGGER.info("Using device: %s", device)
76
+
77
+ # Load training data
78
+ data_dir = PROJECT_ROOT / "data"
79
+ cache_file = data_dir / "training_data.npz"
80
+
81
+ if not cache_file.exists():
82
+ LOGGER.error("Error: Training data not found at %s", cache_file)
83
+ LOGGER.error("Please run 01_generate_training_data.py first")
84
+ return
85
+
86
+ LOGGER.info("Loading training data from %s...", cache_file)
87
+ data = np.load(cache_file)
88
+
89
+ X_train = data["X_train"]
90
+ y_train = data["y_train"]
91
+ X_val = data["X_val"]
92
+ y_val = data["y_val"]
93
+
94
+ # Note: Invalid samples (outside Munsell gamut) are also stored in the cache
95
+ # Available as: data['xyY_all'], data['munsell_all'], data['valid_mask']
96
+ # These can be used for future enhancements like:
97
+ # - Adversarial training to avoid extrapolation
98
+ # - Gamut-aware loss functions
99
+ # - Uncertainty estimation at boundaries
100
+
101
+ LOGGER.info("Train samples: %d", len(X_train))
102
+ LOGGER.info("Validation samples: %d", len(X_val))
103
+
104
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
105
+ # Use hardcoded ranges covering the full Munsell space for generalization
106
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
107
+ y_train_norm = normalize_munsell(y_train, output_parameters)
108
+ y_val_norm = normalize_munsell(y_val, output_parameters)
109
+
110
+ # Convert to PyTorch tensors
111
+ X_train_t = torch.FloatTensor(X_train)
112
+ y_train_t = torch.FloatTensor(y_train_norm)
113
+ X_val_t = torch.FloatTensor(X_val)
114
+ y_val_t = torch.FloatTensor(y_val_norm)
115
+
116
+ # Create data loaders
117
+ train_dataset = TensorDataset(X_train_t, y_train_t)
118
+ val_dataset = TensorDataset(X_val_t, y_val_t)
119
+
120
+ # Larger batch size for larger dataset (500K samples)
121
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
122
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
123
+
124
+ # Initialize model
125
+ model = MLPToMunsell().to(device)
126
+ LOGGER.info("")
127
+ LOGGER.info("Model architecture:")
128
+ LOGGER.info("%s", model)
129
+
130
+ # Count parameters
131
+ total_params = sum(p.numel() for p in model.parameters())
132
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
133
+
134
+ # Training setup - lower learning rate for larger model
135
+ optimizer = optim.Adam(model.parameters(), lr=lr)
136
+ # Use weighted MSE with default weights
137
+ weights = torch.tensor([1.0, 1.0, 2.0, 0.5])
138
+
139
+ def criterion(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
140
+ return weighted_mse_loss(pred, target, weights)
141
+
142
+ # MLflow setup
143
+ run_name = setup_mlflow_experiment("from_xyY", "mlp")
144
+
145
+ LOGGER.info("")
146
+ LOGGER.info("MLflow run: %s", run_name)
147
+
148
+ # Training loop
149
+ best_val_loss = float("inf")
150
+ patience_counter = 0
151
+
152
+ LOGGER.info("")
153
+ LOGGER.info("Starting training...")
154
+
155
+ with mlflow.start_run(run_name=run_name):
156
+ # Log hyperparameters
157
+ mlflow.log_params(
158
+ {
159
+ "epochs": epochs,
160
+ "batch_size": batch_size,
161
+ "learning_rate": lr,
162
+ "optimizer": "Adam",
163
+ "criterion": "weighted_mse_loss",
164
+ "patience": patience,
165
+ "total_params": total_params,
166
+ }
167
+ )
168
+
169
+ for epoch in range(epochs):
170
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
171
+ val_loss = validate(model, val_loader, criterion, device)
172
+
173
+ # Log to MLflow
174
+ log_training_epoch(
175
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
176
+ )
177
+
178
+ LOGGER.info(
179
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
180
+ epoch + 1,
181
+ epochs,
182
+ train_loss,
183
+ val_loss,
184
+ )
185
+
186
+ # Early stopping
187
+ if val_loss < best_val_loss:
188
+ best_val_loss = val_loss
189
+ patience_counter = 0
190
+
191
+ # Save best model
192
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
193
+ model_directory.mkdir(exist_ok=True)
194
+ checkpoint_file = model_directory / "mlp_best.pth"
195
+
196
+ torch.save(
197
+ {
198
+ "model_state_dict": model.state_dict(),
199
+ "output_parameters": output_parameters,
200
+ "epoch": epoch,
201
+ "val_loss": val_loss,
202
+ },
203
+ checkpoint_file,
204
+ )
205
+
206
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
207
+ else:
208
+ patience_counter += 1
209
+ if patience_counter >= patience:
210
+ LOGGER.info("")
211
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
212
+ break
213
+
214
+ # Log final metrics
215
+ mlflow.log_metrics(
216
+ {
217
+ "best_val_loss": best_val_loss,
218
+ "final_epoch": epoch + 1,
219
+ }
220
+ )
221
+
222
+ # Export to ONNX
223
+ LOGGER.info("")
224
+ LOGGER.info("Exporting model to ONNX...")
225
+ model.eval()
226
+
227
+ # Load best model
228
+ checkpoint = torch.load(checkpoint_file)
229
+ model.load_state_dict(checkpoint["model_state_dict"])
230
+
231
+ # Create dummy input
232
+ dummy_input = torch.randn(1, 3).to(device)
233
+
234
+ # Export
235
+ onnx_file = model_directory / "mlp.onnx"
236
+ torch.onnx.export(
237
+ model,
238
+ dummy_input,
239
+ onnx_file,
240
+ export_params=True,
241
+ opset_version=15,
242
+ input_names=["xyY"],
243
+ output_names=["munsell_spec"],
244
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
245
+ )
246
+
247
+ # Save normalization parameters alongside model
248
+ params_file = model_directory / "mlp_normalization_parameters.npz"
249
+ input_parameters = XYY_NORMALIZATION_PARAMS
250
+ np.savez(
251
+ params_file,
252
+ input_parameters=input_parameters,
253
+ output_parameters=output_parameters,
254
+ )
255
+
256
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
257
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
258
+
259
+ # Log artifacts
260
+ mlflow.log_artifact(str(checkpoint_file))
261
+ mlflow.log_artifact(str(onnx_file))
262
+ mlflow.log_artifact(str(params_file))
263
+
264
+ # Log model
265
+ mlflow.pytorch.log_model(model, "model")
266
+
267
+ LOGGER.info("=" * 80)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
272
+
273
+ main()
learning_munsell/training/from_xyY/train_mlp_attention.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train MLP + Self-Attention model for xyY to Munsell conversion.
3
+
4
+ Option 1: MLP backbone with multi-head self-attention layers
5
+ - Input: 3 features (xyY)
6
+ - Architecture: 3 -> 512 -> 1024 + [Attention + ResBlock] x 4 -> 512 -> 4
7
+ - Output: 4 features (hue, value, chroma, code)
8
+ """
9
+
10
+ import logging
11
+
12
+ import click
13
+ import mlflow
14
+ import mlflow.pytorch
15
+ import numpy as np
16
+ import torch
17
+ from torch import nn, optim
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from learning_munsell import PROJECT_ROOT
21
+ from learning_munsell.models.networks import ResidualBlock
22
+ from learning_munsell.utilities.common import (
23
+ log_training_epoch,
24
+ setup_mlflow_experiment,
25
+ )
26
+ from learning_munsell.utilities.data import (
27
+ MUNSELL_NORMALIZATION_PARAMS,
28
+ XYY_NORMALIZATION_PARAMS,
29
+ normalize_munsell,
30
+ )
31
+ from learning_munsell.utilities.losses import precision_focused_loss
32
+ from learning_munsell.utilities.training import train_epoch, validate
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ class MultiHeadSelfAttention(nn.Module):
38
+ """
39
+ Multi-head self-attention layer for feature interaction.
40
+
41
+ Implements scaled dot-product attention with multiple heads to capture
42
+ different aspects of feature relationships.
43
+
44
+ Parameters
45
+ ----------
46
+ dim
47
+ Input and output feature dimension.
48
+ num_heads
49
+ Number of attention heads. Must divide ``dim`` evenly.
50
+
51
+ Attributes
52
+ ----------
53
+ query
54
+ Linear projection for query vectors.
55
+ key
56
+ Linear projection for key vectors.
57
+ value
58
+ Linear projection for value vectors.
59
+ out
60
+ Output projection after attention.
61
+ scale
62
+ Scaling factor (1/sqrt(head_dim)) for dot-product attention.
63
+ """
64
+
65
+ def __init__(self, dim: int, num_heads: int = 4) -> None:
66
+ super().__init__()
67
+ self.num_heads = num_heads
68
+ self.dim = dim
69
+ self.head_dim = dim // num_heads
70
+
71
+ assert dim % num_heads == 0, "dim must be divisible by num_heads" # noqa: S101
72
+
73
+ self.query = nn.Linear(dim, dim)
74
+ self.key = nn.Linear(dim, dim)
75
+ self.value = nn.Linear(dim, dim)
76
+ self.out = nn.Linear(dim, dim)
77
+
78
+ self.scale = self.head_dim**-0.5
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Apply multi-head self-attention.
83
+
84
+ Parameters
85
+ ----------
86
+ x
87
+ Input tensor of shape ``(batch_size, dim)``.
88
+
89
+ Returns
90
+ -------
91
+ torch.Tensor
92
+ Output tensor of shape ``(batch_size, dim)`` with attention applied.
93
+ """
94
+ batch_size = x.size(0)
95
+
96
+ # Linear projections
97
+ Q = self.query(x).view(batch_size, self.num_heads, self.head_dim)
98
+ K = self.key(x).view(batch_size, self.num_heads, self.head_dim)
99
+ V = self.value(x).view(batch_size, self.num_heads, self.head_dim)
100
+
101
+ # Scaled dot-product attention
102
+ attn_weights = torch.softmax(
103
+ torch.matmul(Q, K.transpose(-2, -1)) * self.scale, dim=-1
104
+ )
105
+
106
+ # Apply attention to values
107
+ attn_output = torch.matmul(attn_weights, V)
108
+
109
+ # Concatenate heads and project
110
+ attn_output = attn_output.view(batch_size, self.dim)
111
+ return self.out(attn_output)
112
+
113
+
114
+ class AttentionResBlock(nn.Module):
115
+ """
116
+ Combined attention and residual block.
117
+
118
+ Applies self-attention followed by a residual MLP block, each with
119
+ batch normalization and skip connections.
120
+
121
+ Parameters
122
+ ----------
123
+ dim
124
+ Input and output feature dimension.
125
+ num_heads
126
+ Number of attention heads for the self-attention layer.
127
+
128
+ Attributes
129
+ ----------
130
+ attention
131
+ Multi-head self-attention layer.
132
+ norm1
133
+ Batch normalization after attention.
134
+ residual
135
+ Residual MLP block.
136
+ norm2
137
+ Batch normalization after residual block.
138
+ """
139
+
140
+ def __init__(self, dim: int, num_heads: int = 4) -> None:
141
+ super().__init__()
142
+ self.attention = MultiHeadSelfAttention(dim, num_heads)
143
+ self.norm1 = nn.BatchNorm1d(dim)
144
+ self.residual = ResidualBlock(dim)
145
+ self.norm2 = nn.BatchNorm1d(dim)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Apply attention and residual transformations.
150
+
151
+ Parameters
152
+ ----------
153
+ x
154
+ Input tensor of shape ``(batch_size, dim)``.
155
+
156
+ Returns
157
+ -------
158
+ torch.Tensor
159
+ Output tensor of shape ``(batch_size, dim)``.
160
+ """
161
+ # Attention with residual
162
+ attn_out = self.norm1(x + self.attention(x))
163
+ # ResBlock with residual
164
+ return self.norm2(self.residual(attn_out))
165
+
166
+
167
+ class MLPAttention(nn.Module):
168
+ """
169
+ MLP with self-attention for xyY to Munsell conversion.
170
+
171
+ Architecture:
172
+ - Input: 3 features (xyY normalized to [0, 1])
173
+ - Encoder: 3 -> 512 -> 1024
174
+ - Attention-ResBlocks at 1024-dim (configurable count)
175
+ - Decoder: 1024 -> 512 -> 4
176
+ - Output: 4 features (hue, value, chroma, code normalized)
177
+
178
+ Parameters
179
+ ----------
180
+ num_blocks
181
+ Number of attention-residual blocks in the middle.
182
+ num_heads
183
+ Number of attention heads in each attention layer.
184
+
185
+ Attributes
186
+ ----------
187
+ encoder
188
+ MLP that projects 3D xyY input to 1024D feature space.
189
+ blocks
190
+ List of AttentionResBlock modules.
191
+ decoder
192
+ MLP that projects 1024D features to 4D Munsell output.
193
+ """
194
+
195
+ def __init__(self, num_blocks: int = 4, num_heads: int = 4) -> None:
196
+ super().__init__()
197
+
198
+ # Encoder
199
+ self.encoder = nn.Sequential(
200
+ nn.Linear(3, 512),
201
+ nn.GELU(),
202
+ nn.BatchNorm1d(512),
203
+ nn.Linear(512, 1024),
204
+ nn.GELU(),
205
+ nn.BatchNorm1d(1024),
206
+ )
207
+
208
+ # Attention-ResBlocks
209
+ self.blocks = nn.ModuleList(
210
+ [AttentionResBlock(1024, num_heads) for _ in range(num_blocks)]
211
+ )
212
+
213
+ # Decoder
214
+ self.decoder = nn.Sequential(
215
+ nn.Linear(1024, 512),
216
+ nn.GELU(),
217
+ nn.BatchNorm1d(512),
218
+ nn.Linear(512, 4),
219
+ )
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ """
223
+ Predict Munsell specification from xyY input.
224
+
225
+ Parameters
226
+ ----------
227
+ x
228
+ Input tensor of shape ``(batch_size, 3)`` containing normalized
229
+ xyY values.
230
+
231
+ Returns
232
+ -------
233
+ torch.Tensor
234
+ Output tensor of shape ``(batch_size, 4)`` containing normalized
235
+ Munsell specification [hue, value, chroma, code].
236
+ """
237
+ # Encode
238
+ x = self.encoder(x)
239
+
240
+ # Attention-ResBlocks
241
+ for block in self.blocks:
242
+ x = block(x)
243
+
244
+ # Decode
245
+ return self.decoder(x)
246
+
247
+
248
+ @click.command()
249
+ @click.option("--epochs", default=300, help="Number of training epochs")
250
+ @click.option("--batch-size", default=1024, help="Batch size for training")
251
+ @click.option("--lr", default=3e-4, help="Learning rate")
252
+ @click.option("--patience", default=20, help="Early stopping patience")
253
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
254
+ """
255
+ Train MLP + Self-Attention model for xyY to Munsell conversion.
256
+
257
+ Notes
258
+ -----
259
+ The training pipeline:
260
+ 1. Loads normalization parameters and training data from disk
261
+ 2. Normalizes inputs (xyY) and outputs (Munsell specification) to [0, 1]
262
+ 3. Creates MLPAttention model (4 blocks, 4 attention heads)
263
+ 4. Trains with precision-focused loss (MSE + MAE + log + Huber)
264
+ 5. Uses AdamW optimizer with ReduceLROnPlateau scheduler
265
+ 6. Applies early stopping based on validation loss (patience=20)
266
+ 7. Exports best model to ONNX format
267
+ 8. Logs metrics and artifacts to MLflow
268
+ """
269
+ LOGGER.info("=" * 80)
270
+ LOGGER.info("MLP + Self-Attention: xyY → Munsell")
271
+ LOGGER.info("=" * 80)
272
+
273
+ # Set device
274
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275
+ LOGGER.info("Using device: %s", device)
276
+
277
+ # Paths
278
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
279
+ data_dir = PROJECT_ROOT / "data"
280
+ cache_file = data_dir / "training_data.npz"
281
+
282
+ # Load training data
283
+ LOGGER.info("")
284
+ LOGGER.info("Loading training data from %s...", cache_file)
285
+ data = np.load(cache_file)
286
+ X_train = data["X_train"]
287
+ y_train = data["y_train"]
288
+ X_val = data["X_val"]
289
+ y_val = data["y_val"]
290
+
291
+ LOGGER.info("Train samples: %d", len(X_train))
292
+ LOGGER.info("Validation samples: %d", len(X_val))
293
+
294
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
295
+ y_train_norm = normalize_munsell(y_train, output_parameters)
296
+ y_val_norm = normalize_munsell(y_val, output_parameters)
297
+
298
+ # Convert to PyTorch tensors
299
+ X_train_t = torch.FloatTensor(X_train)
300
+ y_train_t = torch.FloatTensor(y_train_norm)
301
+ X_val_t = torch.FloatTensor(X_val)
302
+ y_val_t = torch.FloatTensor(y_val_norm)
303
+
304
+ # Create data loaders
305
+ train_dataset = TensorDataset(X_train_t, y_train_t)
306
+ val_dataset = TensorDataset(X_val_t, y_val_t)
307
+
308
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
309
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
310
+
311
+ # Initialize model
312
+ model = MLPAttention(num_blocks=4, num_heads=4).to(device)
313
+ LOGGER.info("")
314
+ LOGGER.info("MLP + Attention architecture:")
315
+ LOGGER.info("%s", model)
316
+
317
+ # Count parameters
318
+ total_params = sum(p.numel() for p in model.parameters())
319
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
320
+
321
+ # Training setup
322
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
323
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
324
+ optimizer, mode="min", factor=0.5, patience=5
325
+ )
326
+ criterion = precision_focused_loss
327
+
328
+ # MLflow setup
329
+ run_name = setup_mlflow_experiment("from_xyY", "mlp_attention")
330
+
331
+ LOGGER.info("")
332
+ LOGGER.info("MLflow run: %s", run_name)
333
+
334
+ # Training loop
335
+ best_val_loss = float("inf")
336
+ patience_counter = 0
337
+
338
+ LOGGER.info("")
339
+ LOGGER.info("Starting training...")
340
+
341
+ with mlflow.start_run(run_name=run_name):
342
+ # Log hyperparameters
343
+ mlflow.log_params(
344
+ {
345
+ "num_epochs": epochs,
346
+ "batch_size": batch_size,
347
+ "learning_rate": lr,
348
+ "weight_decay": 1e-5,
349
+ "optimizer": "AdamW",
350
+ "scheduler": "ReduceLROnPlateau",
351
+ "criterion": "precision_focused_loss",
352
+ "patience": patience,
353
+ "total_params": total_params,
354
+ "num_blocks": 4,
355
+ "num_heads": 4,
356
+ }
357
+ )
358
+
359
+ for epoch in range(epochs):
360
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
361
+ val_loss = validate(model, val_loader, criterion, device)
362
+
363
+ scheduler.step(val_loss)
364
+
365
+ log_training_epoch(
366
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
367
+ )
368
+
369
+ LOGGER.info(
370
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
371
+ epoch + 1,
372
+ epochs,
373
+ train_loss,
374
+ val_loss,
375
+ optimizer.param_groups[0]["lr"],
376
+ )
377
+
378
+ # Early stopping
379
+ if val_loss < best_val_loss:
380
+ best_val_loss = val_loss
381
+ patience_counter = 0
382
+
383
+ model_directory.mkdir(exist_ok=True)
384
+ checkpoint_file = model_directory / "mlp_attention_best.pth"
385
+
386
+ torch.save(
387
+ {
388
+ "model_state_dict": model.state_dict(),
389
+ "epoch": epoch,
390
+ "val_loss": val_loss,
391
+ },
392
+ checkpoint_file,
393
+ )
394
+
395
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
396
+ else:
397
+ patience_counter += 1
398
+ if patience_counter >= patience:
399
+ LOGGER.info("")
400
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
401
+ break
402
+
403
+ # Log final metrics
404
+ mlflow.log_metrics(
405
+ {
406
+ "best_val_loss": best_val_loss,
407
+ "final_epoch": epoch + 1,
408
+ }
409
+ )
410
+
411
+ # Export to ONNX
412
+ LOGGER.info("")
413
+ LOGGER.info("Exporting to ONNX...")
414
+ model.eval()
415
+
416
+ checkpoint = torch.load(checkpoint_file)
417
+ model.load_state_dict(checkpoint["model_state_dict"])
418
+
419
+ dummy_input = torch.randn(1, 3).to(device)
420
+
421
+ onnx_file = model_directory / "mlp_attention.onnx"
422
+ torch.onnx.export(
423
+ model,
424
+ dummy_input,
425
+ onnx_file,
426
+ export_params=True,
427
+ opset_version=15,
428
+ input_names=["xyY"],
429
+ output_names=["munsell_spec"],
430
+ dynamic_axes={
431
+ "xyY": {0: "batch_size"},
432
+ "munsell_spec": {0: "batch_size"},
433
+ },
434
+ )
435
+
436
+ # Save normalization parameters alongside model
437
+ params_file = model_directory / "mlp_attention_normalization_parameters.npz"
438
+ input_parameters = XYY_NORMALIZATION_PARAMS
439
+ np.savez(
440
+ params_file,
441
+ input_parameters=input_parameters,
442
+ output_parameters=output_parameters,
443
+ )
444
+
445
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
446
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
447
+
448
+ # Log artifacts
449
+ mlflow.log_artifact(str(checkpoint_file))
450
+ mlflow.log_artifact(str(onnx_file))
451
+ mlflow.log_artifact(str(params_file))
452
+
453
+ # Log model
454
+ mlflow.pytorch.log_model(model, "model")
455
+
456
+ LOGGER.info("=" * 80)
457
+
458
+
459
+ if __name__ == "__main__":
460
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
461
+
462
+ main()
learning_munsell/training/from_xyY/train_mlp_error_predictor.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train error predictor with advanced MLP architecture.
3
+
4
+ Architecture features:
5
+ - Larger capacity: 7 → 256 → 512 → 512 → 256 → 4
6
+ - Residual connections (MLP-style) for better gradient flow
7
+ - Modern activation functions (GELU instead of ReLU)
8
+ - Precision-focused loss function
9
+
10
+ Generic error predictor that can work with any base model.
11
+ """
12
+
13
+ import logging
14
+ from pathlib import Path
15
+
16
+ import click
17
+ import mlflow
18
+ import mlflow.pytorch
19
+ import numpy as np
20
+ import onnxruntime as ort
21
+ import torch
22
+ from torch import nn, optim
23
+ from torch.utils.data import DataLoader, TensorDataset
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+ from learning_munsell.models.networks import ResidualBlock
27
+ from learning_munsell.utilities.common import (
28
+ log_training_epoch,
29
+ setup_mlflow_experiment,
30
+ )
31
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
32
+ from learning_munsell.utilities.losses import precision_focused_loss
33
+ from learning_munsell.utilities.training import train_epoch, validate
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+ # Note: This script has a custom ErrorPredictorMLP architecture
38
+ # so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor
39
+ # from shared modules.
40
+
41
+
42
+ class ErrorPredictorMLP(nn.Module):
43
+ """
44
+ Advanced error predictor with residual connections.
45
+
46
+ This model implements a two-stage architecture for Munsell color prediction:
47
+ 1. Base model makes initial predictions from xyY coordinates
48
+ 2. Error predictor learns residual corrections to improve base predictions
49
+
50
+ The error predictor uses MLP-style residual blocks for better gradient
51
+ flow and deeper representations. It takes both the input xyY coordinates
52
+ and the base model's predictions to predict the error that should be added
53
+ to the base predictions.
54
+
55
+ Architecture:
56
+ - Input: 7 features (xyY_norm + base_pred_norm)
57
+ - Encoder: 7 → 256 → 512
58
+ - Residual blocks at 512-dim
59
+ - Decoder: 512 → 256 → 128 → 4
60
+ - Uses GELU activations and residual connections
61
+
62
+ Parameters
63
+ ----------
64
+ num_residual_blocks : int, optional
65
+ Number of residual blocks to use in the middle of the network.
66
+ Default is 3.
67
+
68
+ Attributes
69
+ ----------
70
+ encoder : nn.Sequential
71
+ Encoder network that maps 7D input to 512D representation.
72
+ residual_blocks : nn.ModuleList
73
+ List of residual blocks for deep feature extraction.
74
+ decoder : nn.Sequential
75
+ Decoder network that maps 512D representation to 4D error prediction.
76
+ """
77
+
78
+ def __init__(self, num_residual_blocks: int = 3) -> None:
79
+ super().__init__()
80
+
81
+ # Encoder
82
+ self.encoder = nn.Sequential(
83
+ nn.Linear(7, 256),
84
+ nn.GELU(),
85
+ nn.BatchNorm1d(256),
86
+ nn.Linear(256, 512),
87
+ nn.GELU(),
88
+ nn.BatchNorm1d(512),
89
+ )
90
+
91
+ # Residual blocks
92
+ self.residual_blocks = nn.ModuleList(
93
+ [ResidualBlock(512) for _ in range(num_residual_blocks)]
94
+ )
95
+
96
+ # Decoder
97
+ self.decoder = nn.Sequential(
98
+ nn.Linear(512, 256),
99
+ nn.GELU(),
100
+ nn.BatchNorm1d(256),
101
+ nn.Linear(256, 128),
102
+ nn.GELU(),
103
+ nn.BatchNorm1d(128),
104
+ nn.Linear(128, 4),
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ """
109
+ Forward pass through the error predictor.
110
+
111
+ Parameters
112
+ ----------
113
+ x : Tensor
114
+ Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
115
+
116
+ Returns
117
+ -------
118
+ Tensor
119
+ Predicted error correction of shape (batch_size, 4).
120
+ """
121
+ # Encode
122
+ x = self.encoder(x)
123
+
124
+ # Residual blocks
125
+ for block in self.residual_blocks:
126
+ x = block(x)
127
+
128
+ # Decode
129
+ return self.decoder(x)
130
+
131
+
132
+ def load_base_model(
133
+ model_path: Path, params_path: Path
134
+ ) -> tuple[ort.InferenceSession, dict, dict]:
135
+ """
136
+ Load the base ONNX model and its normalization parameters.
137
+
138
+ The base model is the first stage of the two-stage architecture that makes
139
+ initial predictions from xyY coordinates to Munsell specifications.
140
+
141
+ Parameters
142
+ ----------
143
+ model_path : Path
144
+ Path to the ONNX model file.
145
+ params_path : Path
146
+ Path to the .npz file containing input and output normalization parameters.
147
+
148
+ Returns
149
+ -------
150
+ session : ort.InferenceSession
151
+ ONNX Runtime inference session for the base model.
152
+ input_parameters : dict
153
+ Dictionary containing input normalization ranges (x_range, y_range, Y_range).
154
+ output_parameters : dict
155
+ Dictionary containing output normalization ranges (hue_range, value_range,
156
+ chroma_range, code_range).
157
+ """
158
+ session = ort.InferenceSession(str(model_path))
159
+ params = np.load(params_path, allow_pickle=True)
160
+ return (
161
+ session,
162
+ params["input_parameters"].item(),
163
+ params["output_parameters"].item(),
164
+ )
165
+
166
+
167
+ @click.command()
168
+ @click.option(
169
+ "--base-model",
170
+ type=click.Path(exists=True, path_type=Path),
171
+ help="Path to base model ONNX file",
172
+ )
173
+ @click.option(
174
+ "--params",
175
+ type=click.Path(exists=True, path_type=Path),
176
+ help="Path to normalization params file",
177
+ )
178
+ @click.option(
179
+ "--epochs",
180
+ type=int,
181
+ default=300,
182
+ help="Number of training epochs",
183
+ )
184
+ @click.option(
185
+ "--batch-size",
186
+ type=int,
187
+ default=1024,
188
+ help="Batch size for training",
189
+ )
190
+ @click.option(
191
+ "--lr",
192
+ type=float,
193
+ default=3e-4,
194
+ help="Learning rate",
195
+ )
196
+ @click.option(
197
+ "--patience",
198
+ type=int,
199
+ default=20,
200
+ help="Patience for early stopping",
201
+ )
202
+ def main(
203
+ base_model: Path | None,
204
+ params: Path | None,
205
+ epochs: int,
206
+ batch_size: int,
207
+ lr: float,
208
+ patience: int,
209
+ ) -> None:
210
+ """
211
+ Train error predictor with advanced MLP architecture.
212
+
213
+ Parameters
214
+ ----------
215
+ base_model : Path or None
216
+ Path to the base model ONNX file. If None, uses default path.
217
+ params : Path or None
218
+ Path to normalization parameters .npz file. If None, uses default path.
219
+
220
+ Notes
221
+ -----
222
+ The training pipeline:
223
+ 1. Loads pre-trained base model
224
+ 2. Generates base model predictions for training data
225
+ 3. Computes residual errors between predictions and targets
226
+ 4. Trains error predictor on these residuals
227
+ 5. Uses precision-focused loss function
228
+ 6. Learning rate scheduling with ReduceLROnPlateau
229
+ 7. Early stopping based on validation loss
230
+ 8. Exports model to ONNX format
231
+ 9. Logs metrics and artifacts to MLflow
232
+ """
233
+
234
+ LOGGER.info("=" * 80)
235
+ LOGGER.info("Error Predictor: MLP + GELU + Precision Loss")
236
+ LOGGER.info("=" * 80)
237
+
238
+ # Set device
239
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240
+ LOGGER.info("Using device: %s", device)
241
+
242
+ # Paths
243
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
244
+ data_dir = PROJECT_ROOT / "data"
245
+
246
+ base_model_path = base_model
247
+ params_path = params
248
+ cache_file = data_dir / "training_data.npz"
249
+
250
+ # Extract base model name for error predictor naming
251
+ base_model_name = (
252
+ base_model_path.stem if base_model_path else "xyY_to_munsell_specification"
253
+ )
254
+
255
+ # Load base model
256
+ LOGGER.info("")
257
+ LOGGER.info("Loading base model from %s...", base_model_path)
258
+ base_session, input_parameters, output_parameters = load_base_model(
259
+ base_model_path, params_path
260
+ )
261
+
262
+ # Load training data
263
+ LOGGER.info("Loading training data from %s...", cache_file)
264
+ data = np.load(cache_file)
265
+ X_train = data["X_train"]
266
+ y_train = data["y_train"]
267
+ X_val = data["X_val"]
268
+ y_val = data["y_val"]
269
+
270
+ LOGGER.info("Train samples: %d", len(X_train))
271
+ LOGGER.info("Validation samples: %d", len(X_val))
272
+
273
+ # Generate base model predictions
274
+ LOGGER.info("")
275
+ LOGGER.info("Generating base model predictions...")
276
+ X_train_norm = normalize_xyY(X_train, input_parameters)
277
+ y_train_norm = normalize_munsell(y_train, output_parameters)
278
+
279
+ # Base predictions (normalized)
280
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
281
+
282
+ X_val_norm = normalize_xyY(X_val, input_parameters)
283
+ y_val_norm = normalize_munsell(y_val, output_parameters)
284
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
285
+
286
+ # Compute errors (in normalized space)
287
+ error_train = y_train_norm - base_pred_train_norm
288
+ error_val = y_val_norm - base_pred_val_norm
289
+
290
+ # Statistics
291
+ LOGGER.info("")
292
+ LOGGER.info("Base model error statistics (normalized space):")
293
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
294
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
295
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
296
+
297
+ # Create combined input: [xyY_norm, base_prediction_norm]
298
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
299
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
300
+
301
+ # Convert to PyTorch tensors
302
+ X_train_t = torch.FloatTensor(X_train_combined)
303
+ error_train_t = torch.FloatTensor(error_train)
304
+ X_val_t = torch.FloatTensor(X_val_combined)
305
+ error_val_t = torch.FloatTensor(error_val)
306
+
307
+ # Create data loaders
308
+ train_dataset = TensorDataset(X_train_t, error_train_t)
309
+ val_dataset = TensorDataset(X_val_t, error_val_t)
310
+
311
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
312
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
313
+
314
+ # Initialize error predictor model with MLP architecture
315
+ model = ErrorPredictorMLP(num_residual_blocks=3).to(device)
316
+ LOGGER.info("")
317
+ LOGGER.info("Error predictor architecture:")
318
+ LOGGER.info("%s", model)
319
+
320
+ # Count parameters
321
+ total_params = sum(p.numel() for p in model.parameters())
322
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
323
+
324
+ # Training setup with precision-focused loss
325
+ LOGGER.info("")
326
+ LOGGER.info("Using precision-focused loss function:")
327
+ LOGGER.info(" - MSE (weight: 1.0)")
328
+ LOGGER.info(" - MAE (weight: 0.5)")
329
+ LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
330
+ LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
331
+
332
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
333
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
334
+ optimizer, mode="min", factor=0.5, patience=5
335
+ )
336
+ criterion = precision_focused_loss
337
+
338
+ # MLflow setup
339
+ model_name = f"{base_model_name}_error_predictor"
340
+ run_name = setup_mlflow_experiment("from_xyY", model_name)
341
+
342
+ LOGGER.info("")
343
+ LOGGER.info("MLflow run: %s", run_name)
344
+
345
+ # Training loop
346
+ best_val_loss = float("inf")
347
+ patience_counter = 0
348
+
349
+ LOGGER.info("")
350
+ LOGGER.info("Starting training...")
351
+
352
+ with mlflow.start_run(run_name=run_name):
353
+ mlflow.log_params(
354
+ {
355
+ "model": model_name,
356
+ "base_model": base_model_name,
357
+ "learning_rate": lr,
358
+ "batch_size": batch_size,
359
+ "num_epochs": epochs,
360
+ "patience": patience,
361
+ "total_params": total_params,
362
+ }
363
+ )
364
+
365
+ for epoch in range(epochs):
366
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
367
+ val_loss = validate(model, val_loader, criterion, device)
368
+
369
+ # Update learning rate
370
+ scheduler.step(val_loss)
371
+
372
+ log_training_epoch(
373
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
374
+ )
375
+
376
+ LOGGER.info(
377
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
378
+ epoch + 1,
379
+ epochs,
380
+ train_loss,
381
+ val_loss,
382
+ optimizer.param_groups[0]["lr"],
383
+ )
384
+
385
+ # Early stopping
386
+ if val_loss < best_val_loss:
387
+ best_val_loss = val_loss
388
+ patience_counter = 0
389
+
390
+ # Save best model
391
+ model_directory.mkdir(exist_ok=True)
392
+ checkpoint_file = (
393
+ model_directory / f"{base_model_name}_error_predictor_best.pth"
394
+ )
395
+
396
+ torch.save(
397
+ {
398
+ "model_state_dict": model.state_dict(),
399
+ "epoch": epoch,
400
+ "val_loss": val_loss,
401
+ },
402
+ checkpoint_file,
403
+ )
404
+
405
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
406
+ else:
407
+ patience_counter += 1
408
+ if patience_counter >= patience:
409
+ LOGGER.info("")
410
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
411
+ break
412
+
413
+ mlflow.log_metrics(
414
+ {
415
+ "best_val_loss": best_val_loss,
416
+ "final_epoch": epoch + 1,
417
+ }
418
+ )
419
+
420
+ # Export to ONNX
421
+ LOGGER.info("")
422
+ LOGGER.info("Exporting error predictor to ONNX...")
423
+ model.eval()
424
+
425
+ # Load best model
426
+ checkpoint = torch.load(checkpoint_file)
427
+ model.load_state_dict(checkpoint["model_state_dict"])
428
+
429
+ # Create dummy input (xyY_norm + base_pred_norm = 7 inputs)
430
+ dummy_input = torch.randn(1, 7).to(device)
431
+
432
+ # Export
433
+ onnx_file = model_directory / f"{base_model_name}_error_predictor.onnx"
434
+ torch.onnx.export(
435
+ model,
436
+ dummy_input,
437
+ onnx_file,
438
+ export_params=True,
439
+ opset_version=15,
440
+ input_names=["combined_input"],
441
+ output_names=["error_correction"],
442
+ dynamic_axes={
443
+ "combined_input": {0: "batch_size"},
444
+ "error_correction": {0: "batch_size"},
445
+ },
446
+ )
447
+
448
+ mlflow.log_artifact(str(checkpoint_file))
449
+ mlflow.log_artifact(str(onnx_file))
450
+ mlflow.pytorch.log_model(model, "model")
451
+
452
+ LOGGER.info("Error predictor ONNX model saved to: %s", onnx_file)
453
+ LOGGER.info("Artifacts logged to MLflow")
454
+
455
+ LOGGER.info("=" * 80)
456
+
457
+
458
+ if __name__ == "__main__":
459
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
460
+
461
+ main()
learning_munsell/training/from_xyY/train_mlp_gamma.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train ML model for xyY to Munsell conversion with gamma-corrected Y.
3
+
4
+ Experiment: Apply gamma 2.33 to Y before normalization to better align
5
+ with perceptual lightness (Munsell Value scale is perceptually uniform).
6
+ """
7
+
8
+ import logging
9
+ from typing import Any
10
+
11
+ import click
12
+ import mlflow
13
+ import mlflow.pytorch
14
+ import numpy as np
15
+ import torch
16
+ from numpy.typing import NDArray
17
+ from torch import optim
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from learning_munsell import PROJECT_ROOT
21
+ from learning_munsell.models.networks import MLPToMunsell
22
+ from learning_munsell.utilities.common import (
23
+ log_training_epoch,
24
+ setup_mlflow_experiment,
25
+ )
26
+ from learning_munsell.utilities.data import (
27
+ MUNSELL_NORMALIZATION_PARAMS,
28
+ normalize_munsell,
29
+ )
30
+ from learning_munsell.utilities.losses import weighted_mse_loss
31
+ from learning_munsell.utilities.training import train_epoch, validate
32
+
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+ # Gamma value for Y transformation
36
+ GAMMA = 2.33
37
+
38
+
39
+ def normalize_inputs(
40
+ X: NDArray, gamma: float = GAMMA
41
+ ) -> tuple[NDArray, dict[str, Any]]:
42
+ """
43
+ Normalize xyY inputs to [0, 1] range with gamma correction on Y.
44
+
45
+ Parameters
46
+ ----------
47
+ X : ndarray
48
+ xyY values of shape (n, 3) where columns are [x, y, Y].
49
+ gamma : float
50
+ Gamma value to apply to Y component.
51
+
52
+ Returns
53
+ -------
54
+ ndarray
55
+ Normalized values with gamma-corrected Y, dtype float32.
56
+ dict
57
+ Normalization parameters including gamma value.
58
+ """
59
+ # Typical ranges for xyY
60
+ x_range = (0.0, 1.0)
61
+ y_range = (0.0, 1.0)
62
+ Y_range = (0.0, 1.0)
63
+
64
+ X_norm = X.copy()
65
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
66
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
67
+
68
+ # Normalize Y first, then apply gamma
69
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
70
+ # Clip to avoid numerical issues with negative values
71
+ Y_normalized = np.clip(Y_normalized, 0, 1)
72
+ # Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light
73
+ X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
74
+
75
+ params = {
76
+ "x_range": x_range,
77
+ "y_range": y_range,
78
+ "Y_range": Y_range,
79
+ "gamma": gamma,
80
+ }
81
+
82
+ return X_norm, params
83
+
84
+
85
+ @click.command()
86
+ @click.option("--epochs", default=300, help="Number of training epochs")
87
+ @click.option("--batch-size", default=1024, help="Batch size for training")
88
+ @click.option("--lr", default=5e-4, help="Learning rate")
89
+ @click.option("--patience", default=20, help="Early stopping patience")
90
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
91
+ """
92
+ Train MLP model with gamma-corrected Y input.
93
+
94
+ Notes
95
+ -----
96
+ The training pipeline:
97
+ 1. Loads training and validation data from cache
98
+ 2. Normalizes inputs with gamma correction (gamma=2.33) on Y
99
+ 3. Normalizes Munsell outputs to [0, 1] range
100
+ 4. Trains MLP with weighted MSE loss
101
+ 5. Uses early stopping based on validation loss
102
+ 6. Exports best model to ONNX format
103
+ 7. Logs metrics and artifacts to MLflow
104
+
105
+ The gamma correction on Y aligns with perceptual lightness. The gamma
106
+ transformation spreads dark values and compresses light values, matching
107
+ human lightness perception and the perceptually uniform Munsell Value scale.
108
+ """
109
+
110
+ LOGGER.info("=" * 80)
111
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Gamma Experiment")
112
+ LOGGER.info("Gamma = %.2f applied to Y component", GAMMA)
113
+ LOGGER.info("=" * 80)
114
+
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ LOGGER.info("Using device: %s", device)
117
+
118
+ # Load training data
119
+ data_dir = PROJECT_ROOT / "data"
120
+ cache_file = data_dir / "training_data.npz"
121
+
122
+ if not cache_file.exists():
123
+ LOGGER.error("Error: Training data not found at %s", cache_file)
124
+ LOGGER.error("Please run 01_generate_training_data.py first")
125
+ return
126
+
127
+ LOGGER.info("Loading training data from %s...", cache_file)
128
+ data = np.load(cache_file)
129
+
130
+ X_train = data["X_train"]
131
+ y_train = data["y_train"]
132
+ X_val = data["X_val"]
133
+ y_val = data["y_val"]
134
+
135
+ LOGGER.info("Train samples: %d", len(X_train))
136
+ LOGGER.info("Validation samples: %d", len(X_val))
137
+
138
+ # Normalize data with gamma correction
139
+ X_train_norm, input_parameters = normalize_inputs(X_train, gamma=GAMMA)
140
+ X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA)
141
+
142
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
143
+ y_train_norm = normalize_munsell(y_train, output_parameters)
144
+ y_val_norm = normalize_munsell(y_val, output_parameters)
145
+
146
+ LOGGER.info("")
147
+ LOGGER.info("Input normalization with gamma=%.2f:", GAMMA)
148
+ LOGGER.info(
149
+ " Y range after gamma: [%.4f, %.4f]",
150
+ X_train_norm[:, 2].min(),
151
+ X_train_norm[:, 2].max(),
152
+ )
153
+
154
+ # Convert to PyTorch tensors
155
+ X_train_t = torch.FloatTensor(X_train_norm)
156
+ y_train_t = torch.FloatTensor(y_train_norm)
157
+ X_val_t = torch.FloatTensor(X_val_norm)
158
+ y_val_t = torch.FloatTensor(y_val_norm)
159
+
160
+ # Create data loaders
161
+ train_dataset = TensorDataset(X_train_t, y_train_t)
162
+ val_dataset = TensorDataset(X_val_t, y_val_t)
163
+
164
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
165
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
166
+
167
+ # Initialize model
168
+ model = MLPToMunsell().to(device)
169
+ LOGGER.info("")
170
+ LOGGER.info("Model architecture:")
171
+ LOGGER.info("%s", model)
172
+
173
+ total_params = sum(p.numel() for p in model.parameters())
174
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
175
+
176
+ # Training setup
177
+ optimizer = optim.Adam(model.parameters(), lr=lr)
178
+ # Component weights: emphasize chroma (2.0), de-emphasize code (0.5)
179
+ weights = torch.tensor([1.0, 1.0, 2.0, 0.5])
180
+
181
+ def criterion(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
182
+ return weighted_mse_loss(pred, target, weights)
183
+
184
+ # MLflow setup
185
+ run_name = setup_mlflow_experiment("from_xyY", f"mlp_gamma_{GAMMA}")
186
+
187
+ LOGGER.info("")
188
+ LOGGER.info("MLflow run: %s", run_name)
189
+
190
+ # Training loop
191
+ best_val_loss = float("inf")
192
+ patience_counter = 0
193
+
194
+ LOGGER.info("")
195
+ LOGGER.info("Starting training...")
196
+
197
+ with mlflow.start_run(run_name=run_name):
198
+ mlflow.log_params(
199
+ {
200
+ "num_epochs": epochs,
201
+ "batch_size": batch_size,
202
+ "learning_rate": lr,
203
+ "optimizer": "Adam",
204
+ "criterion": "weighted_mse_loss",
205
+ "patience": patience,
206
+ "total_params": total_params,
207
+ "gamma": GAMMA,
208
+ }
209
+ )
210
+
211
+ for epoch in range(epochs):
212
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
213
+ val_loss = validate(model, val_loader, criterion, device)
214
+
215
+ log_training_epoch(
216
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
217
+ )
218
+
219
+ LOGGER.info(
220
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
221
+ epoch + 1,
222
+ epochs,
223
+ train_loss,
224
+ val_loss,
225
+ )
226
+
227
+ if val_loss < best_val_loss:
228
+ best_val_loss = val_loss
229
+ patience_counter = 0
230
+
231
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
232
+ model_directory.mkdir(exist_ok=True)
233
+ checkpoint_file = model_directory / "mlp_gamma_best.pth"
234
+
235
+ torch.save(
236
+ {
237
+ "model_state_dict": model.state_dict(),
238
+ "input_parameters": input_parameters,
239
+ "output_parameters": output_parameters,
240
+ "epoch": epoch,
241
+ "val_loss": val_loss,
242
+ },
243
+ checkpoint_file,
244
+ )
245
+
246
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
247
+ else:
248
+ patience_counter += 1
249
+ if patience_counter >= patience:
250
+ LOGGER.info("")
251
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
252
+ break
253
+
254
+ mlflow.log_metrics(
255
+ {
256
+ "best_val_loss": best_val_loss,
257
+ "final_epoch": epoch + 1,
258
+ }
259
+ )
260
+
261
+ # Export to ONNX
262
+ LOGGER.info("")
263
+ LOGGER.info("Exporting model to ONNX...")
264
+ model.eval()
265
+
266
+ checkpoint = torch.load(checkpoint_file)
267
+ model.load_state_dict(checkpoint["model_state_dict"])
268
+
269
+ dummy_input = torch.randn(1, 3).to(device)
270
+
271
+ onnx_file = model_directory / "mlp_gamma.onnx"
272
+ torch.onnx.export(
273
+ model,
274
+ dummy_input,
275
+ onnx_file,
276
+ export_params=True,
277
+ opset_version=15,
278
+ input_names=["xyY_gamma"],
279
+ output_names=["munsell_spec"],
280
+ dynamic_axes={
281
+ "xyY_gamma": {0: "batch_size"},
282
+ "munsell_spec": {0: "batch_size"},
283
+ },
284
+ )
285
+
286
+ # Save normalization parameters (including gamma)
287
+ params_file = model_directory / "mlp_gamma_normalization_parameters.npz"
288
+ np.savez(
289
+ params_file,
290
+ input_parameters=input_parameters,
291
+ output_parameters=output_parameters,
292
+ )
293
+
294
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
295
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
296
+ LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA)
297
+
298
+ mlflow.log_artifact(str(checkpoint_file))
299
+ mlflow.log_artifact(str(onnx_file))
300
+ mlflow.log_artifact(str(params_file))
301
+ mlflow.pytorch.log_model(model, "model")
302
+
303
+ LOGGER.info("=" * 80)
304
+
305
+
306
+ if __name__ == "__main__":
307
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
308
+
309
+ main()
learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train second-stage error predictor for 3-stage model.
3
+
4
+ Architecture: Multi-Head + Multi-Error Predictor + Multi-Error Predictor
5
+ - Stage 1: Multi-Head base model (existing)
6
+ - Stage 2: First error predictor (existing)
7
+ - Stage 3: Second error predictor (this script) - learns residuals from stage 2
8
+
9
+ The second error predictor has the same architecture as the first but learns
10
+ the remaining errors after the first error correction is applied.
11
+ """
12
+
13
+ import logging
14
+ from pathlib import Path
15
+
16
+ import click
17
+ import mlflow
18
+ import mlflow.pytorch
19
+ import numpy as np
20
+ import onnxruntime as ort
21
+ import torch
22
+ from torch import optim
23
+ from torch.utils.data import DataLoader, TensorDataset
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+ from learning_munsell.models.networks import (
27
+ MultiHeadErrorPredictorToMunsell,
28
+ )
29
+ from learning_munsell.utilities.common import (
30
+ log_training_epoch,
31
+ setup_mlflow_experiment,
32
+ )
33
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
34
+ from learning_munsell.utilities.losses import precision_focused_loss
35
+ from learning_munsell.utilities.training import train_epoch, validate
36
+
37
+ LOGGER = logging.getLogger(__name__)
38
+
39
+
40
+ @click.command()
41
+ @click.option(
42
+ "--base-model",
43
+ type=click.Path(exists=True, path_type=Path),
44
+ default=None,
45
+ help="Path to Multi-Head base model ONNX file",
46
+ )
47
+ @click.option(
48
+ "--first-error-predictor",
49
+ type=click.Path(exists=True, path_type=Path),
50
+ default=None,
51
+ help="Path to first error predictor ONNX file",
52
+ )
53
+ @click.option(
54
+ "--params",
55
+ type=click.Path(exists=True, path_type=Path),
56
+ default=None,
57
+ help="Path to normalization params file",
58
+ )
59
+ @click.option(
60
+ "--epochs",
61
+ type=int,
62
+ default=300,
63
+ help="Number of training epochs (default: 300)",
64
+ )
65
+ @click.option(
66
+ "--batch-size",
67
+ type=int,
68
+ default=2048,
69
+ help="Batch size for training (default: 2048)",
70
+ )
71
+ @click.option(
72
+ "--lr",
73
+ type=float,
74
+ default=3e-4,
75
+ help="Learning rate (default: 3e-4)",
76
+ )
77
+ @click.option(
78
+ "--patience",
79
+ type=int,
80
+ default=30,
81
+ help="Early stopping patience (default: 30)",
82
+ )
83
+ def main(
84
+ base_model: Path | None,
85
+ first_error_predictor: Path | None,
86
+ params: Path | None,
87
+ epochs: int,
88
+ batch_size: int,
89
+ lr: float,
90
+ patience: int,
91
+ ) -> None:
92
+ """
93
+ Train the second-stage error predictor for the 3-stage model.
94
+
95
+ This script trains the third stage of a 3-stage model:
96
+ - Stage 1: Multi-Head base model (pre-trained)
97
+ - Stage 2: First error predictor (pre-trained)
98
+ - Stage 3: Second error predictor (trained by this script)
99
+
100
+ The second error predictor learns the residual errors remaining after
101
+ the first error correction is applied, further refining the predictions.
102
+
103
+ Parameters
104
+ ----------
105
+ base_model : Path, optional
106
+ Path to the Multi-Head base model ONNX file.
107
+ Default: models/from_xyY/multi_head_large.onnx
108
+ first_error_predictor : Path, optional
109
+ Path to the first error predictor ONNX file.
110
+ Default: models/from_xyY/multi_head_multi_error_predictor_large.onnx
111
+ params : Path, optional
112
+ Path to the normalization parameters file.
113
+ Default: models/from_xyY/multi_head_large_normalization_parameters.npz
114
+
115
+ Notes
116
+ -----
117
+ The training pipeline:
118
+ 1. Loads pre-trained Stage 1 and Stage 2 models
119
+ 2. Generates Stage 2 predictions (base + first error correction)
120
+ 3. Computes remaining residual errors
121
+ 4. Trains Stage 3 error predictor on these residuals
122
+ 5. Exports the model to ONNX format
123
+ 6. Logs metrics and artifacts to MLflow
124
+ """
125
+
126
+ LOGGER.info("=" * 80)
127
+ LOGGER.info("Second Error Predictor: 3-Stage Model Training")
128
+ LOGGER.info("Multi-Head + Multi-Error Predictor + Multi-Error Predictor")
129
+ LOGGER.info("=" * 80)
130
+
131
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+ if torch.backends.mps.is_available():
133
+ device = torch.device("mps")
134
+ LOGGER.info("Using device: %s", device)
135
+
136
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
137
+ data_dir = PROJECT_ROOT / "data"
138
+
139
+ if base_model is None:
140
+ base_model = model_directory / "multi_head_large.onnx"
141
+ if first_error_predictor is None:
142
+ first_error_predictor = (
143
+ model_directory / "multi_head_multi_error_predictor_large.onnx"
144
+ )
145
+ if params is None:
146
+ params = model_directory / "multi_head_large_normalization_parameters.npz"
147
+
148
+ cache_file = data_dir / "training_data_large.npz"
149
+
150
+ if not cache_file.exists():
151
+ LOGGER.error("Error: Large training data not found at %s", cache_file)
152
+ return
153
+
154
+ if not base_model.exists():
155
+ LOGGER.error("Error: Base model not found at %s", base_model)
156
+ return
157
+
158
+ if not first_error_predictor.exists():
159
+ LOGGER.error(
160
+ "Error: First error predictor not found at %s", first_error_predictor
161
+ )
162
+ return
163
+
164
+ # Load models
165
+ LOGGER.info("")
166
+ LOGGER.info("Loading Stage 1: Multi-Head base model from %s...", base_model)
167
+ base_session = ort.InferenceSession(str(base_model))
168
+
169
+ LOGGER.info(
170
+ "Loading Stage 2: First error predictor from %s...", first_error_predictor
171
+ )
172
+ error_predictor_session = ort.InferenceSession(str(first_error_predictor))
173
+
174
+ # Load normalization params
175
+ params_data = np.load(params, allow_pickle=True)
176
+ input_parameters = params_data["input_parameters"].item()
177
+ output_parameters = params_data["output_parameters"].item()
178
+
179
+ # Load training data
180
+ LOGGER.info("Loading large training data from %s...", cache_file)
181
+ data = np.load(cache_file)
182
+ X_train = data["X_train"]
183
+ y_train = data["y_train"]
184
+ X_val = data["X_val"]
185
+ y_val = data["y_val"]
186
+
187
+ LOGGER.info("Train samples: %d", len(X_train))
188
+ LOGGER.info("Validation samples: %d", len(X_val))
189
+
190
+ # Generate stage 2 predictions (base + first error correction)
191
+ LOGGER.info("")
192
+ LOGGER.info("Computing Stage 2 predictions (base + first error correction)...")
193
+
194
+ X_train_norm = normalize_xyY(X_train, input_parameters)
195
+ y_train_norm = normalize_munsell(y_train, output_parameters)
196
+ X_val_norm = normalize_xyY(X_val, input_parameters)
197
+ y_val_norm = normalize_munsell(y_val, output_parameters)
198
+
199
+ inference_batch_size = 50000
200
+
201
+ # Stage 1: Base model predictions
202
+ LOGGER.info(" Stage 1: Base model predictions (training set)...")
203
+ base_pred_train = []
204
+ for i in range(0, len(X_train_norm), inference_batch_size):
205
+ batch = X_train_norm[i : i + inference_batch_size]
206
+ pred = base_session.run(None, {"xyY": batch})[0]
207
+ base_pred_train.append(pred)
208
+ base_pred_train = np.concatenate(base_pred_train, axis=0)
209
+
210
+ LOGGER.info(" Stage 1: Base model predictions (validation set)...")
211
+ base_pred_val = []
212
+ for i in range(0, len(X_val_norm), inference_batch_size):
213
+ batch = X_val_norm[i : i + inference_batch_size]
214
+ pred = base_session.run(None, {"xyY": batch})[0]
215
+ base_pred_val.append(pred)
216
+ base_pred_val = np.concatenate(base_pred_val, axis=0)
217
+
218
+ # Stage 2: First error predictor corrections
219
+ LOGGER.info(" Stage 2: First error predictor corrections (training set)...")
220
+ combined_train = np.concatenate([X_train_norm, base_pred_train], axis=1).astype(
221
+ np.float32
222
+ )
223
+ error_correction_train = []
224
+ for i in range(0, len(combined_train), inference_batch_size):
225
+ batch = combined_train[i : i + inference_batch_size]
226
+ correction = error_predictor_session.run(None, {"combined_input": batch})[0]
227
+ error_correction_train.append(correction)
228
+ error_correction_train = np.concatenate(error_correction_train, axis=0)
229
+
230
+ LOGGER.info(" Stage 2: First error predictor corrections (validation set)...")
231
+ combined_val = np.concatenate([X_val_norm, base_pred_val], axis=1).astype(
232
+ np.float32
233
+ )
234
+ error_correction_val = []
235
+ for i in range(0, len(combined_val), inference_batch_size):
236
+ batch = combined_val[i : i + inference_batch_size]
237
+ correction = error_predictor_session.run(None, {"combined_input": batch})[0]
238
+ error_correction_val.append(correction)
239
+ error_correction_val = np.concatenate(error_correction_val, axis=0)
240
+
241
+ # Stage 2 predictions (base + first error correction)
242
+ stage2_pred_train = base_pred_train + error_correction_train
243
+ stage2_pred_val = base_pred_val + error_correction_val
244
+
245
+ # Compute remaining errors for stage 3
246
+ error_train = y_train_norm - stage2_pred_train
247
+ error_val = y_val_norm - stage2_pred_val
248
+
249
+ # Statistics
250
+ LOGGER.info("")
251
+ LOGGER.info("Stage 2 prediction error statistics (normalized space):")
252
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
253
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
254
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
255
+
256
+ # Compare with stage 1 errors
257
+ stage1_error_train = y_train_norm - base_pred_train
258
+ LOGGER.info("")
259
+ LOGGER.info("Stage 1 (base only) error statistics for comparison:")
260
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(stage1_error_train)))
261
+ LOGGER.info(" Std of error: %.6f", np.std(stage1_error_train))
262
+
263
+ error_reduction = (
264
+ (np.mean(np.abs(stage1_error_train)) - np.mean(np.abs(error_train)))
265
+ / np.mean(np.abs(stage1_error_train))
266
+ * 100
267
+ )
268
+ LOGGER.info("")
269
+ LOGGER.info("Stage 2 error reduction vs Stage 1: %.1f%%", error_reduction)
270
+
271
+ # Create combined input for stage 3: [xyY_norm, stage2_pred_norm]
272
+ X_train_combined = np.concatenate([X_train_norm, stage2_pred_train], axis=1)
273
+ X_val_combined = np.concatenate([X_val_norm, stage2_pred_val], axis=1)
274
+
275
+ # Convert to PyTorch tensors
276
+ X_train_t = torch.FloatTensor(X_train_combined)
277
+ error_train_t = torch.FloatTensor(error_train)
278
+ X_val_t = torch.FloatTensor(X_val_combined)
279
+ error_val_t = torch.FloatTensor(error_val)
280
+
281
+ train_dataset = TensorDataset(X_train_t, error_train_t)
282
+ val_dataset = TensorDataset(X_val_t, error_val_t)
283
+
284
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
285
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
286
+
287
+ # Initialize second error predictor (same architecture as first)
288
+ model = MultiHeadErrorPredictorToMunsell().to(device)
289
+ LOGGER.info("")
290
+ LOGGER.info("Stage 3: Second error predictor architecture:")
291
+ LOGGER.info("%s", model)
292
+
293
+ total_params = sum(p.numel() for p in model.parameters())
294
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
295
+
296
+ # Training setup
297
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
298
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
299
+ optimizer, mode="min", factor=0.5, patience=10
300
+ )
301
+ criterion = precision_focused_loss
302
+
303
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_3stage_error_predictor")
304
+
305
+ LOGGER.info("")
306
+ LOGGER.info("MLflow run: %s", run_name)
307
+
308
+ # Training loop
309
+ best_val_loss = float("inf")
310
+ patience_counter = 0
311
+
312
+ LOGGER.info("")
313
+ LOGGER.info("Starting Stage 3 training...")
314
+
315
+ with mlflow.start_run(run_name=run_name):
316
+ mlflow.log_params(
317
+ {
318
+ "model": "multi_head_3stage_error_predictor",
319
+ "num_epochs": epochs,
320
+ "batch_size": batch_size,
321
+ "learning_rate": lr,
322
+ "weight_decay": 1e-5,
323
+ "optimizer": "AdamW",
324
+ "scheduler": "ReduceLROnPlateau",
325
+ "criterion": "precision_focused_loss",
326
+ "patience": patience,
327
+ "total_params": total_params,
328
+ "train_samples": len(X_train),
329
+ "val_samples": len(X_val),
330
+ "stage2_error_reduction_pct": error_reduction,
331
+ }
332
+ )
333
+
334
+ for epoch in range(epochs):
335
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
336
+ val_loss = validate(model, val_loader, criterion, device)
337
+
338
+ scheduler.step(val_loss)
339
+
340
+ log_training_epoch(
341
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
342
+ )
343
+
344
+ LOGGER.info(
345
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
346
+ epoch + 1,
347
+ epochs,
348
+ train_loss,
349
+ val_loss,
350
+ optimizer.param_groups[0]["lr"],
351
+ )
352
+
353
+ if val_loss < best_val_loss:
354
+ best_val_loss = val_loss
355
+ patience_counter = 0
356
+
357
+ model_directory.mkdir(exist_ok=True)
358
+ checkpoint_file = (
359
+ model_directory / "multi_head_3stage_error_predictor_best.pth"
360
+ )
361
+
362
+ torch.save(
363
+ {
364
+ "model_state_dict": model.state_dict(),
365
+ "epoch": epoch,
366
+ "val_loss": val_loss,
367
+ },
368
+ checkpoint_file,
369
+ )
370
+
371
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
372
+ else:
373
+ patience_counter += 1
374
+ if patience_counter >= patience:
375
+ LOGGER.info("")
376
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
377
+ break
378
+
379
+ mlflow.log_metrics(
380
+ {
381
+ "best_val_loss": best_val_loss,
382
+ "final_epoch": epoch + 1,
383
+ }
384
+ )
385
+
386
+ # Export to ONNX
387
+ LOGGER.info("")
388
+ LOGGER.info("Exporting Stage 3 error predictor to ONNX...")
389
+ model.eval()
390
+
391
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
392
+ model.load_state_dict(checkpoint["model_state_dict"])
393
+
394
+ dummy_input = torch.randn(1, 7).to(device)
395
+
396
+ onnx_file = model_directory / "multi_head_3stage_error_predictor.onnx"
397
+ torch.onnx.export(
398
+ model,
399
+ dummy_input,
400
+ onnx_file,
401
+ export_params=True,
402
+ opset_version=15,
403
+ input_names=["combined_input"],
404
+ output_names=["error_correction"],
405
+ dynamic_axes={
406
+ "combined_input": {0: "batch_size"},
407
+ "error_correction": {0: "batch_size"},
408
+ },
409
+ )
410
+
411
+ LOGGER.info("Stage 3 error predictor ONNX model saved to: %s", onnx_file)
412
+
413
+ mlflow.log_artifact(str(checkpoint_file))
414
+ mlflow.log_artifact(str(onnx_file))
415
+ mlflow.pytorch.log_model(model, "model")
416
+
417
+ LOGGER.info("=" * 80)
418
+
419
+
420
+ if __name__ == "__main__":
421
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
422
+
423
+ main()
learning_munsell/training/from_xyY/train_multi_head_circular.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Multi-Head model with circular hue loss for xyY to Munsell conversion.
3
+
4
+ This version uses circular loss for the hue component (which wraps from 0-10)
5
+ to avoid penalizing predictions near the boundary.
6
+
7
+ Key Difference from Standard Training:
8
+ - Uses munsell_component_loss() which applies circular MSE for hue
9
+ - and regular MSE for value/chroma/code components
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import logging
16
+
17
+ import click
18
+ import mlflow
19
+ import mlflow.pytorch
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn, optim
23
+ from torch.utils.data import DataLoader, TensorDataset
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+ from learning_munsell.training.from_xyY.hyperparameter_search_multi_head import (
27
+ MultiHeadParametric,
28
+ )
29
+ from learning_munsell.utilities.common import setup_mlflow_experiment
30
+ from learning_munsell.utilities.data import (
31
+ MUNSELL_NORMALIZATION_PARAMS,
32
+ normalize_munsell,
33
+ )
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ def circular_mse_loss(
39
+ pred_hue: torch.Tensor, target_hue: torch.Tensor, hue_range: float = 1.0
40
+ ) -> torch.Tensor:
41
+ """
42
+ Circular MSE loss for hue component (normalized 0-1).
43
+
44
+ Parameters
45
+ ----------
46
+ pred_hue : Tensor
47
+ Predicted hue values (normalized 0-1)
48
+ target_hue : Tensor
49
+ Target hue values (normalized 0-1)
50
+ hue_range : float
51
+ Range of hue values (1.0 for normalized)
52
+
53
+ Returns
54
+ -------
55
+ Tensor
56
+ Circular MSE loss
57
+ """
58
+ diff = torch.abs(pred_hue - target_hue)
59
+ circular_diff = torch.min(diff, hue_range - diff)
60
+ return torch.mean(circular_diff**2)
61
+
62
+
63
+ def munsell_component_loss(
64
+ pred: torch.Tensor, target: torch.Tensor, hue_range: float = 1.0
65
+ ) -> torch.Tensor:
66
+ """
67
+ Component-wise loss for Munsell predictions.
68
+
69
+ Uses circular MSE for hue (component 0) and regular MSE
70
+ for value, chroma, code (components 1-3).
71
+
72
+ Parameters
73
+ ----------
74
+ pred : Tensor
75
+ Predictions [hue, value, chroma, code] (shape: [batch, 4])
76
+ target : Tensor
77
+ Ground truth [hue, value, chroma, code] (shape: [batch, 4])
78
+ hue_range : float
79
+ Range of normalized hue values (default 1.0)
80
+
81
+ Returns
82
+ -------
83
+ Tensor
84
+ Combined loss
85
+ """
86
+ hue_loss = circular_mse_loss(pred[:, 0], target[:, 0], hue_range)
87
+ other_loss = nn.functional.mse_loss(pred[:, 1:], target[:, 1:])
88
+ return hue_loss + other_loss
89
+
90
+
91
+ @click.command()
92
+ @click.option("--epochs", default=300, help="Number of training epochs")
93
+ @click.option("--batch-size", default=512, help="Batch size for training")
94
+ @click.option("--lr", default=0.000837, help="Learning rate")
95
+ @click.option("--patience", default=30, help="Early stopping patience")
96
+ def main(
97
+ epochs: int,
98
+ batch_size: int,
99
+ lr: float,
100
+ patience: int,
101
+ encoder_width: float = 0.75,
102
+ head_width: float = 1.5,
103
+ chroma_head_width: float = 1.5,
104
+ dropout: float = 0.0,
105
+ weight_decay: float = 0.000013,
106
+ ) -> tuple[MultiHeadParametric, float]:
107
+ """
108
+ Train Multi-Head model with circular hue loss.
109
+
110
+ This script uses circular loss for the hue component (which wraps from
111
+ 0-10) to avoid penalizing predictions near the boundary.
112
+
113
+ Parameters
114
+ ----------
115
+ epochs : int, optional
116
+ Maximum number of training epochs.
117
+ batch_size : int, optional
118
+ Training batch size.
119
+ lr : float, optional
120
+ Learning rate for AdamW optimizer.
121
+ encoder_width : float, optional
122
+ Width multiplier for the shared encoder.
123
+ head_width : float, optional
124
+ Width multiplier for hue, value, and code heads.
125
+ chroma_head_width : float, optional
126
+ Width multiplier for chroma head (typically larger).
127
+ dropout : float, optional
128
+ Dropout rate for regularization.
129
+ weight_decay : float, optional
130
+ Weight decay for AdamW optimizer.
131
+
132
+ Returns
133
+ -------
134
+ model : MultiHeadParametric
135
+ Trained model with best validation loss weights.
136
+ best_val_loss : float
137
+ Best validation loss achieved during training.
138
+
139
+ Notes
140
+ -----
141
+ The training pipeline:
142
+ 1. Loads training data from cache
143
+ 2. Normalizes outputs to [0, 1] range
144
+ 3. Trains with circular MSE for hue and regular MSE for other components
145
+ 4. Uses CosineAnnealingLR scheduler
146
+ 5. Early stopping based on validation loss
147
+ 6. Exports model to ONNX format
148
+ 7. Logs metrics and artifacts to MLflow
149
+
150
+ The circular loss experiment showed that while mathematically correct,
151
+ the circular distance creates gradient discontinuities that harm
152
+ optimization. This model is included for comparison purposes.
153
+ """
154
+
155
+ LOGGER.info("=" * 80)
156
+ LOGGER.info("Training Multi-Head (Circular Hue Loss) for xyY to Munsell conversion")
157
+ LOGGER.info("=" * 80)
158
+ LOGGER.info("")
159
+ LOGGER.info("Using Circular Loss for Hue Component")
160
+ LOGGER.info("=" * 80)
161
+ LOGGER.info("")
162
+ LOGGER.info("Hyperparameters:")
163
+ LOGGER.info(" lr: %.6f", lr)
164
+ LOGGER.info(" batch_size: %d", batch_size)
165
+ LOGGER.info(" encoder_width: %.2f", encoder_width)
166
+ LOGGER.info(" head_width: %.2f", head_width)
167
+ LOGGER.info(" chroma_head_width: %.2f", chroma_head_width)
168
+ LOGGER.info(" dropout: %.2f", dropout)
169
+ LOGGER.info(" weight_decay: %.6f", weight_decay)
170
+ LOGGER.info("")
171
+
172
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
173
+ LOGGER.info("Using device: %s", device)
174
+
175
+ # Load data from cache
176
+ data_dir = PROJECT_ROOT / "data"
177
+ cache_file = data_dir / "training_data.npz"
178
+ data = np.load(cache_file)
179
+
180
+ X_train = data["X_train"]
181
+ y_train = data["y_train"]
182
+ X_val = data["X_val"]
183
+ y_val = data["y_val"]
184
+
185
+ LOGGER.info("Training samples: %d", len(X_train))
186
+ LOGGER.info("Validation samples: %d", len(X_val))
187
+
188
+ # Normalize outputs (xyY inputs already in [0, 1] range)
189
+ # Use shared normalization parameters covering the full Munsell
190
+ # space for generalization
191
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
192
+ y_train_norm = normalize_munsell(y_train, output_parameters)
193
+ y_val_norm = normalize_munsell(y_val, output_parameters)
194
+
195
+ # Convert to tensors
196
+ X_train_t = torch.from_numpy(X_train).float()
197
+ y_train_t = torch.from_numpy(y_train_norm).float()
198
+ X_val_t = torch.from_numpy(X_val).float()
199
+ y_val_t = torch.from_numpy(y_val_norm).float()
200
+
201
+ train_loader = DataLoader(
202
+ TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True
203
+ )
204
+ val_loader = DataLoader(
205
+ TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False
206
+ )
207
+
208
+ # Create model
209
+ model = MultiHeadParametric(
210
+ encoder_width=encoder_width,
211
+ head_width=head_width,
212
+ chroma_head_width=chroma_head_width,
213
+ dropout=dropout,
214
+ ).to(device)
215
+
216
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
217
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
218
+
219
+ total_params = sum(p.numel() for p in model.parameters())
220
+ LOGGER.info("")
221
+ LOGGER.info("Model parameters: %s", f"{total_params:,}")
222
+
223
+ encoder_params = sum(p.numel() for p in model.encoder.parameters())
224
+ hue_params = sum(p.numel() for p in model.hue_head.parameters())
225
+ value_params = sum(p.numel() for p in model.value_head.parameters())
226
+ chroma_params = sum(p.numel() for p in model.chroma_head.parameters())
227
+ code_params = sum(p.numel() for p in model.code_head.parameters())
228
+
229
+ LOGGER.info(" - Shared encoder (%.2fx): %s", encoder_width, f"{encoder_params:,}")
230
+ LOGGER.info(" - Hue head (%.2fx): %s", head_width, f"{hue_params:,}")
231
+ LOGGER.info(" - Value head (%.2fx): %s", head_width, f"{value_params:,}")
232
+ LOGGER.info(" - Chroma head (%.2fx): %s", chroma_head_width, f"{chroma_params:,}")
233
+ LOGGER.info(" - Code head (%.2fx): %s", head_width, f"{code_params:,}")
234
+
235
+ # MLflow setup
236
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_circular")
237
+ LOGGER.info("")
238
+ LOGGER.info("MLflow run: %s", run_name)
239
+
240
+ best_val_loss = float("inf")
241
+ best_state = None
242
+ patience_counter = 0
243
+
244
+ LOGGER.info("")
245
+ LOGGER.info("Starting training with circular hue loss...")
246
+
247
+ with mlflow.start_run(run_name=run_name):
248
+ mlflow.log_params(
249
+ {
250
+ "model": "multi_head_circular",
251
+ "encoder_width": encoder_width,
252
+ "head_width": head_width,
253
+ "chroma_head_width": chroma_head_width,
254
+ "dropout": dropout,
255
+ "learning_rate": lr,
256
+ "batch_size": batch_size,
257
+ "weight_decay": weight_decay,
258
+ "epochs": epochs,
259
+ "patience": patience,
260
+ "total_params": total_params,
261
+ "loss_type": "circular_hue",
262
+ }
263
+ )
264
+
265
+ for epoch in range(epochs):
266
+ # Training
267
+ model.train()
268
+ train_loss = 0.0
269
+ for X_batch, y_batch in train_loader:
270
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device) # noqa: PLW2901
271
+
272
+ optimizer.zero_grad()
273
+ pred = model(X_batch)
274
+
275
+ # Use circular loss for hue component
276
+ loss = munsell_component_loss(pred, y_batch, hue_range=1.0)
277
+
278
+ loss.backward()
279
+ optimizer.step()
280
+ train_loss += loss.item() * len(X_batch)
281
+
282
+ train_loss /= len(X_train_t)
283
+ scheduler.step()
284
+
285
+ # Validation
286
+ model.eval()
287
+ val_loss = 0.0
288
+ with torch.no_grad():
289
+ for X_batch, y_batch in val_loader:
290
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device) # noqa: PLW2901
291
+ pred = model(X_batch)
292
+ val_loss += munsell_component_loss(
293
+ pred, y_batch, hue_range=1.0
294
+ ).item() * len(X_batch)
295
+ val_loss /= len(X_val_t)
296
+
297
+ # Per-component MAE (denormalized for interpretability)
298
+ with torch.no_grad():
299
+ pred_val = model(X_val_t.to(device)).cpu()
300
+ # Denormalize predictions and ground truth
301
+ pred_denorm = pred_val.numpy()
302
+ hue_min, hue_max = output_parameters["hue_range"]
303
+ value_min, value_max = output_parameters["value_range"]
304
+ chroma_min, chroma_max = output_parameters["chroma_range"]
305
+ code_min, code_max = output_parameters["code_range"]
306
+
307
+ pred_denorm[:, 0] = (
308
+ pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min
309
+ ) # hue
310
+ pred_denorm[:, 1] = (
311
+ pred_val[:, 1].numpy() * (value_max - value_min) + value_min
312
+ ) # value
313
+ pred_denorm[:, 2] = (
314
+ pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min
315
+ ) # chroma
316
+ pred_denorm[:, 3] = (
317
+ pred_val[:, 3].numpy() * (code_max - code_min) + code_min
318
+ ) # code
319
+
320
+ y_denorm = y_val_norm.copy()
321
+ y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min
322
+ y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min
323
+ y_denorm[:, 2] = (
324
+ y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min
325
+ )
326
+ y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min
327
+
328
+ mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0)
329
+
330
+ mlflow.log_metrics(
331
+ {
332
+ "train_loss": train_loss,
333
+ "val_loss": val_loss,
334
+ "mae_hue": mae[0],
335
+ "mae_value": mae[1],
336
+ "mae_chroma": mae[2],
337
+ "mae_code": mae[3],
338
+ },
339
+ step=epoch,
340
+ )
341
+
342
+ if val_loss < best_val_loss:
343
+ best_val_loss = val_loss
344
+ best_state = copy.deepcopy(model.state_dict())
345
+ patience_counter = 0
346
+ LOGGER.info(
347
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - "
348
+ "MAE: hue=%.4f, value=%.4f, chroma=%.4f, code=%.4f",
349
+ epoch + 1,
350
+ epochs,
351
+ train_loss,
352
+ val_loss,
353
+ mae[0],
354
+ mae[1],
355
+ mae[2],
356
+ mae[3],
357
+ )
358
+ else:
359
+ patience_counter += 1
360
+ if (epoch + 1) % 50 == 0:
361
+ LOGGER.info(
362
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f",
363
+ epoch + 1,
364
+ epochs,
365
+ train_loss,
366
+ val_loss,
367
+ )
368
+
369
+ if patience_counter >= patience:
370
+ LOGGER.info("Early stopping at epoch %d", epoch + 1)
371
+ break
372
+
373
+ # Load best model
374
+ model.load_state_dict(best_state)
375
+
376
+ # Final evaluation
377
+ model.eval()
378
+ with torch.no_grad():
379
+ pred_val = model(X_val_t.to(device)).cpu()
380
+ pred_denorm = pred_val.numpy()
381
+ hue_min, hue_max = output_parameters["hue_range"]
382
+ value_min, value_max = output_parameters["value_range"]
383
+ chroma_min, chroma_max = output_parameters["chroma_range"]
384
+ code_min, code_max = output_parameters["code_range"]
385
+
386
+ pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min
387
+ pred_denorm[:, 1] = (
388
+ pred_val[:, 1].numpy() * (value_max - value_min) + value_min
389
+ )
390
+ pred_denorm[:, 2] = (
391
+ pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min
392
+ )
393
+ pred_denorm[:, 3] = (
394
+ pred_val[:, 3].numpy() * (code_max - code_min) + code_min
395
+ )
396
+
397
+ y_denorm = y_val_norm.copy()
398
+ y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min
399
+ y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min
400
+ y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min
401
+ y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min
402
+
403
+ mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0)
404
+
405
+ # Log final metrics
406
+ mlflow.log_metrics(
407
+ {
408
+ "best_val_loss": best_val_loss,
409
+ "final_mae_hue": mae[0],
410
+ "final_mae_value": mae[1],
411
+ "final_mae_chroma": mae[2],
412
+ "final_mae_code": mae[3],
413
+ "final_epoch": epoch + 1,
414
+ }
415
+ )
416
+
417
+ LOGGER.info("")
418
+ LOGGER.info("Final Results:")
419
+ LOGGER.info(" Best Val Loss: %.6f", best_val_loss)
420
+ LOGGER.info(" MAE hue: %.6f", mae[0])
421
+ LOGGER.info(" MAE value: %.6f", mae[1])
422
+ LOGGER.info(" MAE chroma: %.6f", mae[2])
423
+ LOGGER.info(" MAE code: %.6f", mae[3])
424
+
425
+ # Save model
426
+ models_dir = PROJECT_ROOT / "models" / "from_xyY"
427
+ models_dir.mkdir(exist_ok=True)
428
+
429
+ checkpoint_path = models_dir / "multi_head_circular.pth"
430
+ torch.save(
431
+ {
432
+ "model_state_dict": model.state_dict(),
433
+ "output_parameters": output_parameters,
434
+ "val_loss": best_val_loss,
435
+ "mae": {
436
+ "hue": float(mae[0]),
437
+ "value": float(mae[1]),
438
+ "chroma": float(mae[2]),
439
+ "code": float(mae[3]),
440
+ },
441
+ "hyperparameters": {
442
+ "encoder_width": encoder_width,
443
+ "head_width": head_width,
444
+ "chroma_head_width": chroma_head_width,
445
+ "dropout": dropout,
446
+ "lr": lr,
447
+ "batch_size": batch_size,
448
+ "weight_decay": weight_decay,
449
+ },
450
+ "loss_type": "circular_hue",
451
+ },
452
+ checkpoint_path,
453
+ )
454
+ LOGGER.info("")
455
+ LOGGER.info("Saved checkpoint: %s", checkpoint_path)
456
+
457
+ # Export to ONNX
458
+ model.cpu().eval()
459
+ dummy_input = torch.randn(1, 3)
460
+ onnx_path = models_dir / "multi_head_circular.onnx"
461
+
462
+ torch.onnx.export(
463
+ model,
464
+ dummy_input,
465
+ onnx_path,
466
+ input_names=["xyY"], # Match other models for comparison compatibility
467
+ output_names=["munsell_spec"],
468
+ dynamic_axes={"xyY": {0: "batch"}, "munsell_spec": {0: "batch"}},
469
+ opset_version=17,
470
+ )
471
+ LOGGER.info("Saved ONNX: %s", onnx_path)
472
+
473
+ # Save normalization parameters
474
+ params_path = models_dir / "multi_head_circular_normalization_parameters.npz"
475
+ np.savez(
476
+ params_path,
477
+ output_parameters=output_parameters,
478
+ )
479
+ LOGGER.info("Saved normalization parameters: %s", params_path)
480
+
481
+ # Log artifacts to MLflow
482
+ mlflow.log_artifact(str(checkpoint_path))
483
+ mlflow.log_artifact(str(onnx_path))
484
+ mlflow.log_artifact(str(params_path))
485
+ mlflow.pytorch.log_model(model, "model")
486
+ LOGGER.info("Artifacts logged to MLflow")
487
+
488
+ LOGGER.info("=" * 80)
489
+
490
+ return model, best_val_loss
491
+
492
+
493
+ if __name__ == "__main__":
494
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
495
+
496
+ main()
learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Multi-Head + Cross-Attention Error Predictor for xyY to Munsell conversion.
3
+
4
+ This version uses cross-attention between component branches to learn
5
+ correlations between errors in different Munsell components.
6
+
7
+ Key Features:
8
+ - Shared context encoder
9
+ - Multi-head cross-attention between components
10
+ - Component-specific prediction heads
11
+ - Residual connections
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ import logging
18
+
19
+ import click
20
+ import mlflow
21
+ import mlflow.pytorch
22
+ import numpy as np
23
+ import onnxruntime as ort
24
+ import torch
25
+ from torch import nn, optim
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+
28
+ from learning_munsell import PROJECT_ROOT
29
+ from learning_munsell.utilities.common import (
30
+ fix_onnx_dynamic_batch,
31
+ log_training_epoch,
32
+ setup_mlflow_experiment,
33
+ )
34
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
35
+
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+ # Note: This script has a custom CrossAttentionErrorPredictor architecture
39
+ # so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor
40
+ # from shared modules.
41
+
42
+
43
+ class CustomMultiheadAttention(nn.Module):
44
+ """
45
+ Custom multi-head attention that exports cleanly to ONNX.
46
+
47
+ Uses basic operations instead of nn.MultiheadAttention to avoid
48
+ reshape issues with dynamic batch sizes during ONNX export.
49
+
50
+ Parameters
51
+ ----------
52
+ embed_dim : int
53
+ Total dimension of the model (must be divisible by num_heads).
54
+ num_heads : int
55
+ Number of parallel attention heads.
56
+ dropout : float, optional
57
+ Dropout probability on attention weights.
58
+
59
+ Attributes
60
+ ----------
61
+ embed_dim : int
62
+ Total embedding dimension.
63
+ num_heads : int
64
+ Number of attention heads.
65
+ head_dim : int
66
+ Dimension of each attention head (embed_dim // num_heads).
67
+ scale : float
68
+ Scaling factor for attention scores (head_dim ** -0.5).
69
+ q_proj : nn.Linear
70
+ Query projection layer.
71
+ k_proj : nn.Linear
72
+ Key projection layer.
73
+ v_proj : nn.Linear
74
+ Value projection layer.
75
+ out_proj : nn.Linear
76
+ Output projection layer.
77
+ dropout : nn.Dropout
78
+ Dropout layer for attention weights.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ embed_dim: int,
84
+ num_heads: int,
85
+ dropout: float = 0.0,
86
+ ) -> None:
87
+ """Initialize the custom multi-head attention module."""
88
+ super().__init__()
89
+
90
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" # noqa: S101
91
+
92
+ self.embed_dim = embed_dim
93
+ self.num_heads = num_heads
94
+ self.head_dim = embed_dim // num_heads
95
+ self.scale = self.head_dim**-0.5
96
+
97
+ # Linear projections for Q, K, V
98
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
99
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
100
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
101
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
102
+
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Forward pass for self-attention.
108
+
109
+ Parameters
110
+ ----------
111
+ x : Tensor
112
+ Input tensor [batch, seq_len, embed_dim]
113
+
114
+ Returns
115
+ -------
116
+ Tensor
117
+ Output tensor [batch, seq_len, embed_dim]
118
+ """
119
+ _batch_size, seq_len, _embed_dim = x.shape
120
+
121
+ # Project to Q, K, V
122
+ q = self.q_proj(x) # [batch, seq_len, embed_dim]
123
+ k = self.k_proj(x) # [batch, seq_len, embed_dim]
124
+ v = self.v_proj(x) # [batch, seq_len, embed_dim]
125
+
126
+ # Reshape for multi-head attention: [batch, seq_len, num_heads, head_dim]
127
+ # Then transpose to: [batch, num_heads, seq_len, head_dim]
128
+ # Use -1 for batch dimension to enable dynamic batch size in ONNX
129
+ q = q.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
130
+ k = k.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
131
+ v = v.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
132
+
133
+ # Scaled dot-product attention
134
+ # Q @ K^T: [batch, heads, seq, dim] @ [batch, heads, dim, seq]
135
+ # -> [batch, heads, seq, seq]
136
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
137
+ attn_weights = torch.softmax(attn_scores, dim=-1)
138
+ attn_weights = self.dropout(attn_weights)
139
+
140
+ # Apply attention to values
141
+ # [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim]
142
+ # -> [batch, num_heads, seq_len, head_dim]
143
+ attn_output = torch.matmul(attn_weights, v)
144
+
145
+ # Transpose back and reshape: [batch, num_heads, seq_len, head_dim]
146
+ # -> [batch, seq_len, num_heads, head_dim]
147
+ # -> [batch, seq_len, embed_dim]
148
+ # Use -1 for batch dimension to enable dynamic batch size in ONNX
149
+ attn_output = attn_output.transpose(1, 2).contiguous()
150
+ attn_output = attn_output.reshape(-1, seq_len, self.embed_dim)
151
+
152
+ # Final projection
153
+ return self.out_proj(attn_output)
154
+
155
+
156
+ class CrossAttentionErrorPredictor(nn.Module):
157
+ """
158
+ Error predictor with cross-attention between Munsell components.
159
+
160
+ Uses cross-attention to learn correlations between errors in different
161
+ Munsell components (hue, value, chroma, code).
162
+
163
+ Parameters
164
+ ----------
165
+ input_dim : int, optional
166
+ Input dimension (7 = xyY_norm + base_pred_norm).
167
+ context_dim : int, optional
168
+ Dimension of shared context features.
169
+ component_dim : int, optional
170
+ Dimension of component-specific features.
171
+ n_components : int, optional
172
+ Number of Munsell components (4).
173
+ n_attention_heads : int, optional
174
+ Number of attention heads for cross-attention.
175
+ dropout : float, optional
176
+ Dropout probability.
177
+
178
+ Attributes
179
+ ----------
180
+ context_encoder : nn.Sequential
181
+ Shared encoder: input_dim → 256 → context_dim.
182
+ component_encoders : nn.ModuleList
183
+ Component-specific encoders: context_dim → component_dim (x4).
184
+ cross_attention : CustomMultiheadAttention
185
+ Cross-attention module between component features.
186
+ attention_norm : nn.LayerNorm
187
+ Layer normalization after attention.
188
+ component_decoders : nn.ModuleList
189
+ Component-specific decoders: component_dim → 128 → 1 (x4).
190
+
191
+ Notes
192
+ -----
193
+ Architecture:
194
+ 1. Shared context encoder: 7 → 256 → 512
195
+ 2. Component-specific encoders: 512 → 256 (x4)
196
+ 3. Multi-head cross-attention between components
197
+ 4. Residual connection + layer norm
198
+ 5. Component-specific decoders: 256 → 128 → 1
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ input_dim: int = 7,
204
+ context_dim: int = 512,
205
+ component_dim: int = 256,
206
+ n_components: int = 4,
207
+ n_attention_heads: int = 4,
208
+ dropout: float = 0.1,
209
+ ) -> None:
210
+ """Initialize the cross-attention error predictor."""
211
+ super().__init__()
212
+
213
+ self.n_components = n_components
214
+ self.component_dim = component_dim
215
+
216
+ # Shared context encoder
217
+ self.context_encoder = nn.Sequential(
218
+ nn.Linear(input_dim, 256),
219
+ nn.GELU(),
220
+ nn.LayerNorm(256),
221
+ nn.Dropout(dropout),
222
+ nn.Linear(256, context_dim),
223
+ nn.GELU(),
224
+ nn.LayerNorm(context_dim),
225
+ )
226
+
227
+ # Component-specific encoders
228
+ self.component_encoders = nn.ModuleList(
229
+ [
230
+ nn.Sequential(
231
+ nn.Linear(context_dim, component_dim),
232
+ nn.GELU(),
233
+ nn.LayerNorm(component_dim),
234
+ )
235
+ for _ in range(n_components)
236
+ ]
237
+ )
238
+
239
+ # Multi-head cross-attention (using custom implementation)
240
+ self.cross_attention = CustomMultiheadAttention(
241
+ embed_dim=component_dim,
242
+ num_heads=n_attention_heads,
243
+ dropout=dropout,
244
+ )
245
+
246
+ # Layer norm after attention
247
+ self.attention_norm = nn.LayerNorm(component_dim)
248
+
249
+ # Component-specific decoders
250
+ self.component_decoders = nn.ModuleList(
251
+ [
252
+ nn.Sequential(
253
+ nn.Linear(component_dim, 128),
254
+ nn.GELU(),
255
+ nn.LayerNorm(128),
256
+ nn.Dropout(dropout),
257
+ nn.Linear(128, 1),
258
+ )
259
+ for _ in range(n_components)
260
+ ]
261
+ )
262
+
263
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
264
+ """
265
+ Forward pass with cross-attention.
266
+
267
+ Parameters
268
+ ----------
269
+ x : Tensor
270
+ Input [xyY_norm (3) + base_pred_norm (4)] = 7 features
271
+
272
+ Returns
273
+ -------
274
+ Tensor
275
+ Predicted errors [hue_err, value_err, chroma_err, code_err]
276
+ """
277
+ # Shared context encoding
278
+ context = self.context_encoder(x) # [batch, 512]
279
+
280
+ # Component-specific encoding
281
+ component_features = []
282
+ for encoder in self.component_encoders:
283
+ feat = encoder(context) # [batch, 256]
284
+ component_features.append(feat)
285
+
286
+ # Stack for cross-attention: [batch, 4, 256]
287
+ component_stack = torch.stack(component_features, dim=1)
288
+
289
+ # Cross-attention between components
290
+ attended = self.cross_attention(component_stack) # [batch, 4, 256]
291
+
292
+ # Residual connection + layer norm
293
+ component_stack = self.attention_norm(component_stack + attended)
294
+
295
+ # Component-specific decoding (unrolled for ONNX compatibility)
296
+ # Use unbind to split the tensor instead of indexing to preserve batch dimension
297
+ components = torch.unbind(
298
+ component_stack, dim=1
299
+ ) # Split into 4 tensors of shape [batch, 256]
300
+
301
+ # Decode each component explicitly
302
+ pred_0 = self.component_decoders[0](components[0]) # [batch, 1]
303
+ pred_1 = self.component_decoders[1](components[1]) # [batch, 1]
304
+ pred_2 = self.component_decoders[2](components[2]) # [batch, 1]
305
+ pred_3 = self.component_decoders[3](components[3]) # [batch, 1]
306
+
307
+ # Concatenate along dimension 1 and squeeze
308
+ return torch.cat([pred_0, pred_1, pred_2, pred_3], dim=1) # [batch, 4]
309
+
310
+
311
+ def train_cross_attention_error_predictor(
312
+ epochs: int = 300,
313
+ batch_size: int = 1024,
314
+ lr: float = 0.0005,
315
+ dropout: float = 0.1,
316
+ context_dim: int = 512,
317
+ component_dim: int = 256,
318
+ n_attention_heads: int = 4,
319
+ ) -> tuple[CrossAttentionErrorPredictor, float]:
320
+ """
321
+ Train cross-attention error predictor.
322
+
323
+ This model uses cross-attention between component branches to learn
324
+ correlations between errors in different Munsell components.
325
+
326
+ Parameters
327
+ ----------
328
+ epochs : int, optional
329
+ Maximum number of training epochs.
330
+ batch_size : int, optional
331
+ Training batch size.
332
+ lr : float, optional
333
+ Learning rate for AdamW optimizer.
334
+ dropout : float, optional
335
+ Dropout rate for regularization.
336
+ context_dim : int, optional
337
+ Dimension of shared context features.
338
+ component_dim : int, optional
339
+ Dimension of component-specific features.
340
+ n_attention_heads : int, optional
341
+ Number of attention heads for cross-attention.
342
+
343
+ Returns
344
+ -------
345
+ model : CrossAttentionErrorPredictor
346
+ Trained model with best validation loss weights.
347
+ best_val_loss : float
348
+ Best validation loss achieved during training.
349
+
350
+ Notes
351
+ -----
352
+ The training pipeline:
353
+ 1. Loads pre-trained Multi-Head base model
354
+ 2. Generates base model predictions for training data
355
+ 3. Computes residual errors between predictions and targets
356
+ 4. Trains cross-attention error predictor on these residuals
357
+ 5. Uses CosineAnnealingLR scheduler
358
+ 6. Early stopping based on validation loss
359
+ 7. Exports model to ONNX format
360
+ 8. Logs metrics and artifacts to MLflow
361
+ """
362
+
363
+ LOGGER.info("=" * 80)
364
+ LOGGER.info("Training Multi-Head + Cross-Attention Error Predictor")
365
+ LOGGER.info("=" * 80)
366
+ LOGGER.info("")
367
+ LOGGER.info("Architecture:")
368
+ LOGGER.info(" - Shared context encoder: 7 → 256 → %d", context_dim)
369
+ LOGGER.info(" - Component encoders: %d → %d (x4)", context_dim, component_dim)
370
+ LOGGER.info(" - Cross-attention: %d heads", n_attention_heads)
371
+ LOGGER.info(" - Component decoders: %d → 128 → 1 (x4)", component_dim)
372
+ LOGGER.info("")
373
+ LOGGER.info("Hyperparameters:")
374
+ LOGGER.info(" lr: %.6f", lr)
375
+ LOGGER.info(" batch_size: %d", batch_size)
376
+ LOGGER.info(" dropout: %.2f", dropout)
377
+ LOGGER.info("")
378
+
379
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
380
+ LOGGER.info("Using device: %s", device)
381
+
382
+ # Paths
383
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
384
+ data_dir = PROJECT_ROOT / "data"
385
+ base_model_path = model_directory / "multi_head.onnx"
386
+ params_path = model_directory / "multi_head_normalization_parameters.npz"
387
+ cache_file = data_dir / "training_data.npz"
388
+
389
+ # Load base model
390
+ LOGGER.info("")
391
+ LOGGER.info("Loading Multi-Head base model from %s...", base_model_path)
392
+ base_session = ort.InferenceSession(str(base_model_path))
393
+ params = np.load(params_path, allow_pickle=True)
394
+ input_parameters = params["input_parameters"].item()
395
+ output_parameters = params["output_parameters"].item()
396
+
397
+ # Load training data
398
+ LOGGER.info("Loading training data from %s...", cache_file)
399
+ data = np.load(cache_file)
400
+ X_train = data["X_train"]
401
+ y_train = data["y_train"]
402
+ X_val = data["X_val"]
403
+ y_val = data["y_val"]
404
+
405
+ LOGGER.info("Train samples: %d", len(X_train))
406
+ LOGGER.info("Validation samples: %d", len(X_val))
407
+
408
+ # Generate base model predictions
409
+ LOGGER.info("")
410
+ LOGGER.info("Generating Multi-Head base model predictions...")
411
+ X_train_norm = normalize_xyY(X_train, input_parameters)
412
+ y_train_norm = normalize_munsell(y_train, output_parameters)
413
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
414
+
415
+ X_val_norm = normalize_xyY(X_val, input_parameters)
416
+ y_val_norm = normalize_munsell(y_val, output_parameters)
417
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
418
+
419
+ # Compute errors
420
+ error_train = y_train_norm - base_pred_train_norm
421
+ error_val = y_val_norm - base_pred_val_norm
422
+
423
+ LOGGER.info("")
424
+ LOGGER.info("Base model error statistics (normalized space):")
425
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
426
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
427
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
428
+
429
+ # Create combined input: [xyY_norm, base_prediction_norm]
430
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
431
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
432
+
433
+ # Convert to PyTorch tensors
434
+ X_train_t = torch.FloatTensor(X_train_combined)
435
+ error_train_t = torch.FloatTensor(error_train)
436
+ X_val_t = torch.FloatTensor(X_val_combined)
437
+ error_val_t = torch.FloatTensor(error_val)
438
+
439
+ # Create data loaders
440
+ train_dataset = TensorDataset(X_train_t, error_train_t)
441
+ val_dataset = TensorDataset(X_val_t, error_val_t)
442
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
443
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
444
+
445
+ # Initialize model
446
+ model = CrossAttentionErrorPredictor(
447
+ input_dim=7,
448
+ context_dim=context_dim,
449
+ component_dim=component_dim,
450
+ n_attention_heads=n_attention_heads,
451
+ dropout=dropout,
452
+ ).to(device)
453
+
454
+ # Count parameters
455
+ total_params = sum(p.numel() for p in model.parameters())
456
+ LOGGER.info("")
457
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
458
+
459
+ context_params = sum(p.numel() for p in model.context_encoder.parameters())
460
+ attention_params = sum(p.numel() for p in model.cross_attention.parameters())
461
+ LOGGER.info(" - Context encoder: %s", f"{context_params:,}")
462
+ LOGGER.info(" - Cross-attention: %s", f"{attention_params:,}")
463
+
464
+ # Training setup
465
+ criterion = nn.MSELoss()
466
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
467
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
468
+
469
+ # MLflow setup
470
+ run_name = setup_mlflow_experiment("from_xyY", "cross_attention_error_predictor")
471
+ LOGGER.info("")
472
+ LOGGER.info("MLflow run: %s", run_name)
473
+
474
+ # Training loop
475
+ best_val_loss = float("inf")
476
+ best_state = None
477
+ patience = 30
478
+ patience_counter = 0
479
+
480
+ LOGGER.info("")
481
+ LOGGER.info("Starting training...")
482
+
483
+ with mlflow.start_run(run_name=run_name):
484
+ mlflow.log_params(
485
+ {
486
+ "model": "cross_attention_error_predictor",
487
+ "context_dim": context_dim,
488
+ "component_dim": component_dim,
489
+ "n_attention_heads": n_attention_heads,
490
+ "dropout": dropout,
491
+ "learning_rate": lr,
492
+ "batch_size": batch_size,
493
+ "epochs": epochs,
494
+ "patience": patience,
495
+ "total_params": total_params,
496
+ }
497
+ )
498
+
499
+ for epoch in range(epochs):
500
+ # Training
501
+ model.train()
502
+ train_loss = 0.0
503
+ for X_batch, y_batch in train_loader:
504
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device) # noqa: PLW2901
505
+
506
+ optimizer.zero_grad()
507
+ pred = model(X_batch)
508
+ loss = criterion(pred, y_batch)
509
+ loss.backward()
510
+ optimizer.step()
511
+ train_loss += loss.item() * len(X_batch)
512
+
513
+ train_loss /= len(X_train_t)
514
+ scheduler.step()
515
+
516
+ # Validation
517
+ model.eval()
518
+ val_loss = 0.0
519
+ with torch.no_grad():
520
+ for X_batch, y_batch in val_loader:
521
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device) # noqa: PLW2901
522
+ pred = model(X_batch)
523
+ val_loss += criterion(pred, y_batch).item() * len(X_batch)
524
+ val_loss /= len(X_val_t)
525
+
526
+ log_training_epoch(
527
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
528
+ )
529
+
530
+ if val_loss < best_val_loss:
531
+ best_val_loss = val_loss
532
+ best_state = copy.deepcopy(model.state_dict())
533
+ patience_counter = 0
534
+ LOGGER.info(
535
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - LR: %.6f",
536
+ epoch + 1,
537
+ epochs,
538
+ train_loss,
539
+ val_loss,
540
+ optimizer.param_groups[0]["lr"],
541
+ )
542
+ else:
543
+ patience_counter += 1
544
+ if (epoch + 1) % 50 == 0:
545
+ LOGGER.info(
546
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f",
547
+ epoch + 1,
548
+ epochs,
549
+ train_loss,
550
+ val_loss,
551
+ )
552
+
553
+ if patience_counter >= patience:
554
+ LOGGER.info("Early stopping at epoch %d", epoch + 1)
555
+ break
556
+
557
+ # Load best model
558
+ model.load_state_dict(best_state)
559
+
560
+ mlflow.log_metrics(
561
+ {
562
+ "best_val_loss": best_val_loss,
563
+ "final_epoch": epoch + 1,
564
+ }
565
+ )
566
+
567
+ LOGGER.info("")
568
+ LOGGER.info("Final Results:")
569
+ LOGGER.info(" Best Val Loss: %.6f", best_val_loss)
570
+
571
+ # Save model
572
+ model_directory.mkdir(exist_ok=True)
573
+ checkpoint_path = (
574
+ model_directory / "multi_head_cross_attention_error_predictor.pth"
575
+ )
576
+
577
+ torch.save(
578
+ {
579
+ "model_state_dict": model.state_dict(),
580
+ "val_loss": best_val_loss,
581
+ "hyperparameters": {
582
+ "context_dim": context_dim,
583
+ "component_dim": component_dim,
584
+ "n_attention_heads": n_attention_heads,
585
+ "dropout": dropout,
586
+ "lr": lr,
587
+ "batch_size": batch_size,
588
+ },
589
+ },
590
+ checkpoint_path,
591
+ )
592
+ LOGGER.info("")
593
+ LOGGER.info("Saved checkpoint: %s", checkpoint_path)
594
+
595
+ # Export to ONNX
596
+ LOGGER.info("")
597
+ LOGGER.info("Exporting error predictor to ONNX...")
598
+ model.eval()
599
+ model.cpu()
600
+
601
+ dummy_input = torch.randn(1, 7)
602
+ onnx_path = model_directory / "multi_head_cross_attention_error_predictor.onnx"
603
+
604
+ torch.onnx.export(
605
+ model,
606
+ dummy_input,
607
+ onnx_path,
608
+ export_params=True,
609
+ opset_version=17,
610
+ input_names=["combined_input"],
611
+ output_names=["error_correction"],
612
+ dynamic_axes={
613
+ "combined_input": {0: "batch_size"},
614
+ "error_correction": {0: "batch_size"},
615
+ },
616
+ )
617
+
618
+ fix_onnx_dynamic_batch(onnx_path)
619
+
620
+ mlflow.log_artifact(str(checkpoint_path))
621
+ mlflow.log_artifact(str(onnx_path))
622
+ mlflow.pytorch.log_model(model, "model")
623
+
624
+ LOGGER.info("ONNX model saved to: %s", onnx_path)
625
+ LOGGER.info("Artifacts logged to MLflow")
626
+
627
+ LOGGER.info("=" * 80)
628
+
629
+ return model, best_val_loss
630
+
631
+
632
+ @click.command()
633
+ @click.option("--epochs", default=300, help="Maximum number of training epochs.")
634
+ @click.option("--batch-size", default=1024, help="Training batch size.")
635
+ @click.option("--lr", default=0.0005, help="Learning rate for AdamW optimizer.")
636
+ @click.option("--dropout", default=0.1, help="Dropout rate.")
637
+ @click.option(
638
+ "--context-dim", default=512, help="Dimension of shared context features."
639
+ )
640
+ @click.option("--component-dim", default=256, help="Dimension of component features.")
641
+ @click.option("--n-attention-heads", default=4, help="Number of attention heads.")
642
+ def main(
643
+ epochs: int = 300,
644
+ batch_size: int = 1024,
645
+ lr: float = 0.0005,
646
+ dropout: float = 0.1,
647
+ context_dim: int = 512,
648
+ component_dim: int = 256,
649
+ n_attention_heads: int = 4,
650
+ ) -> None:
651
+ """Train Multi-Head + Cross-Attention Error Predictor."""
652
+
653
+ train_cross_attention_error_predictor(
654
+ epochs=epochs,
655
+ batch_size=batch_size,
656
+ lr=lr,
657
+ dropout=dropout,
658
+ context_dim=context_dim,
659
+ component_dim=component_dim,
660
+ n_attention_heads=n_attention_heads,
661
+ )
662
+
663
+
664
+ if __name__ == "__main__":
665
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
666
+
667
+ main()
learning_munsell/training/from_xyY/train_multi_head_gamma.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model for xyY to Munsell conversion with gamma-corrected Y.
3
+
4
+ Experiment: Apply gamma 2.33 to Y before normalization to better align
5
+ with perceptual lightness (Munsell Value scale is perceptually uniform).
6
+
7
+ The multi-head architecture has separate heads for each Munsell component,
8
+ so gamma correction on Y should primarily benefit Value prediction without
9
+ negatively impacting Chroma prediction (unlike the single MLP).
10
+ """
11
+
12
+ import logging
13
+ from typing import Any
14
+
15
+ import click
16
+ import mlflow
17
+ import mlflow.pytorch
18
+ import numpy as np
19
+ import torch
20
+ from numpy.typing import NDArray
21
+ from torch import optim
22
+ from torch.utils.data import DataLoader, TensorDataset
23
+
24
+ from learning_munsell import PROJECT_ROOT
25
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
26
+ from learning_munsell.utilities.common import (
27
+ log_training_epoch,
28
+ setup_mlflow_experiment,
29
+ )
30
+ from learning_munsell.utilities.data import (
31
+ MUNSELL_NORMALIZATION_PARAMS,
32
+ normalize_munsell,
33
+ )
34
+ from learning_munsell.utilities.losses import weighted_mse_loss
35
+ from learning_munsell.utilities.training import train_epoch, validate
36
+
37
+ LOGGER = logging.getLogger(__name__)
38
+
39
+ # Gamma value for Y transformation
40
+ GAMMA = 2.33
41
+
42
+
43
+ def normalize_inputs(
44
+ X: NDArray, gamma: float = GAMMA
45
+ ) -> tuple[NDArray, dict[str, Any]]:
46
+ """
47
+ Normalize xyY inputs to [0, 1] range with gamma correction on Y.
48
+
49
+ Parameters
50
+ ----------
51
+ X : ndarray
52
+ xyY values of shape (n, 3) where columns are [x, y, Y].
53
+ gamma : float
54
+ Gamma value to apply to Y component.
55
+
56
+ Returns
57
+ -------
58
+ ndarray
59
+ Normalized values with gamma-corrected Y, dtype float32.
60
+ dict
61
+ Normalization parameters including gamma value.
62
+ """
63
+ # xyY chromaticity and luminance ranges (all [0, 1])
64
+ x_range = (0.0, 1.0)
65
+ y_range = (0.0, 1.0)
66
+ Y_range = (0.0, 1.0)
67
+
68
+ X_norm = X.copy()
69
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
70
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
71
+
72
+ # Normalize Y first, then apply gamma
73
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
74
+ # Clip to avoid numerical issues with negative values
75
+ Y_normalized = np.clip(Y_normalized, 0, 1)
76
+ # Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light
77
+ X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
78
+
79
+ params = {
80
+ "x_range": x_range,
81
+ "y_range": y_range,
82
+ "Y_range": Y_range,
83
+ "gamma": gamma,
84
+ }
85
+
86
+ return X_norm, params
87
+
88
+
89
+ @click.command()
90
+ @click.option("--epochs", default=300, help="Number of training epochs")
91
+ @click.option("--batch-size", default=1024, help="Batch size for training")
92
+ @click.option("--lr", default=5e-4, help="Learning rate")
93
+ @click.option("--patience", default=20, help="Early stopping patience")
94
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
95
+ """
96
+ Train the multi-head model with gamma-corrected Y input.
97
+
98
+ Notes
99
+ -----
100
+ The training pipeline:
101
+ 1. Loads training and validation data from cache
102
+ 2. Normalizes inputs with gamma correction (gamma=2.33) on Y
103
+ 3. Normalizes Munsell outputs to [0, 1] range
104
+ 4. Trains multi-head MLP with weighted MSE loss
105
+ 5. Uses early stopping based on validation loss
106
+ 6. Exports best model to ONNX format
107
+ 7. Logs metrics and artifacts to MLflow
108
+
109
+ The gamma correction on Y aligns with perceptual lightness. The Munsell
110
+ Value scale is perceptually uniform, so gamma correction should primarily
111
+ benefit Value prediction without negatively impacting Chroma prediction.
112
+ """
113
+
114
+ LOGGER.info("=" * 80)
115
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Gamma Experiment")
116
+ LOGGER.info("Gamma = %.2f applied to Y component", GAMMA)
117
+ LOGGER.info("=" * 80)
118
+
119
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+ LOGGER.info("Using device: %s", device)
121
+
122
+ # Load training data
123
+ data_dir = PROJECT_ROOT / "data"
124
+ cache_file = data_dir / "training_data.npz"
125
+
126
+ if not cache_file.exists():
127
+ LOGGER.error("Error: Training data not found at %s", cache_file)
128
+ LOGGER.error("Please run 01_generate_training_data.py first")
129
+ return
130
+
131
+ LOGGER.info("Loading training data from %s...", cache_file)
132
+ data = np.load(cache_file)
133
+
134
+ X_train = data["X_train"]
135
+ y_train = data["y_train"]
136
+ X_val = data["X_val"]
137
+ y_val = data["y_val"]
138
+
139
+ LOGGER.info("Train samples: %d", len(X_train))
140
+ LOGGER.info("Validation samples: %d", len(X_val))
141
+
142
+ # Normalize data with gamma correction
143
+ X_train_norm, input_parameters = normalize_inputs(X_train, gamma=GAMMA)
144
+ X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA)
145
+
146
+ # Use shared normalization parameters for Munsell outputs
147
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
148
+ y_train_norm = normalize_munsell(y_train, output_parameters)
149
+ y_val_norm = normalize_munsell(y_val, output_parameters)
150
+
151
+ LOGGER.info("")
152
+ LOGGER.info("Input normalization with gamma=%.2f:", GAMMA)
153
+ LOGGER.info(
154
+ " Y range after gamma: [%.4f, %.4f]",
155
+ X_train_norm[:, 2].min(),
156
+ X_train_norm[:, 2].max(),
157
+ )
158
+
159
+ # Convert to PyTorch tensors
160
+ X_train_t = torch.FloatTensor(X_train_norm)
161
+ y_train_t = torch.FloatTensor(y_train_norm)
162
+ X_val_t = torch.FloatTensor(X_val_norm)
163
+ y_val_t = torch.FloatTensor(y_val_norm)
164
+
165
+ # Create data loaders
166
+ train_dataset = TensorDataset(X_train_t, y_train_t)
167
+ val_dataset = TensorDataset(X_val_t, y_val_t)
168
+
169
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
170
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
171
+
172
+ # Initialize model
173
+ model = MultiHeadMLPToMunsell().to(device)
174
+ LOGGER.info("")
175
+ LOGGER.info("Model architecture:")
176
+ LOGGER.info("%s", model)
177
+
178
+ total_params = sum(p.numel() for p in model.parameters())
179
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
180
+
181
+ # Training setup
182
+ optimizer = optim.Adam(model.parameters(), lr=lr)
183
+ criterion = weighted_mse_loss
184
+
185
+ # MLflow setup
186
+ run_name = setup_mlflow_experiment("from_xyY", f"multi_head_gamma_{GAMMA}")
187
+
188
+ LOGGER.info("")
189
+ LOGGER.info("MLflow run: %s", run_name)
190
+
191
+ # Training loop
192
+ best_val_loss = float("inf")
193
+ patience_counter = 0
194
+
195
+ LOGGER.info("")
196
+ LOGGER.info("Starting training...")
197
+
198
+ with mlflow.start_run(run_name=run_name):
199
+ mlflow.log_params(
200
+ {
201
+ "model": "multi_head_gamma",
202
+ "num_epochs": epochs,
203
+ "batch_size": batch_size,
204
+ "learning_rate": lr,
205
+ "optimizer": "Adam",
206
+ "criterion": "weighted_mse_loss",
207
+ "patience": patience,
208
+ "total_params": total_params,
209
+ "gamma": GAMMA,
210
+ }
211
+ )
212
+
213
+ for epoch in range(epochs):
214
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
215
+ val_loss = validate(model, val_loader, criterion, device)
216
+
217
+ log_training_epoch(
218
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
219
+ )
220
+
221
+ LOGGER.info(
222
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
223
+ epoch + 1,
224
+ epochs,
225
+ train_loss,
226
+ val_loss,
227
+ )
228
+
229
+ if val_loss < best_val_loss:
230
+ best_val_loss = val_loss
231
+ patience_counter = 0
232
+
233
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
234
+ model_directory.mkdir(exist_ok=True)
235
+ checkpoint_file = model_directory / "multi_head_gamma_best.pth"
236
+
237
+ torch.save(
238
+ {
239
+ "model_state_dict": model.state_dict(),
240
+ "input_parameters": input_parameters,
241
+ "output_parameters": output_parameters,
242
+ "epoch": epoch,
243
+ "val_loss": val_loss,
244
+ },
245
+ checkpoint_file,
246
+ )
247
+
248
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
249
+ else:
250
+ patience_counter += 1
251
+ if patience_counter >= patience:
252
+ LOGGER.info("")
253
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
254
+ break
255
+
256
+ mlflow.log_metrics(
257
+ {
258
+ "best_val_loss": best_val_loss,
259
+ "final_epoch": epoch + 1,
260
+ }
261
+ )
262
+
263
+ # Export to ONNX
264
+ LOGGER.info("")
265
+ LOGGER.info("Exporting model to ONNX...")
266
+ model.eval()
267
+
268
+ checkpoint = torch.load(checkpoint_file)
269
+ model.load_state_dict(checkpoint["model_state_dict"])
270
+
271
+ dummy_input = torch.randn(1, 3).to(device)
272
+
273
+ onnx_file = model_directory / "multi_head_gamma.onnx"
274
+ torch.onnx.export(
275
+ model,
276
+ dummy_input,
277
+ onnx_file,
278
+ export_params=True,
279
+ opset_version=15,
280
+ input_names=["xyY_gamma"],
281
+ output_names=["munsell_spec"],
282
+ dynamic_axes={
283
+ "xyY_gamma": {0: "batch_size"},
284
+ "munsell_spec": {0: "batch_size"},
285
+ },
286
+ )
287
+
288
+ # Save normalization parameters (including gamma)
289
+ params_file = model_directory / "multi_head_gamma_normalization_parameters.npz"
290
+ np.savez(
291
+ params_file,
292
+ input_parameters=input_parameters,
293
+ output_parameters=output_parameters,
294
+ )
295
+
296
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
297
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
298
+ LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA)
299
+
300
+ mlflow.log_artifact(str(checkpoint_file))
301
+ mlflow.log_artifact(str(onnx_file))
302
+ mlflow.log_artifact(str(params_file))
303
+ mlflow.pytorch.log_model(model, "model")
304
+
305
+ LOGGER.info("=" * 80)
306
+
307
+
308
+ if __name__ == "__main__":
309
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
310
+
311
+ main()
learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML models with various gamma values to find optimal gamma.
3
+
4
+ Sweeps gamma from 1.0 to 3.0 in increments of 0.1 and evaluates each model
5
+ on real Munsell colours using Delta-E CIE2000.
6
+
7
+ Supports parallel execution with multiple runs per gamma for averaging.
8
+ """
9
+
10
+ import logging
11
+ from concurrent.futures import ProcessPoolExecutor, as_completed
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import torch
16
+ from colour import XYZ_to_Lab, xyY_to_XYZ
17
+ from colour.difference import delta_E_CIE2000
18
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
19
+ from colour.notation.munsell import (
20
+ CCS_ILLUMINANT_MUNSELL,
21
+ munsell_specification_to_xyY,
22
+ )
23
+ from numpy.typing import NDArray
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
29
+ from learning_munsell.utilities.data import (
30
+ MUNSELL_NORMALIZATION_PARAMS,
31
+ normalize_munsell,
32
+ )
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ def normalize_inputs(X: NDArray, gamma: float) -> tuple[NDArray, dict[str, Any]]:
38
+ """
39
+ Normalize xyY inputs to [0, 1] range with gamma correction on Y.
40
+
41
+ Parameters
42
+ ----------
43
+ X : ndarray
44
+ xyY values of shape (n, 3) where columns are [x, y, Y].
45
+ gamma : float
46
+ Gamma value to apply to Y component.
47
+
48
+ Returns
49
+ -------
50
+ ndarray
51
+ Normalized values with gamma-corrected Y, dtype float32.
52
+ dict
53
+ Normalization parameters including gamma value.
54
+ """
55
+ x_range = (0.0, 1.0)
56
+ y_range = (0.0, 1.0)
57
+ Y_range = (0.0, 1.0)
58
+
59
+ X_norm = X.copy()
60
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
61
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
62
+
63
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
64
+ Y_normalized = np.clip(Y_normalized, 0, 1)
65
+ X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
66
+
67
+ params = {
68
+ "x_range": x_range,
69
+ "y_range": y_range,
70
+ "Y_range": Y_range,
71
+ "gamma": gamma,
72
+ }
73
+
74
+ return X_norm, params
75
+
76
+
77
+ def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
78
+ """
79
+ Denormalize Munsell output from [0, 1] to original ranges.
80
+
81
+ Parameters
82
+ ----------
83
+ y_norm : ndarray
84
+ Normalized Munsell values in [0, 1] range.
85
+ params : dict
86
+ Normalization parameters containing range information.
87
+
88
+ Returns
89
+ -------
90
+ ndarray
91
+ Denormalized Munsell values in original ranges.
92
+ """
93
+ y = np.copy(y_norm)
94
+ y[..., 0] = (
95
+ y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
96
+ + params["hue_range"][0]
97
+ )
98
+ y[..., 1] = (
99
+ y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
100
+ + params["value_range"][0]
101
+ )
102
+ y[..., 2] = (
103
+ y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
104
+ + params["chroma_range"][0]
105
+ )
106
+ y[..., 3] = (
107
+ y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
108
+ + params["code_range"][0]
109
+ )
110
+ return y
111
+
112
+
113
+ def weighted_mse_loss(
114
+ pred: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None
115
+ ) -> torch.Tensor:
116
+ """
117
+ Component-wise weighted MSE loss.
118
+
119
+ Parameters
120
+ ----------
121
+ pred : Tensor
122
+ Predicted Munsell values.
123
+ target : Tensor
124
+ Ground truth Munsell values.
125
+ weights : Tensor, optional
126
+ Component weights [w_hue, w_value, w_chroma, w_code].
127
+
128
+ Returns
129
+ -------
130
+ Tensor
131
+ Weighted mean squared error loss.
132
+ """
133
+ if weights is None:
134
+ weights = torch.tensor([1.0, 1.0, 3.0, 0.5], device=pred.device)
135
+ mse = (pred - target) ** 2
136
+ weighted_mse = mse * weights
137
+ return weighted_mse.mean()
138
+
139
+
140
+ def clamp_munsell_specification(spec: NDArray) -> NDArray:
141
+ """
142
+ Clamp Munsell specification to valid ranges.
143
+
144
+ Parameters
145
+ ----------
146
+ spec : ndarray
147
+ Munsell specification [hue, value, chroma, code].
148
+
149
+ Returns
150
+ -------
151
+ ndarray
152
+ Clamped Munsell specification within valid ranges.
153
+ """
154
+ clamped = np.copy(spec)
155
+ clamped[..., 0] = np.clip(spec[..., 0], 0.5, 10.0)
156
+ clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0)
157
+ clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0)
158
+ clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0)
159
+ return clamped
160
+
161
+
162
+ def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]:
163
+ """
164
+ Compute Delta-E CIE2000 for predicted Munsell specifications.
165
+
166
+ Parameters
167
+ ----------
168
+ pred : ndarray
169
+ Predicted Munsell specifications.
170
+ reference_Lab : ndarray
171
+ Reference CIELAB values for comparison.
172
+
173
+ Returns
174
+ -------
175
+ list of float
176
+ Delta-E CIE2000 values for valid predictions.
177
+
178
+ Notes
179
+ -----
180
+ Predictions that cannot be converted to valid xyY are skipped.
181
+ """
182
+ delta_E_values = []
183
+ for idx in range(len(pred)):
184
+ try:
185
+ ml_spec = clamp_munsell_specification(pred[idx])
186
+ ml_spec_for_conversion = ml_spec.copy()
187
+ ml_spec_for_conversion[3] = round(ml_spec[3])
188
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
189
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
190
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
191
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
192
+ delta_E_values.append(delta_E)
193
+ except (RuntimeError, ValueError):
194
+ continue
195
+ return delta_E_values
196
+
197
+
198
+ def train_model(
199
+ gamma: float,
200
+ X_train: NDArray,
201
+ y_train: NDArray,
202
+ X_val: NDArray,
203
+ y_val: NDArray,
204
+ device: torch.device,
205
+ num_epochs: int = 100,
206
+ patience: int = 15,
207
+ ) -> tuple[nn.Module, dict[str, Any], dict[str, Any], float]:
208
+ """
209
+ Train a multi-head model with specified gamma value.
210
+
211
+ Parameters
212
+ ----------
213
+ gamma : float
214
+ Gamma value for Y correction.
215
+ X_train : ndarray
216
+ Training inputs (xyY values).
217
+ y_train : ndarray
218
+ Training targets (Munsell specifications).
219
+ X_val : ndarray
220
+ Validation inputs.
221
+ y_val : ndarray
222
+ Validation targets.
223
+ device : torch.device
224
+ Device to run training on.
225
+ num_epochs : int, optional
226
+ Maximum number of training epochs. Default is 100.
227
+ patience : int, optional
228
+ Early stopping patience. Default is 15.
229
+
230
+ Returns
231
+ -------
232
+ nn.Module
233
+ Trained model with best validation loss.
234
+ dict
235
+ Input normalization parameters.
236
+ dict
237
+ Output normalization parameters.
238
+ float
239
+ Best validation loss achieved.
240
+ """
241
+ # Normalize data
242
+ X_train_norm, input_parameters = normalize_inputs(X_train, gamma=gamma)
243
+ X_val_norm, _ = normalize_inputs(X_val, gamma=gamma)
244
+
245
+ # Use shared normalization parameters covering the full Munsell
246
+ # space for generalization
247
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
248
+ y_train_norm = normalize_munsell(y_train, output_parameters)
249
+ y_val_norm = normalize_munsell(y_val, output_parameters)
250
+
251
+ # Convert to tensors
252
+ X_train_t = torch.FloatTensor(X_train_norm)
253
+ y_train_t = torch.FloatTensor(y_train_norm)
254
+ X_val_t = torch.FloatTensor(X_val_norm)
255
+ y_val_t = torch.FloatTensor(y_val_norm)
256
+
257
+ # Create data loaders
258
+ train_dataset = TensorDataset(X_train_t, y_train_t)
259
+ val_dataset = TensorDataset(X_val_t, y_val_t)
260
+
261
+ train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
262
+ val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)
263
+
264
+ # Initialize model
265
+ model = MultiHeadMLPToMunsell().to(device)
266
+ optimizer = optim.Adam(model.parameters(), lr=5e-4)
267
+ criterion = weighted_mse_loss
268
+
269
+ best_val_loss = float("inf")
270
+ patience_counter = 0
271
+ best_state = None
272
+
273
+ for _epoch in range(num_epochs):
274
+ # Train
275
+ model.train()
276
+ for X_batch, y_batch in train_loader:
277
+ X_batch = X_batch.to(device) # noqa: PLW2901
278
+ y_batch = y_batch.to(device) # noqa: PLW2901
279
+
280
+ outputs = model(X_batch)
281
+ loss = criterion(outputs, y_batch)
282
+
283
+ optimizer.zero_grad()
284
+ loss.backward()
285
+ optimizer.step()
286
+
287
+ # Validate
288
+ model.eval()
289
+ total_val_loss = 0.0
290
+ with torch.no_grad():
291
+ for X_batch, y_batch in val_loader:
292
+ X_batch = X_batch.to(device) # noqa: PLW2901
293
+ y_batch = y_batch.to(device) # noqa: PLW2901
294
+ outputs = model(X_batch)
295
+ loss = criterion(outputs, y_batch)
296
+ total_val_loss += loss.item()
297
+ val_loss = total_val_loss / len(val_loader)
298
+
299
+ if val_loss < best_val_loss:
300
+ best_val_loss = val_loss
301
+ patience_counter = 0
302
+ best_state = model.state_dict().copy()
303
+ else:
304
+ patience_counter += 1
305
+ if patience_counter >= patience:
306
+ break
307
+
308
+ # Load best state
309
+ if best_state is not None:
310
+ model.load_state_dict(best_state)
311
+
312
+ return model, input_parameters, output_parameters, best_val_loss
313
+
314
+
315
+ def evaluate_on_real_munsell(
316
+ model: nn.Module,
317
+ input_parameters: dict[str, Any],
318
+ output_parameters: dict[str, Any],
319
+ xyY_array: NDArray,
320
+ reference_Lab: NDArray,
321
+ device: torch.device,
322
+ ) -> tuple[float, float]:
323
+ """
324
+ Evaluate model on real Munsell colors using Delta-E CIE2000.
325
+
326
+ Parameters
327
+ ----------
328
+ model : nn.Module
329
+ Trained model to evaluate.
330
+ input_parameters : dict
331
+ Input normalization parameters.
332
+ output_parameters : dict
333
+ Output normalization parameters.
334
+ xyY_array : ndarray
335
+ Real Munsell xyY values.
336
+ reference_Lab : ndarray
337
+ Reference CIELAB values for Delta-E computation.
338
+ device : torch.device
339
+ Device to run evaluation on.
340
+
341
+ Returns
342
+ -------
343
+ float
344
+ Mean Delta-E CIE2000.
345
+ float
346
+ Median Delta-E CIE2000.
347
+ """
348
+ model.eval()
349
+ gamma = input_parameters["gamma"]
350
+
351
+ # Normalize inputs
352
+ X_norm, _ = normalize_inputs(xyY_array, gamma=gamma)
353
+ X_t = torch.FloatTensor(X_norm).to(device)
354
+
355
+ # Predict
356
+ with torch.no_grad():
357
+ pred_norm = model(X_t).cpu().numpy()
358
+
359
+ pred = denormalize_output(pred_norm, output_parameters)
360
+ delta_E_values = compute_delta_e(pred, reference_Lab)
361
+
362
+ return np.mean(delta_E_values), np.median(delta_E_values)
363
+
364
+
365
+ def run_single_trial(
366
+ gamma: float,
367
+ run_id: int,
368
+ X_train: NDArray,
369
+ y_train: NDArray,
370
+ X_val: NDArray,
371
+ y_val: NDArray,
372
+ xyY_array: NDArray,
373
+ reference_Lab: NDArray,
374
+ ) -> dict[str, Any]:
375
+ """
376
+ Run a single training trial for a given gamma value.
377
+
378
+ Parameters
379
+ ----------
380
+ gamma : float
381
+ Gamma value for Y correction.
382
+ run_id : int
383
+ Run identifier for this trial.
384
+ X_train : ndarray
385
+ Training inputs.
386
+ y_train : ndarray
387
+ Training targets.
388
+ X_val : ndarray
389
+ Validation inputs.
390
+ y_val : ndarray
391
+ Validation targets.
392
+ xyY_array : ndarray
393
+ Real Munsell xyY values for evaluation.
394
+ reference_Lab : ndarray
395
+ Reference CIELAB values for Delta-E computation.
396
+
397
+ Returns
398
+ -------
399
+ dict
400
+ Results dictionary containing gamma, run_id, val_loss,
401
+ mean_delta_e, and median_delta_e.
402
+
403
+ Notes
404
+ -----
405
+ Uses CPU to avoid MPS multiprocessing issues.
406
+ """
407
+ # Each process uses CPU to avoid MPS multiprocessing issues
408
+ device = torch.device("cpu")
409
+
410
+ model, input_parameters, output_parameters, val_loss = train_model(
411
+ gamma=gamma,
412
+ X_train=X_train,
413
+ y_train=y_train,
414
+ X_val=X_val,
415
+ y_val=y_val,
416
+ device=device,
417
+ num_epochs=100,
418
+ patience=15,
419
+ )
420
+
421
+ mean_delta_e, median_delta_e = evaluate_on_real_munsell(
422
+ model, input_parameters, output_parameters, xyY_array, reference_Lab, device
423
+ )
424
+
425
+ return {
426
+ "gamma": gamma,
427
+ "run_id": run_id,
428
+ "val_loss": val_loss,
429
+ "mean_delta_e": mean_delta_e,
430
+ "median_delta_e": median_delta_e,
431
+ }
432
+
433
+
434
+ def main() -> None:
435
+ """
436
+ Run gamma sweep experiment to find optimal gamma value.
437
+
438
+ Notes
439
+ -----
440
+ The training pipeline:
441
+ 1. Loads training and validation data from cache
442
+ 2. Loads real Munsell colors for evaluation
443
+ 3. Sweeps gamma values from 1.0 to 3.0 in 0.1 increments
444
+ 4. Trains multiple models per gamma value for averaging
445
+ 5. Evaluates each model on real Munsell colors using Delta-E CIE2000
446
+ 6. Aggregates results and identifies best gamma value
447
+ 7. Saves results to NPZ file for analysis
448
+
449
+ Uses parallel execution with ProcessPoolExecutor for efficiency.
450
+ Each model is trained with early stopping and evaluated on validation set.
451
+ """
452
+ import argparse # noqa: PLC0415
453
+
454
+ parser = argparse.ArgumentParser(description="Gamma sweep with averaging")
455
+ parser.add_argument("--runs", type=int, default=3, help="Number of runs per gamma")
456
+ parser.add_argument(
457
+ "--workers", type=int, default=4, help="Number of parallel workers"
458
+ )
459
+ args = parser.parse_args()
460
+
461
+ num_runs = args.runs
462
+ num_workers = args.workers
463
+
464
+ LOGGER.info("=" * 80)
465
+ LOGGER.info("Multi-Head Gamma Sweep: Finding Optimal Gamma Value")
466
+ LOGGER.info("Testing gamma values from 1.0 to 3.0 in increments of 0.1")
467
+ LOGGER.info("Runs per gamma: %d, Parallel workers: %d", num_runs, num_workers)
468
+ LOGGER.info("=" * 80)
469
+
470
+ # Load training data
471
+ data_dir = PROJECT_ROOT / "data"
472
+ cache_file = data_dir / "training_data.npz"
473
+
474
+ if not cache_file.exists():
475
+ LOGGER.error("Error: Training data not found at %s", cache_file)
476
+ return
477
+
478
+ LOGGER.info("\nLoading training data...")
479
+ data = np.load(cache_file)
480
+ X_train = data["X_train"]
481
+ y_train = data["y_train"]
482
+ X_val = data["X_val"]
483
+ y_val = data["y_val"]
484
+ LOGGER.info("Train samples: %d, Validation samples: %d", len(X_train), len(X_val))
485
+
486
+ # Load real Munsell data for evaluation
487
+ LOGGER.info("Loading real Munsell colours for evaluation...")
488
+ xyY_values = []
489
+ reference_Lab = []
490
+
491
+ for _munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
492
+ try:
493
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
494
+ XYZ = xyY_to_XYZ(xyY_scaled)
495
+ Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
496
+ xyY_values.append(xyY_scaled)
497
+ reference_Lab.append(Lab)
498
+ except (RuntimeError, ValueError):
499
+ continue
500
+
501
+ xyY_array = np.array(xyY_values)
502
+ reference_Lab = np.array(reference_Lab)
503
+ LOGGER.info("Loaded %d real Munsell colours", len(xyY_array))
504
+
505
+ # Gamma values to test
506
+ gamma_values = [round(1.0 + i * 0.1, 1) for i in range(21)] # 1.0 to 3.0
507
+
508
+ # Create all tasks: (gamma, run_id) pairs
509
+ tasks = [(gamma, run_id) for gamma in gamma_values for run_id in range(num_runs)]
510
+ total_tasks = len(tasks)
511
+
512
+ LOGGER.info("\n%s", "-" * 80)
513
+ LOGGER.info(
514
+ "Starting gamma sweep: %d total tasks (%d gamma values x %d runs)",
515
+ total_tasks,
516
+ len(gamma_values),
517
+ num_runs,
518
+ )
519
+ LOGGER.info("-" * 80)
520
+
521
+ all_results = []
522
+ completed = 0
523
+
524
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
525
+ futures = {
526
+ executor.submit(
527
+ run_single_trial,
528
+ gamma,
529
+ run_id,
530
+ X_train,
531
+ y_train,
532
+ X_val,
533
+ y_val,
534
+ xyY_array,
535
+ reference_Lab,
536
+ ): (gamma, run_id)
537
+ for gamma, run_id in tasks
538
+ }
539
+
540
+ for future in as_completed(futures):
541
+ gamma, run_id = futures[future]
542
+ try:
543
+ result = future.result()
544
+ all_results.append(result)
545
+ completed += 1
546
+ LOGGER.info(
547
+ "[%3d/%3d] gamma=%.1f run=%d: mean_ΔE=%.4f, median_ΔE=%.4f",
548
+ completed,
549
+ total_tasks,
550
+ gamma,
551
+ run_id,
552
+ result["mean_delta_e"],
553
+ result["median_delta_e"],
554
+ )
555
+ except Exception:
556
+ LOGGER.exception(
557
+ "Task failed for gamma=%.1f run=%d",
558
+ gamma,
559
+ run_id,
560
+ )
561
+ completed += 1
562
+
563
+ # Aggregate results by gamma (average across runs)
564
+ aggregated = {}
565
+ for r in all_results:
566
+ gamma = r["gamma"]
567
+ if gamma not in aggregated:
568
+ aggregated[gamma] = {"val_losses": [], "means": [], "medians": []}
569
+ aggregated[gamma]["val_losses"].append(r["val_loss"])
570
+ aggregated[gamma]["means"].append(r["mean_delta_e"])
571
+ aggregated[gamma]["medians"].append(r["median_delta_e"])
572
+
573
+ results = []
574
+ for gamma in sorted(aggregated.keys()):
575
+ agg = aggregated[gamma]
576
+ results.append(
577
+ {
578
+ "gamma": gamma,
579
+ "val_loss": np.mean(agg["val_losses"]),
580
+ "val_loss_std": np.std(agg["val_losses"]),
581
+ "mean_delta_e": np.mean(agg["means"]),
582
+ "mean_delta_e_std": np.std(agg["means"]),
583
+ "median_delta_e": np.mean(agg["medians"]),
584
+ "median_delta_e_std": np.std(agg["medians"]),
585
+ "num_runs": len(agg["means"]),
586
+ }
587
+ )
588
+
589
+ # Print results
590
+ LOGGER.info("\n%s", "=" * 80)
591
+ LOGGER.info("GAMMA SWEEP RESULTS (averaged over %d runs)", num_runs)
592
+ LOGGER.info("=" * 80)
593
+ LOGGER.info("")
594
+ LOGGER.info("%-8s %-14s %-14s %-14s", "Gamma", "Val Loss", "Mean ΔE", "Median ΔE")
595
+ LOGGER.info("-" * 50)
596
+
597
+ for r in results:
598
+ LOGGER.info(
599
+ "%-8.1f %-14s %-14s %-14s",
600
+ r["gamma"],
601
+ f"{r['val_loss']:.6f}±{r['val_loss_std']:.4f}",
602
+ f"{r['mean_delta_e']:.4f}±{r['mean_delta_e_std']:.4f}",
603
+ f"{r['median_delta_e']:.4f}±{r['median_delta_e_std']:.4f}",
604
+ )
605
+
606
+ # Find best by mean Delta-E
607
+ best_by_mean = min(results, key=lambda x: x["mean_delta_e"])
608
+ best_by_median = min(results, key=lambda x: x["median_delta_e"])
609
+
610
+ LOGGER.info("")
611
+ LOGGER.info(
612
+ "Best gamma by MEAN Delta-E: %.1f (ΔE = %.4f ± %.4f)",
613
+ best_by_mean["gamma"],
614
+ best_by_mean["mean_delta_e"],
615
+ best_by_mean["mean_delta_e_std"],
616
+ )
617
+ LOGGER.info(
618
+ "Best gamma by MEDIAN Delta-E: %.1f (ΔE = %.4f ± %.4f)",
619
+ best_by_median["gamma"],
620
+ best_by_median["median_delta_e"],
621
+ best_by_median["median_delta_e_std"],
622
+ )
623
+
624
+ # Save results
625
+ results_file = (
626
+ PROJECT_ROOT / "models" / "from_xyY" / "gamma_sweep_results_averaged.npz"
627
+ )
628
+ np.savez(results_file, results=results, all_results=all_results)
629
+ LOGGER.info("\nResults saved to: %s", results_file)
630
+
631
+ LOGGER.info("\n%s", "=" * 80)
632
+
633
+
634
+ if __name__ == "__main__":
635
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
636
+
637
+ main()
learning_munsell/training/from_xyY/train_multi_head_large.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model on large dataset (2M samples) for xyY to Munsell conversion.
3
+
4
+ This script trains on the larger dataset for potentially improved accuracy.
5
+ Uses the same architecture as train_multi_head_mlp.py but with the large dataset.
6
+ """
7
+
8
+ import logging
9
+
10
+ import click
11
+ import mlflow
12
+ import mlflow.pytorch
13
+ import numpy as np
14
+ import torch
15
+ from torch import optim
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+
18
+ from learning_munsell import PROJECT_ROOT
19
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
20
+ from learning_munsell.utilities.common import (
21
+ log_training_epoch,
22
+ setup_mlflow_experiment,
23
+ )
24
+ from learning_munsell.utilities.data import (
25
+ MUNSELL_NORMALIZATION_PARAMS,
26
+ XYY_NORMALIZATION_PARAMS,
27
+ normalize_munsell,
28
+ )
29
+ from learning_munsell.utilities.losses import weighted_mse_loss
30
+ from learning_munsell.utilities.training import train_epoch, validate
31
+
32
+ LOGGER = logging.getLogger(__name__)
33
+
34
+
35
+ @click.command()
36
+ @click.option("--epochs", default=300, help="Number of training epochs")
37
+ @click.option("--batch-size", default=2048, help="Batch size for training")
38
+ @click.option("--lr", default=5e-4, help="Learning rate")
39
+ @click.option("--patience", default=30, help="Early stopping patience")
40
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
41
+ """
42
+ Train multi-head MLP on large dataset (2M samples) for xyY to Munsell.
43
+
44
+ Notes
45
+ -----
46
+ The training pipeline:
47
+ 1. Loads training and validation data from large cached .npz file
48
+ 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1]
49
+ 3. Creates multi-head MLP with shared encoder and component-specific heads
50
+ 4. Trains with weighted MSE loss (emphasizing chroma)
51
+ 5. Uses Adam optimizer with ReduceLROnPlateau scheduler
52
+ 6. Applies early stopping based on validation loss (patience=30)
53
+ 7. Exports best model to ONNX format
54
+ 8. Logs metrics and artifacts to MLflow
55
+ """
56
+
57
+ LOGGER.info("=" * 80)
58
+ LOGGER.info("Multi-Head Model Training on Large Dataset (2M samples)")
59
+ LOGGER.info("=" * 80)
60
+
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ if torch.backends.mps.is_available():
63
+ device = torch.device("mps")
64
+ LOGGER.info("Using device: %s", device)
65
+
66
+ # Load large training data
67
+ data_dir = PROJECT_ROOT / "data"
68
+ cache_file = data_dir / "training_data_large.npz"
69
+
70
+ if not cache_file.exists():
71
+ LOGGER.error("Error: Large training data not found at %s", cache_file)
72
+ LOGGER.error("Please run generate_large_training_data.py first")
73
+ return
74
+
75
+ LOGGER.info("Loading large training data from %s...", cache_file)
76
+ data = np.load(cache_file)
77
+
78
+ X_train = data["X_train"]
79
+ y_train = data["y_train"]
80
+ X_val = data["X_val"]
81
+ y_val = data["y_val"]
82
+
83
+ LOGGER.info("Train samples: %d", len(X_train))
84
+ LOGGER.info("Validation samples: %d", len(X_val))
85
+
86
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
87
+ # Use shared normalization parameters covering the full Munsell
88
+ # space for generalization
89
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
90
+ y_train_norm = normalize_munsell(y_train, output_parameters)
91
+ y_val_norm = normalize_munsell(y_val, output_parameters)
92
+
93
+ # Convert to PyTorch tensors
94
+ X_train_t = torch.FloatTensor(X_train)
95
+ y_train_t = torch.FloatTensor(y_train_norm)
96
+ X_val_t = torch.FloatTensor(X_val)
97
+ y_val_t = torch.FloatTensor(y_val_norm)
98
+
99
+ # Create data loaders (larger batch size for larger dataset)
100
+ train_dataset = TensorDataset(X_train_t, y_train_t)
101
+ val_dataset = TensorDataset(X_val_t, y_val_t)
102
+
103
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
104
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
105
+
106
+ # Initialize model
107
+ model = MultiHeadMLPToMunsell().to(device)
108
+ LOGGER.info("")
109
+ LOGGER.info("Model architecture:")
110
+ LOGGER.info("%s", model)
111
+
112
+ total_params = sum(p.numel() for p in model.parameters())
113
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
114
+
115
+ # Training setup
116
+ learning_rate = lr
117
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
118
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
119
+ optimizer, mode="min", factor=0.5, patience=10
120
+ )
121
+ criterion = weighted_mse_loss
122
+
123
+ # MLflow setup
124
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_large")
125
+
126
+ LOGGER.info("")
127
+ LOGGER.info("MLflow run: %s", run_name)
128
+
129
+ # Training loop
130
+ best_val_loss = float("inf")
131
+ patience_counter = 0
132
+
133
+ LOGGER.info("")
134
+ LOGGER.info("Starting training...")
135
+
136
+ with mlflow.start_run(run_name=run_name):
137
+ mlflow.log_params(
138
+ {
139
+ "model": "multi_head_large",
140
+ "learning_rate": learning_rate,
141
+ "batch_size": batch_size,
142
+ "num_epochs": epochs,
143
+ "patience": patience,
144
+ "total_params": total_params,
145
+ "train_samples": len(X_train),
146
+ "val_samples": len(X_val),
147
+ "dataset": "large_2M",
148
+ }
149
+ )
150
+
151
+ for epoch in range(epochs):
152
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
153
+ val_loss = validate(model, val_loader, criterion, device)
154
+
155
+ scheduler.step(val_loss)
156
+
157
+ log_training_epoch(
158
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
159
+ )
160
+
161
+ LOGGER.info(
162
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
163
+ epoch + 1,
164
+ epochs,
165
+ train_loss,
166
+ val_loss,
167
+ optimizer.param_groups[0]["lr"],
168
+ )
169
+
170
+ if val_loss < best_val_loss:
171
+ best_val_loss = val_loss
172
+ patience_counter = 0
173
+
174
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
175
+ model_directory.mkdir(exist_ok=True)
176
+ checkpoint_file = model_directory / "multi_head_large_best.pth"
177
+
178
+ torch.save(
179
+ {
180
+ "model_state_dict": model.state_dict(),
181
+ "output_parameters": output_parameters,
182
+ "epoch": epoch,
183
+ "val_loss": val_loss,
184
+ },
185
+ checkpoint_file,
186
+ )
187
+
188
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
189
+ else:
190
+ patience_counter += 1
191
+ if patience_counter >= patience:
192
+ LOGGER.info("")
193
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
194
+ break
195
+
196
+ mlflow.log_metrics(
197
+ {
198
+ "best_val_loss": best_val_loss,
199
+ "final_epoch": epoch + 1,
200
+ }
201
+ )
202
+
203
+ # Export to ONNX
204
+ LOGGER.info("")
205
+ LOGGER.info("Exporting model to ONNX...")
206
+ model.eval()
207
+
208
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
209
+ model.load_state_dict(checkpoint["model_state_dict"])
210
+
211
+ dummy_input = torch.randn(1, 3).to(device)
212
+
213
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
214
+ onnx_file = model_directory / "multi_head_large.onnx"
215
+ torch.onnx.export(
216
+ model,
217
+ dummy_input,
218
+ onnx_file,
219
+ export_params=True,
220
+ opset_version=15,
221
+ input_names=["xyY"],
222
+ output_names=["munsell_spec"],
223
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
224
+ )
225
+
226
+ params_file = model_directory / "multi_head_large_normalization_parameters.npz"
227
+ input_parameters = XYY_NORMALIZATION_PARAMS
228
+ np.savez(
229
+ params_file,
230
+ input_parameters=input_parameters,
231
+ output_parameters=output_parameters,
232
+ )
233
+
234
+ mlflow.log_artifact(str(checkpoint_file))
235
+ mlflow.log_artifact(str(onnx_file))
236
+ mlflow.log_artifact(str(params_file))
237
+ mlflow.pytorch.log_model(model, "model")
238
+
239
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
240
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
241
+ LOGGER.info("Artifacts logged to MLflow")
242
+
243
+ LOGGER.info("=" * 80)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
248
+
249
+ main()
learning_munsell/training/from_xyY/train_multi_head_mlp.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model for xyY to Munsell conversion.
3
+
4
+ Architecture:
5
+ - Shared encoder: 3 inputs → 512-dim features
6
+ - 4 separate heads (one per component):
7
+ - Hue head (circular/angular)
8
+ - Value head (linear lightness)
9
+ - Chroma head (non-linear saturation - larger capacity)
10
+ - Code head (discrete categorical)
11
+
12
+ This architecture allows each component to learn specialized features
13
+ while sharing the general color space understanding.
14
+ """
15
+
16
+ import logging
17
+
18
+ import click
19
+ import mlflow
20
+ import mlflow.pytorch
21
+ import numpy as np
22
+ import torch
23
+ from torch import optim
24
+ from torch.utils.data import DataLoader, TensorDataset
25
+
26
+ from learning_munsell import PROJECT_ROOT
27
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
28
+ from learning_munsell.utilities.common import (
29
+ log_training_epoch,
30
+ setup_mlflow_experiment,
31
+ )
32
+ from learning_munsell.utilities.data import (
33
+ MUNSELL_NORMALIZATION_PARAMS,
34
+ XYY_NORMALIZATION_PARAMS,
35
+ normalize_munsell,
36
+ )
37
+ from learning_munsell.utilities.losses import weighted_mse_loss
38
+ from learning_munsell.utilities.training import train_epoch, validate
39
+
40
+ LOGGER = logging.getLogger(__name__)
41
+
42
+
43
+ @click.command()
44
+ @click.option("--epochs", default=300, help="Number of training epochs")
45
+ @click.option("--batch-size", default=1024, help="Batch size for training")
46
+ @click.option("--lr", default=5e-4, help="Learning rate")
47
+ @click.option("--patience", default=20, help="Early stopping patience")
48
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
49
+ """
50
+ Train multi-head MLP for xyY to Munsell conversion.
51
+
52
+ Notes
53
+ -----
54
+ The training pipeline:
55
+ 1. Loads training and validation data from cached .npz file
56
+ 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1]
57
+ 3. Creates multi-head MLP with shared encoder and component-specific heads
58
+ 4. Trains with weighted MSE loss (emphasizing chroma)
59
+ 5. Uses Adam optimizer with no learning rate scheduling
60
+ 6. Applies early stopping based on validation loss (patience=20)
61
+ 7. Exports best model to ONNX format
62
+ 8. Logs metrics and artifacts to MLflow
63
+ """
64
+
65
+ LOGGER.info("=" * 80)
66
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Model Training")
67
+ LOGGER.info("=" * 80)
68
+
69
+ # Set device
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ LOGGER.info("Using device: %s", device)
72
+
73
+ # Load training data
74
+ data_dir = PROJECT_ROOT / "data"
75
+ cache_file = data_dir / "training_data.npz"
76
+
77
+ if not cache_file.exists():
78
+ LOGGER.error("Error: Training data not found at %s", cache_file)
79
+ LOGGER.error("Please run 01_generate_training_data.py first")
80
+ return
81
+
82
+ LOGGER.info("Loading training data from %s...", cache_file)
83
+ data = np.load(cache_file)
84
+
85
+ X_train = data["X_train"]
86
+ y_train = data["y_train"]
87
+ X_val = data["X_val"]
88
+ y_val = data["y_val"]
89
+
90
+ LOGGER.info("Train samples: %d", len(X_train))
91
+ LOGGER.info("Validation samples: %d", len(X_val))
92
+
93
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
94
+ # Use shared normalization parameters covering the full Munsell
95
+ # space for generalization
96
+ output_parameters = MUNSELL_NORMALIZATION_PARAMS
97
+ y_train_norm = normalize_munsell(y_train, output_parameters)
98
+ y_val_norm = normalize_munsell(y_val, output_parameters)
99
+
100
+ # Convert to PyTorch tensors
101
+ X_train_t = torch.FloatTensor(X_train)
102
+ y_train_t = torch.FloatTensor(y_train_norm)
103
+ X_val_t = torch.FloatTensor(X_val)
104
+ y_val_t = torch.FloatTensor(y_val_norm)
105
+
106
+ # Create data loaders
107
+ train_dataset = TensorDataset(X_train_t, y_train_t)
108
+ val_dataset = TensorDataset(X_val_t, y_val_t)
109
+
110
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
111
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
112
+
113
+ # Initialize model
114
+ model = MultiHeadMLPToMunsell().to(device)
115
+ LOGGER.info("")
116
+ LOGGER.info("Model architecture:")
117
+ LOGGER.info("%s", model)
118
+
119
+ # Count parameters
120
+ total_params = sum(p.numel() for p in model.parameters())
121
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
122
+
123
+ # Count parameters per component
124
+ encoder_params = sum(p.numel() for p in model.encoder.parameters())
125
+ hue_params = sum(p.numel() for p in model.hue_head.parameters())
126
+ value_params = sum(p.numel() for p in model.value_head.parameters())
127
+ chroma_params = sum(p.numel() for p in model.chroma_head.parameters())
128
+ code_params = sum(p.numel() for p in model.code_head.parameters())
129
+
130
+ LOGGER.info(" - Shared encoder: %s", f"{encoder_params:,}")
131
+ LOGGER.info(" - Hue head: %s", f"{hue_params:,}")
132
+ LOGGER.info(" - Value head: %s", f"{value_params:,}")
133
+ LOGGER.info(" - Chroma head: %s (WIDER)", f"{chroma_params:,}")
134
+ LOGGER.info(" - Code head: %s", f"{code_params:,}")
135
+
136
+ # Training setup
137
+ optimizer = optim.Adam(model.parameters(), lr=lr)
138
+ # Use weighted MSE with default weights
139
+ weights = torch.tensor([1.0, 1.0, 3.0, 0.5])
140
+
141
+ def criterion(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
142
+ return weighted_mse_loss(pred, target, weights)
143
+
144
+ # MLflow setup
145
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head")
146
+
147
+ LOGGER.info("")
148
+ LOGGER.info("MLflow run: %s", run_name)
149
+
150
+ # Training loop
151
+ best_val_loss = float("inf")
152
+ patience_counter = 0
153
+
154
+ LOGGER.info("")
155
+ LOGGER.info("Starting training...")
156
+
157
+ with mlflow.start_run(run_name=run_name):
158
+ # Log parameters
159
+ mlflow.log_params(
160
+ {
161
+ "model": "multi_head",
162
+ "learning_rate": lr,
163
+ "batch_size": batch_size,
164
+ "num_epochs": epochs,
165
+ "patience": patience,
166
+ "total_params": total_params,
167
+ }
168
+ )
169
+
170
+ for epoch in range(epochs):
171
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
172
+ val_loss = validate(model, val_loader, criterion, device)
173
+
174
+ # Log to MLflow
175
+ log_training_epoch(
176
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
177
+ )
178
+
179
+ LOGGER.info(
180
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
181
+ epoch + 1,
182
+ epochs,
183
+ train_loss,
184
+ val_loss,
185
+ )
186
+
187
+ # Early stopping
188
+ if val_loss < best_val_loss:
189
+ best_val_loss = val_loss
190
+ patience_counter = 0
191
+
192
+ # Save best model
193
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
194
+ model_directory.mkdir(exist_ok=True)
195
+ checkpoint_file = model_directory / "multi_head_best.pth"
196
+
197
+ torch.save(
198
+ {
199
+ "model_state_dict": model.state_dict(),
200
+ "output_parameters": output_parameters,
201
+ "epoch": epoch,
202
+ "val_loss": val_loss,
203
+ },
204
+ checkpoint_file,
205
+ )
206
+
207
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
208
+ else:
209
+ patience_counter += 1
210
+ if patience_counter >= patience:
211
+ LOGGER.info("")
212
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
213
+ break
214
+
215
+ # Log final metrics
216
+ mlflow.log_metrics(
217
+ {
218
+ "best_val_loss": best_val_loss,
219
+ "final_epoch": epoch + 1,
220
+ }
221
+ )
222
+
223
+ # Export to ONNX
224
+ LOGGER.info("")
225
+ LOGGER.info("Exporting model to ONNX...")
226
+ model.eval()
227
+
228
+ # Load best model
229
+ checkpoint = torch.load(checkpoint_file)
230
+ model.load_state_dict(checkpoint["model_state_dict"])
231
+
232
+ # Create dummy input
233
+ dummy_input = torch.randn(1, 3).to(device)
234
+
235
+ # Export
236
+ onnx_file = model_directory / "multi_head.onnx"
237
+ torch.onnx.export(
238
+ model,
239
+ dummy_input,
240
+ onnx_file,
241
+ export_params=True,
242
+ opset_version=15,
243
+ input_names=["xyY"],
244
+ output_names=["munsell_spec"],
245
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
246
+ )
247
+
248
+ # Save normalization parameters alongside model
249
+ params_file = model_directory / "multi_head_normalization_parameters.npz"
250
+ input_parameters = XYY_NORMALIZATION_PARAMS
251
+ np.savez(
252
+ params_file,
253
+ input_parameters=input_parameters,
254
+ output_parameters=output_parameters,
255
+ )
256
+
257
+ # Log artifacts to MLflow
258
+ mlflow.log_artifact(str(checkpoint_file))
259
+ mlflow.log_artifact(str(onnx_file))
260
+ mlflow.log_artifact(str(params_file))
261
+ mlflow.pytorch.log_model(model, "model")
262
+
263
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
264
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
265
+ LOGGER.info("Artifacts logged to MLflow")
266
+
267
+ LOGGER.info("=" * 80)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
272
+
273
+ main()